Spaces:
Runtime error
Runtime error
| # Copyright 2024 Lnyan (https://github.com/lkwq007). All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import math | |
| from functools import partial | |
| import numpy as np | |
| import jax | |
| import jax.numpy as jnp | |
| from jax import Array as Tensor | |
| import flax | |
| from flax import nnx | |
| import flax.linen | |
| def fake_init(key, feature_shape, param_dtype): | |
| return jax.ShapeDtypeStruct(feature_shape, param_dtype) | |
| def wrap_LayerNorm(dim, *, eps=1e-5, elementwise_affine=True, bias=True, rngs:nnx.Rngs): | |
| return nnx.LayerNorm(dim, epsilon=eps, use_bias=elementwise_affine and bias, use_scale=elementwise_affine, bias_init=fake_init, scale_init=fake_init, rngs=rngs) | |
| def wrap_Linear(dim, inner_dim, *, bias=True, rngs:nnx.Rngs): | |
| return nnx.Linear(dim, inner_dim, use_bias=bias, kernel_init=fake_init, bias_init=fake_init, rngs=rngs) | |
| def wrap_GroupNorm(num_groups, num_channels, *, eps=1e-5, affine=True, rngs:nnx.Rngs): | |
| return nnx.GroupNorm(num_channels, num_groups=num_groups, epsilon=eps, use_bias=affine, use_scale=affine, bias_init=fake_init, scale_init=fake_init, rngs=rngs) | |
| def wrap_Conv(in_channels, out_channels, kernel_size, *, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', rngs:nnx.Rngs, conv_dim:int): | |
| if isinstance(kernel_size, int): | |
| kernel_tuple = (kernel_size,) * conv_dim | |
| else: | |
| # elif isinstance(kernel_size, tuple): | |
| assert len(kernel_size) == conv_dim | |
| kernel_tuple = kernel_size | |
| return nnx.Conv(in_channels, out_channels, kernel_tuple, strides=stride, padding=padding, use_bias=bias, kernel_init=fake_init, bias_init=fake_init, rngs=rngs) | |
| # return nnx.Conv(in_channels, out_channels, kernel_tuple, stride=stride, padding=padding, dilation=dilation, feature_group_count=groups, use_bias=bias, rngs=rngs) | |
| class nn_GELU(nnx.Module): | |
| def __init__(self, approximate="none") -> None: | |
| self.approximate=approximate=="tanh" | |
| def __call__(self, x): | |
| return nnx.gelu(x, approximate=self.approximate) | |
| class nn_SiLU(nnx.Module): | |
| def __init__(self) -> None: | |
| pass | |
| def __call__(self, x): | |
| return nnx.silu(x) | |
| class nn_AvgPool(nnx.Module): | |
| def __init__(self, window_shape, strides=None, padding="VALID") -> None: | |
| self.window_shape=window_shape | |
| self.strides=strides | |
| self.padding=padding | |
| def __call__(self, x): | |
| return flax.linen.avg_pool(x, window_shape=self.window_shape, strides=self.strides, padding=self.padding) | |
| # a wrapper class | |
| class TorchWrapper: | |
| def __init__(self, rngs: nnx.Rngs, dtype=jnp.float32): | |
| self.rngs = rngs | |
| self.dtype = dtype | |
| def declare_with_rng(self, *args): | |
| ret=list(map(lambda f: partial(f, dtype=self.dtype, rngs=self.rngs), args)) | |
| return ret if len(ret)>1 else ret[0] | |
| def conv_nd(self, dims, *args, **kwargs): | |
| return wrap_Conv(*args, **kwargs, rngs=self.rngs, conv_dim=dims) | |
| def avg_pool(self, *args, **kwargs): | |
| return nn_AvgPool(*args, **kwargs) | |
| def linear(self, *args, **kwargs): | |
| return self.Linear(*args, **kwargs) | |
| def SiLU(self): | |
| return nn_SiLU() | |
| def GELU(self, approximate="none"): | |
| return nn_GELU(approximate) | |
| def Identity(self): | |
| return lambda x: x | |
| def LayerNorm(self, dim, eps=1e-5, elementwise_affine=True, bias=True): | |
| return wrap_LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine, bias=bias, rngs=self.rngs) | |
| def GroupNorm(self, *args, **kwargs): | |
| return wrap_GroupNorm(*args,**kwargs, rngs=self.rngs) | |
| def Linear(self, *args, **kwargs): | |
| return wrap_Linear(*args, **kwargs, rngs=self.rngs) | |
| def Parameter(self, value): | |
| return nnx.Param(value) | |
| def Dropout(self, p): | |
| return nnx.Dropout(rate=p, rngs=self.rngs) | |
| def Sequential(self, *args): | |
| return nnx.Sequential(*args) | |
| def Conv1d(self, *args, **kwargs): | |
| return wrap_Conv(*args, **kwargs, rngs=self.rngs, conv_dim=1) | |
| def Conv2d(self, *args, **kwargs): | |
| return wrap_Conv(*args, **kwargs, rngs=self.rngs, conv_dim=2) | |
| def Conv3d(self, *args, **kwargs): | |
| return wrap_Conv(*args, **kwargs, rngs=self.rngs, conv_dim=3) | |
| def ModuleList(self, lst=None): | |
| if lst is None: | |
| return [] | |
| return list(lst) | |
| def Module(self,*args,**kwargs): | |
| return nnx.Dict() | |