Spatial-temporal-ERF / models /qk_model_with_delay /delay_synaptic_func_inter.py
ericzhang0328's picture
Upload folder using huggingface_hub
0c1e054 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
def set_sigma_for_DCLS(model, s):
for name, module in model.named_modules():
if module.__class__.__name__ == 'DelayConv':
if hasattr(module, 'sigma'):
module.sigma = s
print('Set sigma to ',s)
class DropoutNd(nn.Module):
def __init__(self, p: float = 0.5, tie=True, transposed=True):
"""
tie: tie dropout mask across sequence lengths (Dropout1d/2d/3d)
"""
super().__init__()
if p < 0 or p >= 1:
raise ValueError("dropout probability has to be in [0, 1), " "but got {}".format(p))
self.p = p
self.tie = tie
self.transposed = transposed
self.binomial = torch.distributions.binomial.Binomial(probs=1-self.p)
def forward(self, X):
"""X: (batch, dim, lengths...)."""
if self.training:
if not self.transposed: X = rearrange(X, 'b ... d -> b d ...')
# binomial = torch.distributions.binomial.Binomial(probs=1-self.p) # This is incredibly slow because of CPU -> GPU copying
mask_shape = X.shape[:2] + (1,) * (X.ndim - 2) if self.tie else X.shape
# mask = self.binomial.sample(mask_shape)
mask = torch.rand(*mask_shape, device=X.device) < 1. - self.p
X = X * mask * (1.0 / (1 - self.p))
if not self.transposed: X = rearrange(X, 'b d ... -> b ... d')
return X
return X
class DelayConv(nn.Module):
def __init__(
self,
in_c,
k,
dropout=0.0,
n_delay=1,
dilation=1,
kernel_type='triangle_r_temp'
):
super().__init__()
self.C = in_c # 输入和输出通道数
self.win_len = k
self.dilation = dilation
self.n_delay = n_delay
self.kernel_type = kernel_type
self.t = torch.arange(self.win_len).float().unsqueeze(0) # [1, k]
self.sigma = self.win_len // 2
self.delay_kernel = None
self.bump = None
# ========== 修改:d 形状 -> [C_out, C_in, n_delay] ==========
d = torch.rand(self.C, self.C, self.n_delay)
with torch.no_grad():
for co in range(self.C):
for ci in range(self.C):
d[co, ci, :] = torch.randperm(self.win_len - 2)[:self.n_delay] + 1
self.register("d", d, lr=1e-2)
# 初始化权重: [C_out, C_in, k]
weight = torch.ones([self.C, self.C, k])
with torch.no_grad():
for co in range(self.C): # output channel
for ci in range(self.C): # input channel
for i in range(k - 2, -1, -1):
weight[co, ci, i] = weight[co, ci, i + 1] / 2
self.weight = nn.Parameter(weight)
self.dropout = nn.Dropout(dropout / 5) if dropout > 0.0 else nn.Identity()
def register(self, name, tensor, lr=None):
"""注册可训练或固定参数"""
if lr == 0.0:
self.register_buffer(name, tensor)
else:
self.register_parameter(name, nn.Parameter(tensor))
optim = {"weight_decay": 0}
if lr is not None:
optim["lr"] = lr
setattr(getattr(self, name), "_optim", optim)
def update_kernel(self, device):
"""
输出 delay kernel: shape [C_out, C_in, k]
"""
t = self.t.to(device).view(1, 1, 1, -1) # [1,1,1,k]
d = self.d.to(device) # [C_out, C_in, n_delay]
# ---------- 计算 bump ----------
if self.kernel_type == 'gauss':
bump = torch.exp(-0.5 * ((t - self.win_len + d.unsqueeze(-1) + 1) / self.sigma) ** 2)
bump = (bump - 1e-3).relu() + 1e-3
bump = bump / (bump.sum(dim=-1, keepdim=True) + 1e-7)
elif self.kernel_type == 'triangle':
bump = torch.relu(1 - torch.abs((t - self.win_len + d.unsqueeze(-1) + 1) / self.sigma))
bump = bump / (bump.sum(dim=-1, keepdim=True).detach() + 1e-7)
elif self.kernel_type == 'triangle_r':
d_int = (d.round() - d).detach() + d
bump = torch.relu(1 - torch.abs((t - self.win_len + d_int.unsqueeze(-1) + 1) / self.sigma))
bump = bump / (bump.sum(dim=-1, keepdim=True).detach() + 1e-7)
elif self.kernel_type == 'triangle_r_temp':
scale = min(1.0, 1.0 / self.sigma)
d_int = (d.round() - d).detach() * scale + d
bump = torch.relu(1 - torch.abs((t - self.win_len + d_int.unsqueeze(-1) + 1) / self.sigma))
bump = bump / (bump.sum(dim=-1, keepdim=True).detach() + 1e-7) # [C_out, C_in, n_delay, k]
# ------ 在eval模式硬化bump ------
if not self.training:
max_idx = bump.argmax(dim=-1, keepdim=True) # 找最大值索引
hard_mask = torch.zeros_like(bump)
hard_mask.scatter_(-1, max_idx, 1.0)
bump = bump * hard_mask
# --------------------------------
else:
raise ValueError(f"Unknown kernel_type: {self.kernel_type}")
# bump: [C_out, C_in, n_delay, k]
self.bump = bump.detach().clone().to(device)
# ---------- 沿 n_delay 维度求和: [C_out, C_in, k] ----------
bump_sum = bump.sum(dim=2)
# ---------- 生成最终卷积核 ----------
# weight: [C_out, C_in, k]
self.delay_kernel = (self.weight * bump_sum).to(device) # [C_out, C_in, k]
def forward(self, x):
"""
x: (T, B, N, C)
return: (T*B, C, N)
"""
# 调整维度
x = x.permute(0, 1, 3, 2).contiguous() # (T, B, N, C)
T, B, N, C = x.shape
assert C == self.C, f"Input channel mismatch: {C} vs {self.C}"
x = x.permute(1, 2, 3, 0).contiguous() # (B, N, C, T)
# 合并 B*N 作为 batch
x_reshaped = x.view(B * N, C, T) # (B*N, C, T)
device = x.device
# 更新 kernel
self.update_kernel(device) # -> [C_out, C_in, k]
kernel = self.delay_kernel
# padding
pad_left = (self.win_len - 1) * self.dilation
x_padded = F.pad(x_reshaped, (pad_left, 0)) # (B*N, C, T+pad)
# 全通道卷积: groups=1 (跨通道交互)
y = F.conv1d(x_padded, kernel, stride=1, dilation=self.dilation, groups=1) # (B*N, C, T)
# 还原到原始形状
y = y.view(B, N, C, T).permute(3, 0, 2, 1).contiguous().view(-1, C, N) # (T*B, C, N)
return self.dropout(y)