|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
from einops import rearrange
|
|
|
|
|
|
def set_sigma_for_DCLS(model, s):
|
|
|
for name, module in model.named_modules():
|
|
|
if module.__class__.__name__ == 'DelayConv':
|
|
|
if hasattr(module, 'sigma'):
|
|
|
module.sigma = s
|
|
|
print('Set sigma to ',s)
|
|
|
|
|
|
class DropoutNd(nn.Module):
|
|
|
def __init__(self, p: float = 0.5, tie=True, transposed=True):
|
|
|
"""
|
|
|
tie: tie dropout mask across sequence lengths (Dropout1d/2d/3d)
|
|
|
"""
|
|
|
super().__init__()
|
|
|
if p < 0 or p >= 1:
|
|
|
raise ValueError("dropout probability has to be in [0, 1), " "but got {}".format(p))
|
|
|
self.p = p
|
|
|
self.tie = tie
|
|
|
self.transposed = transposed
|
|
|
self.binomial = torch.distributions.binomial.Binomial(probs=1-self.p)
|
|
|
|
|
|
def forward(self, X):
|
|
|
"""X: (batch, dim, lengths...)."""
|
|
|
if self.training:
|
|
|
if not self.transposed: X = rearrange(X, 'b ... d -> b d ...')
|
|
|
|
|
|
mask_shape = X.shape[:2] + (1,) * (X.ndim - 2) if self.tie else X.shape
|
|
|
|
|
|
mask = torch.rand(*mask_shape, device=X.device) < 1. - self.p
|
|
|
X = X * mask * (1.0 / (1 - self.p))
|
|
|
if not self.transposed: X = rearrange(X, 'b d ... -> b ... d')
|
|
|
return X
|
|
|
return X
|
|
|
|
|
|
class DelayConv(nn.Module):
|
|
|
def __init__(
|
|
|
self,
|
|
|
in_c,
|
|
|
k,
|
|
|
dropout=0.0,
|
|
|
n_delay=1,
|
|
|
dilation=1,
|
|
|
kernel_type='triangle_r_temp'
|
|
|
):
|
|
|
super().__init__()
|
|
|
self.C = in_c
|
|
|
self.win_len = k
|
|
|
self.dilation = dilation
|
|
|
self.n_delay = n_delay
|
|
|
self.kernel_type = kernel_type
|
|
|
|
|
|
self.t = torch.arange(self.win_len).float().unsqueeze(0)
|
|
|
self.sigma = self.win_len // 2
|
|
|
|
|
|
self.delay_kernel = None
|
|
|
self.bump = None
|
|
|
|
|
|
|
|
|
d = torch.rand(self.C, self.C, self.n_delay)
|
|
|
with torch.no_grad():
|
|
|
for co in range(self.C):
|
|
|
for ci in range(self.C):
|
|
|
d[co, ci, :] = torch.randperm(self.win_len - 2)[:self.n_delay] + 1
|
|
|
self.register("d", d, lr=1e-2)
|
|
|
|
|
|
|
|
|
weight = torch.ones([self.C, self.C, k])
|
|
|
with torch.no_grad():
|
|
|
for co in range(self.C):
|
|
|
for ci in range(self.C):
|
|
|
for i in range(k - 2, -1, -1):
|
|
|
weight[co, ci, i] = weight[co, ci, i + 1] / 2
|
|
|
|
|
|
self.weight = nn.Parameter(weight)
|
|
|
|
|
|
self.dropout = nn.Dropout(dropout / 5) if dropout > 0.0 else nn.Identity()
|
|
|
|
|
|
def register(self, name, tensor, lr=None):
|
|
|
"""注册可训练或固定参数"""
|
|
|
if lr == 0.0:
|
|
|
self.register_buffer(name, tensor)
|
|
|
else:
|
|
|
self.register_parameter(name, nn.Parameter(tensor))
|
|
|
optim = {"weight_decay": 0}
|
|
|
if lr is not None:
|
|
|
optim["lr"] = lr
|
|
|
setattr(getattr(self, name), "_optim", optim)
|
|
|
|
|
|
def update_kernel(self, device):
|
|
|
"""
|
|
|
输出 delay kernel: shape [C_out, C_in, k]
|
|
|
"""
|
|
|
t = self.t.to(device).view(1, 1, 1, -1)
|
|
|
d = self.d.to(device)
|
|
|
|
|
|
|
|
|
if self.kernel_type == 'gauss':
|
|
|
bump = torch.exp(-0.5 * ((t - self.win_len + d.unsqueeze(-1) + 1) / self.sigma) ** 2)
|
|
|
bump = (bump - 1e-3).relu() + 1e-3
|
|
|
bump = bump / (bump.sum(dim=-1, keepdim=True) + 1e-7)
|
|
|
|
|
|
elif self.kernel_type == 'triangle':
|
|
|
bump = torch.relu(1 - torch.abs((t - self.win_len + d.unsqueeze(-1) + 1) / self.sigma))
|
|
|
bump = bump / (bump.sum(dim=-1, keepdim=True).detach() + 1e-7)
|
|
|
|
|
|
elif self.kernel_type == 'triangle_r':
|
|
|
d_int = (d.round() - d).detach() + d
|
|
|
bump = torch.relu(1 - torch.abs((t - self.win_len + d_int.unsqueeze(-1) + 1) / self.sigma))
|
|
|
bump = bump / (bump.sum(dim=-1, keepdim=True).detach() + 1e-7)
|
|
|
|
|
|
elif self.kernel_type == 'triangle_r_temp':
|
|
|
scale = min(1.0, 1.0 / self.sigma)
|
|
|
d_int = (d.round() - d).detach() * scale + d
|
|
|
bump = torch.relu(1 - torch.abs((t - self.win_len + d_int.unsqueeze(-1) + 1) / self.sigma))
|
|
|
bump = bump / (bump.sum(dim=-1, keepdim=True).detach() + 1e-7)
|
|
|
|
|
|
if not self.training:
|
|
|
max_idx = bump.argmax(dim=-1, keepdim=True)
|
|
|
hard_mask = torch.zeros_like(bump)
|
|
|
hard_mask.scatter_(-1, max_idx, 1.0)
|
|
|
bump = bump * hard_mask
|
|
|
|
|
|
else:
|
|
|
raise ValueError(f"Unknown kernel_type: {self.kernel_type}")
|
|
|
|
|
|
|
|
|
self.bump = bump.detach().clone().to(device)
|
|
|
|
|
|
|
|
|
bump_sum = bump.sum(dim=2)
|
|
|
|
|
|
|
|
|
|
|
|
self.delay_kernel = (self.weight * bump_sum).to(device)
|
|
|
|
|
|
def forward(self, x):
|
|
|
"""
|
|
|
x: (T, B, N, C)
|
|
|
return: (T*B, C, N)
|
|
|
"""
|
|
|
|
|
|
x = x.permute(0, 1, 3, 2).contiguous()
|
|
|
T, B, N, C = x.shape
|
|
|
assert C == self.C, f"Input channel mismatch: {C} vs {self.C}"
|
|
|
x = x.permute(1, 2, 3, 0).contiguous()
|
|
|
|
|
|
|
|
|
x_reshaped = x.view(B * N, C, T)
|
|
|
device = x.device
|
|
|
|
|
|
|
|
|
self.update_kernel(device)
|
|
|
kernel = self.delay_kernel
|
|
|
|
|
|
|
|
|
pad_left = (self.win_len - 1) * self.dilation
|
|
|
x_padded = F.pad(x_reshaped, (pad_left, 0))
|
|
|
|
|
|
|
|
|
y = F.conv1d(x_padded, kernel, stride=1, dilation=self.dilation, groups=1)
|
|
|
|
|
|
|
|
|
y = y.view(B, N, C, T).permute(3, 0, 2, 1).contiguous().view(-1, C, N)
|
|
|
|
|
|
return self.dropout(y) |