import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange def set_sigma_for_DCLS(model, s): for name, module in model.named_modules(): if module.__class__.__name__ == 'DelayConv': if hasattr(module, 'sigma'): module.sigma = s print('Set sigma to ',s) class DropoutNd(nn.Module): def __init__(self, p: float = 0.5, tie=True, transposed=True): """ tie: tie dropout mask across sequence lengths (Dropout1d/2d/3d) """ super().__init__() if p < 0 or p >= 1: raise ValueError("dropout probability has to be in [0, 1), " "but got {}".format(p)) self.p = p self.tie = tie self.transposed = transposed self.binomial = torch.distributions.binomial.Binomial(probs=1-self.p) def forward(self, X): """X: (batch, dim, lengths...).""" if self.training: if not self.transposed: X = rearrange(X, 'b ... d -> b d ...') # binomial = torch.distributions.binomial.Binomial(probs=1-self.p) # This is incredibly slow because of CPU -> GPU copying mask_shape = X.shape[:2] + (1,) * (X.ndim - 2) if self.tie else X.shape # mask = self.binomial.sample(mask_shape) mask = torch.rand(*mask_shape, device=X.device) < 1. - self.p X = X * mask * (1.0 / (1 - self.p)) if not self.transposed: X = rearrange(X, 'b d ... -> b ... d') return X return X class DelayConv(nn.Module): def __init__( self, in_c, k, dropout=0.0, n_delay=1, dilation=1, kernel_type='triangle_r_temp' ): super().__init__() self.C = in_c # 输入和输出通道数 self.win_len = k self.dilation = dilation self.n_delay = n_delay self.kernel_type = kernel_type self.t = torch.arange(self.win_len).float().unsqueeze(0) # [1, k] self.sigma = self.win_len // 2 self.delay_kernel = None self.bump = None # ========== 修改:d 形状 -> [C_out, C_in, n_delay] ========== d = torch.rand(self.C, self.C, self.n_delay) with torch.no_grad(): for co in range(self.C): for ci in range(self.C): d[co, ci, :] = torch.randperm(self.win_len - 2)[:self.n_delay] + 1 self.register("d", d, lr=1e-2) # 初始化权重: [C_out, C_in, k] weight = torch.ones([self.C, self.C, k]) with torch.no_grad(): for co in range(self.C): # output channel for ci in range(self.C): # input channel for i in range(k - 2, -1, -1): weight[co, ci, i] = weight[co, ci, i + 1] / 2 self.weight = nn.Parameter(weight) self.dropout = nn.Dropout(dropout / 5) if dropout > 0.0 else nn.Identity() def register(self, name, tensor, lr=None): """注册可训练或固定参数""" if lr == 0.0: self.register_buffer(name, tensor) else: self.register_parameter(name, nn.Parameter(tensor)) optim = {"weight_decay": 0} if lr is not None: optim["lr"] = lr setattr(getattr(self, name), "_optim", optim) def update_kernel(self, device): """ 输出 delay kernel: shape [C_out, C_in, k] """ t = self.t.to(device).view(1, 1, 1, -1) # [1,1,1,k] d = self.d.to(device) # [C_out, C_in, n_delay] # ---------- 计算 bump ---------- if self.kernel_type == 'gauss': bump = torch.exp(-0.5 * ((t - self.win_len + d.unsqueeze(-1) + 1) / self.sigma) ** 2) bump = (bump - 1e-3).relu() + 1e-3 bump = bump / (bump.sum(dim=-1, keepdim=True) + 1e-7) elif self.kernel_type == 'triangle': bump = torch.relu(1 - torch.abs((t - self.win_len + d.unsqueeze(-1) + 1) / self.sigma)) bump = bump / (bump.sum(dim=-1, keepdim=True).detach() + 1e-7) elif self.kernel_type == 'triangle_r': d_int = (d.round() - d).detach() + d bump = torch.relu(1 - torch.abs((t - self.win_len + d_int.unsqueeze(-1) + 1) / self.sigma)) bump = bump / (bump.sum(dim=-1, keepdim=True).detach() + 1e-7) elif self.kernel_type == 'triangle_r_temp': scale = min(1.0, 1.0 / self.sigma) d_int = (d.round() - d).detach() * scale + d bump = torch.relu(1 - torch.abs((t - self.win_len + d_int.unsqueeze(-1) + 1) / self.sigma)) bump = bump / (bump.sum(dim=-1, keepdim=True).detach() + 1e-7) # [C_out, C_in, n_delay, k] # ------ 在eval模式硬化bump ------ if not self.training: max_idx = bump.argmax(dim=-1, keepdim=True) # 找最大值索引 hard_mask = torch.zeros_like(bump) hard_mask.scatter_(-1, max_idx, 1.0) bump = bump * hard_mask # -------------------------------- else: raise ValueError(f"Unknown kernel_type: {self.kernel_type}") # bump: [C_out, C_in, n_delay, k] self.bump = bump.detach().clone().to(device) # ---------- 沿 n_delay 维度求和: [C_out, C_in, k] ---------- bump_sum = bump.sum(dim=2) # ---------- 生成最终卷积核 ---------- # weight: [C_out, C_in, k] self.delay_kernel = (self.weight * bump_sum).to(device) # [C_out, C_in, k] def forward(self, x): """ x: (T, B, N, C) return: (T*B, C, N) """ # 调整维度 x = x.permute(0, 1, 3, 2).contiguous() # (T, B, N, C) T, B, N, C = x.shape assert C == self.C, f"Input channel mismatch: {C} vs {self.C}" x = x.permute(1, 2, 3, 0).contiguous() # (B, N, C, T) # 合并 B*N 作为 batch x_reshaped = x.view(B * N, C, T) # (B*N, C, T) device = x.device # 更新 kernel self.update_kernel(device) # -> [C_out, C_in, k] kernel = self.delay_kernel # padding pad_left = (self.win_len - 1) * self.dilation x_padded = F.pad(x_reshaped, (pad_left, 0)) # (B*N, C, T+pad) # 全通道卷积: groups=1 (跨通道交互) y = F.conv1d(x_padded, kernel, stride=1, dilation=self.dilation, groups=1) # (B*N, C, T) # 还原到原始形状 y = y.view(B, N, C, T).permute(3, 0, 2, 1).contiguous().view(-1, C, N) # (T*B, C, N) return self.dropout(y)