| | from __future__ import annotations |
| |
|
| | from dataclasses import dataclass |
| | from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | |
| |
|
| | def _get_activation(name: Optional[str]) -> nn.Module: |
| | if name is None: |
| | return nn.Identity() |
| | name = name.lower() |
| | mapping = { |
| | "relu": nn.ReLU(), |
| | "gelu": nn.GELU(), |
| | "silu": nn.SiLU(), |
| | "swish": nn.SiLU(), |
| | "tanh": nn.Tanh(), |
| | "sigmoid": nn.Sigmoid(), |
| | "leaky_relu": nn.LeakyReLU(0.2), |
| | "elu": nn.ELU(), |
| | "mish": nn.Mish(), |
| | "softplus": nn.Softplus(), |
| | "identity": nn.Identity(), |
| | None: nn.Identity(), |
| | } |
| | if name not in mapping: |
| | raise ValueError(f"Unknown activation: {name}") |
| | return mapping[name] |
| |
|
| |
|
| | def _get_norm(name: Optional[str], num_features: int) -> nn.Module: |
| | if name is None or name == "none": |
| | return nn.Identity() |
| | name = name.lower() |
| | if name == "batch": |
| | return nn.BatchNorm1d(num_features) |
| | if name == "layer": |
| | return nn.LayerNorm(num_features) |
| | if name == "instance": |
| | return nn.InstanceNorm1d(num_features) |
| | if name == "group": |
| | |
| | groups = max(1, min(8, num_features)) |
| | |
| | while num_features % groups != 0 and groups > 1: |
| | groups -= 1 |
| | if groups == 1: |
| | return nn.LayerNorm(num_features) |
| | return nn.GroupNorm(groups, num_features) |
| | raise ValueError(f"Unknown normalization: {name}") |
| |
|
| |
|
| | def _flatten_3d_to_2d(x: torch.Tensor) -> Tuple[torch.Tensor, Optional[Tuple[int, int]]]: |
| | if x.dim() == 3: |
| | b, t, f = x.shape |
| | return x.reshape(b * t, f), (b, t) |
| | return x, None |
| |
|
| |
|
| | def _maybe_restore_3d(x: torch.Tensor, shape_hint: Optional[Tuple[int, int]]) -> torch.Tensor: |
| | if shape_hint is None: |
| | return x |
| | b, t = shape_hint |
| | f = x.shape[-1] |
| | return x.reshape(b, t, f) |