|
|
from functools import partial |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from spikingjelly.clock_driven import layer |
|
|
from timm.models.layers import to_2tuple, trunc_normal_, DropPath |
|
|
from timm.models.registry import register_model |
|
|
from timm.models.vision_transformer import _cfg |
|
|
from einops.layers.torch import Rearrange |
|
|
import torch.nn.functional as F |
|
|
from timm.models.vision_transformer import PatchEmbed, Block |
|
|
from .util.pos_embed import get_2d_sincos_pos_embed |
|
|
|
|
|
import copy |
|
|
from torchvision import transforms |
|
|
import matplotlib.pyplot as plt |
|
|
import torch.nn as nn |
|
|
|
|
|
T=4 |
|
|
|
|
|
class multispike(torch.autograd.Function): |
|
|
@staticmethod |
|
|
def forward(ctx, input, lens=T): |
|
|
ctx.save_for_backward(input) |
|
|
ctx.lens = lens |
|
|
return torch.floor(torch.clamp(input, 0, lens) + 0.5) |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, grad_output): |
|
|
input, = ctx.saved_tensors |
|
|
grad_input = grad_output.clone() |
|
|
temp1 = 0 < input |
|
|
temp2 = input < ctx.lens |
|
|
return grad_input * temp1.float() * temp2.float(), None |
|
|
|
|
|
|
|
|
class Multispike(nn.Module): |
|
|
def __init__(self, spike=multispike,norm=T): |
|
|
super().__init__() |
|
|
self.lens = norm |
|
|
self.spike = spike |
|
|
self.norm=norm |
|
|
|
|
|
def forward(self, inputs): |
|
|
return self.spike.apply(inputs)/self.norm |
|
|
|
|
|
|
|
|
def MS_conv_unit(in_channels, out_channels,kernel_size=1,padding=0,groups=1): |
|
|
return nn.Sequential( |
|
|
layer.SeqToANNContainer( |
|
|
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, groups=groups,bias=True), |
|
|
nn.BatchNorm2d(out_channels) |
|
|
) |
|
|
) |
|
|
class MS_ConvBlock(nn.Module): |
|
|
def __init__(self, dim, |
|
|
mlp_ratio=4.0): |
|
|
super().__init__() |
|
|
|
|
|
self.neuron1 = Multispike() |
|
|
self.conv1 = MS_conv_unit(dim, dim * mlp_ratio, 3, 1) |
|
|
|
|
|
self.neuron2 = Multispike() |
|
|
self.conv2 = MS_conv_unit(dim*mlp_ratio, dim, 3, 1) |
|
|
|
|
|
|
|
|
def forward(self, x, mask=None): |
|
|
short_cut = x |
|
|
x = self.neuron1(x) |
|
|
x = self.conv1(x) |
|
|
x = self.neuron2(x) |
|
|
x = self.conv2(x) |
|
|
x = x +short_cut |
|
|
return x |
|
|
|
|
|
class MS_MLP(nn.Module): |
|
|
def __init__( |
|
|
self, in_features, hidden_features=None, out_features=None, drop=0.0, layer=0 |
|
|
): |
|
|
super().__init__() |
|
|
out_features = out_features or in_features |
|
|
hidden_features = hidden_features or in_features |
|
|
self.fc1_conv = nn.Conv1d(in_features, hidden_features, kernel_size=1, stride=1) |
|
|
self.fc1_bn = nn.BatchNorm1d(hidden_features) |
|
|
self.fc1_lif = Multispike() |
|
|
|
|
|
|
|
|
self.fc2_conv = nn.Conv1d( |
|
|
hidden_features, out_features, kernel_size=1, stride=1 |
|
|
) |
|
|
self.fc2_bn = nn.BatchNorm1d(out_features) |
|
|
self.fc2_lif = Multispike() |
|
|
|
|
|
self.c_hidden = hidden_features |
|
|
self.c_output = out_features |
|
|
|
|
|
def forward(self, x): |
|
|
T, B, C, N= x.shape |
|
|
|
|
|
x = self.fc1_lif(x) |
|
|
x = self.fc1_conv(x.flatten(0, 1)) |
|
|
x = self.fc1_bn(x).reshape(T, B, self.c_hidden, N).contiguous() |
|
|
|
|
|
x = self.fc2_lif(x) |
|
|
x = self.fc2_conv(x.flatten(0, 1)) |
|
|
x = self.fc2_bn(x).reshape(T, B, C, N).contiguous() |
|
|
|
|
|
return x |
|
|
|
|
|
class RepConv(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
in_channel, |
|
|
out_channel, |
|
|
bias=False, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.conv1 = nn.Sequential(nn.Conv1d(in_channel, int(in_channel*1.5), kernel_size=1, stride=1,bias=False), nn.BatchNorm1d(int(in_channel*1.5))) |
|
|
self.conv2 = nn.Sequential(nn.Conv1d(int(in_channel*1.5), out_channel, kernel_size=1, stride=1,bias=False), nn.BatchNorm1d(out_channel)) |
|
|
def forward(self, x): |
|
|
return self.conv2(self.conv1(x)) |
|
|
class RepConv2(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
in_channel, |
|
|
out_channel, |
|
|
bias=False, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.conv1 = nn.Sequential(nn.Conv1d(in_channel, int(in_channel), kernel_size=1, stride=1,bias=False), nn.BatchNorm1d(int(in_channel))) |
|
|
self.conv2 = nn.Sequential(nn.Conv1d(int(in_channel), out_channel, kernel_size=1, stride=1,bias=False), nn.BatchNorm1d(out_channel)) |
|
|
def forward(self, x): |
|
|
return self.conv2(self.conv1(x)) |
|
|
|
|
|
class MS_Attention_Conv_qkv_id(nn.Module): |
|
|
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): |
|
|
super().__init__() |
|
|
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." |
|
|
self.dim = dim |
|
|
self.num_heads = num_heads |
|
|
self.scale = 0.125 |
|
|
self.sr_ratio=sr_ratio |
|
|
|
|
|
self.head_lif = Multispike() |
|
|
|
|
|
|
|
|
self.q_conv = nn.Sequential(RepConv(dim,dim), nn.BatchNorm1d(dim)) |
|
|
self.k_conv = nn.Sequential(RepConv(dim,dim), nn.BatchNorm1d(dim)) |
|
|
self.v_conv = nn.Sequential(RepConv(dim,dim*sr_ratio), nn.BatchNorm1d(dim*sr_ratio)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.q_lif = Multispike() |
|
|
|
|
|
self.k_lif = Multispike() |
|
|
|
|
|
self.v_lif = Multispike() |
|
|
|
|
|
self.attn_lif = Multispike() |
|
|
|
|
|
self.proj_conv = nn.Sequential(RepConv(sr_ratio*dim,dim), nn.BatchNorm1d(dim)) |
|
|
|
|
|
def forward(self, x): |
|
|
T, B, C, N = x.shape |
|
|
|
|
|
x = self.head_lif(x) |
|
|
|
|
|
x_for_qkv = x.flatten(0, 1) |
|
|
q_conv_out = self.q_conv(x_for_qkv).reshape(T, B, C, N) |
|
|
|
|
|
q_conv_out = self.q_lif(q_conv_out) |
|
|
|
|
|
q = q_conv_out.transpose(-1, -2).reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, |
|
|
4) |
|
|
|
|
|
k_conv_out = self.k_conv(x_for_qkv).reshape(T, B, C, N) |
|
|
|
|
|
k_conv_out = self.k_lif(k_conv_out) |
|
|
|
|
|
k = k_conv_out.transpose(-1, -2).reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, |
|
|
4) |
|
|
|
|
|
v_conv_out = self.v_conv(x_for_qkv).reshape(T, B, self.sr_ratio*C, N) |
|
|
|
|
|
v_conv_out = self.v_lif(v_conv_out) |
|
|
|
|
|
v = v_conv_out.transpose(-1, -2).reshape(T, B, N, self.num_heads, self.sr_ratio*C // self.num_heads).permute(0, 1, 3, 2, |
|
|
4) |
|
|
|
|
|
x = k.transpose(-2, -1) @ v |
|
|
x = (q @ x) * self.scale |
|
|
x = x.transpose(3, 4).reshape(T, B, self.sr_ratio*C, N) |
|
|
x = self.attn_lif(x) |
|
|
|
|
|
x = self.proj_conv(x.flatten(0, 1)).reshape(T, B, C, N) |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MS_Block(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
dim, |
|
|
choice, |
|
|
num_heads, |
|
|
mlp_ratio=4.0, |
|
|
qkv_bias=False, |
|
|
qk_scale=None, |
|
|
drop=0.0, |
|
|
attn_drop=0.0, |
|
|
drop_path=0.0, |
|
|
norm_layer=nn.LayerNorm, |
|
|
sr_ratio=1,init_values=1e-6,finetune=False, |
|
|
): |
|
|
super().__init__() |
|
|
self.model=choice |
|
|
if self.model=="base": |
|
|
self.rep_conv=RepConv2(dim,dim) |
|
|
self.lif = Multispike() |
|
|
self.attn = MS_Attention_Conv_qkv_id( |
|
|
dim, |
|
|
num_heads=num_heads, |
|
|
qkv_bias=qkv_bias, |
|
|
qk_scale=qk_scale, |
|
|
attn_drop=attn_drop, |
|
|
proj_drop=drop, |
|
|
sr_ratio=sr_ratio, |
|
|
) |
|
|
self.finetune = finetune |
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() |
|
|
mlp_hidden_dim = int(dim * mlp_ratio) |
|
|
self.mlp = MS_MLP(in_features=dim, hidden_features=mlp_hidden_dim, drop=drop) |
|
|
|
|
|
if self.finetune: |
|
|
self.layer_scale1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) |
|
|
self.layer_scale2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) |
|
|
|
|
|
def forward(self, x): |
|
|
T, B, C, N = x.shape |
|
|
if self.model=="base": |
|
|
x= x + self.rep_conv(self.lif(x).flatten(0, 1)).reshape(T, B, C, N) |
|
|
|
|
|
if self.finetune: |
|
|
x = x + self.drop_path(self.attn(x) * self.layer_scale1.unsqueeze(0).unsqueeze(0).unsqueeze(-1)) |
|
|
x = x + self.drop_path(self.mlp(x) * self.layer_scale2.unsqueeze(0).unsqueeze(0).unsqueeze(-1)) |
|
|
else: |
|
|
x = x + self.attn(x) |
|
|
x = x + self.mlp(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class MS_DownSampling(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
in_channels=2, |
|
|
embed_dims=256, |
|
|
kernel_size=3, |
|
|
stride=2, |
|
|
padding=1, |
|
|
first_layer=True, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.encode_conv = nn.Conv2d( |
|
|
in_channels, |
|
|
embed_dims, |
|
|
kernel_size=kernel_size, |
|
|
stride=stride, |
|
|
padding=padding, |
|
|
) |
|
|
|
|
|
self.encode_bn =nn.BatchNorm2d(embed_dims) |
|
|
if not first_layer: |
|
|
self.encode_lif = Multispike() |
|
|
|
|
|
def forward(self, x): |
|
|
T, B, _, _, _ = x.shape |
|
|
if hasattr(self, "encode_lif"): |
|
|
x = self.encode_lif(x) |
|
|
x = self.encode_conv(x.flatten(0, 1)) |
|
|
_, _, H, W = x.shape |
|
|
x = self.encode_bn(x).reshape(T, B, -1, H, W).contiguous() |
|
|
return x |
|
|
|
|
|
|
|
|
class Spikformer(nn.Module): |
|
|
def __init__(self, T=1, |
|
|
choice=None, |
|
|
img_size_h=224, |
|
|
img_size_w=224, |
|
|
patch_size=16, |
|
|
embed_dim=[128, 256, 512, 640], |
|
|
num_heads=8, |
|
|
mlp_ratios=4, |
|
|
in_channels=3, |
|
|
qk_scale=None, |
|
|
drop_rate=0.0, |
|
|
attn_drop_rate=0.0, |
|
|
drop_path_rate=0.1, |
|
|
num_classes=1000, |
|
|
qkv_bias=False, |
|
|
norm_layer=partial(nn.LayerNorm, eps=1e-6), |
|
|
depths=8, |
|
|
sr_ratios=1, |
|
|
mlp_ratio=4., |
|
|
nb_classes=1000, |
|
|
kd=True): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.T = T |
|
|
self.patch_size = patch_size |
|
|
self.embed_dim =embed_dim |
|
|
dpr = [ |
|
|
x.item() for x in torch.linspace(0, drop_path_rate, depths) |
|
|
] |
|
|
self.downsample1_1 = MS_DownSampling( |
|
|
in_channels=in_channels, |
|
|
embed_dims=embed_dim[0] // 2, |
|
|
kernel_size=7, |
|
|
stride=2, |
|
|
padding=3, |
|
|
first_layer=True, |
|
|
) |
|
|
|
|
|
self.ConvBlock1_1 = nn.ModuleList( |
|
|
[MS_ConvBlock(dim=embed_dim[0] // 2, mlp_ratio=mlp_ratios)] |
|
|
) |
|
|
|
|
|
self.downsample1_2 = MS_DownSampling( |
|
|
in_channels=embed_dim[0] // 2, |
|
|
embed_dims=embed_dim[0], |
|
|
kernel_size=3, |
|
|
stride=2, |
|
|
padding=1, |
|
|
first_layer=False, |
|
|
) |
|
|
|
|
|
self.ConvBlock1_2 = nn.ModuleList( |
|
|
[MS_ConvBlock(dim=embed_dim[0], mlp_ratio=mlp_ratios)] |
|
|
) |
|
|
|
|
|
self.downsample2 = MS_DownSampling( |
|
|
in_channels=embed_dim[0], |
|
|
embed_dims=embed_dim[1], |
|
|
kernel_size=3, |
|
|
stride=2, |
|
|
padding=1, |
|
|
first_layer=False, |
|
|
) |
|
|
|
|
|
self.ConvBlock2_1 = nn.ModuleList( |
|
|
[MS_ConvBlock(dim=embed_dim[1], mlp_ratio=mlp_ratios)] |
|
|
) |
|
|
|
|
|
self.ConvBlock2_2 = nn.ModuleList( |
|
|
[MS_ConvBlock(dim=embed_dim[1], mlp_ratio=mlp_ratios)] |
|
|
) |
|
|
|
|
|
self.downsample3 = MS_DownSampling( |
|
|
in_channels=embed_dim[1], |
|
|
embed_dims=embed_dim[2], |
|
|
kernel_size=3, |
|
|
stride=2, |
|
|
padding=1, |
|
|
first_layer=False, |
|
|
) |
|
|
|
|
|
self.block3 = nn.ModuleList( |
|
|
[ |
|
|
MS_Block( |
|
|
dim=embed_dim[2], |
|
|
choice=choice, |
|
|
num_heads=num_heads, |
|
|
mlp_ratio=mlp_ratios, |
|
|
qkv_bias=qkv_bias, |
|
|
qk_scale=qk_scale, |
|
|
drop=drop_rate, |
|
|
attn_drop=attn_drop_rate, |
|
|
drop_path=dpr[j], |
|
|
norm_layer=norm_layer, |
|
|
sr_ratio=sr_ratios, |
|
|
finetune=True, |
|
|
) |
|
|
for j in range(depths) |
|
|
] |
|
|
) |
|
|
self.head = nn.Linear(embed_dim[2], nb_classes) |
|
|
self.lif = Multispike(norm=1) |
|
|
self.kd = kd |
|
|
if self.kd: |
|
|
self.head_kd = ( |
|
|
nn.Linear(embed_dim[-1], num_classes) |
|
|
if num_classes > 0 |
|
|
else nn.Identity() |
|
|
) |
|
|
self.initialize_weights() |
|
|
|
|
|
def initialize_weights(self): |
|
|
self.apply(self._init_weights) |
|
|
|
|
|
def _init_weights(self, m): |
|
|
if isinstance(m, nn.Linear): |
|
|
|
|
|
torch.nn.init.xavier_uniform_(m.weight) |
|
|
if isinstance(m, nn.Linear) and m.bias is not None: |
|
|
nn.init.constant_(m.bias, 0) |
|
|
elif isinstance(m, nn.LayerNorm): |
|
|
nn.init.constant_(m.bias, 0) |
|
|
nn.init.constant_(m.weight, 1.0) |
|
|
|
|
|
|
|
|
def forward_encoder(self, x ): |
|
|
x = (x.unsqueeze(0)).repeat(self.T, 1, 1, 1, 1) |
|
|
|
|
|
|
|
|
x = self.downsample1_1(x) |
|
|
|
|
|
for blk in self.ConvBlock1_1: |
|
|
x = blk(x) |
|
|
|
|
|
x = self.downsample1_2(x) |
|
|
for blk in self.ConvBlock1_2: |
|
|
x = blk(x) |
|
|
|
|
|
x = self.downsample2(x) |
|
|
|
|
|
for blk in self.ConvBlock2_1: |
|
|
x = blk(x) |
|
|
|
|
|
for blk in self.ConvBlock2_2: |
|
|
x = blk(x) |
|
|
x = self.downsample3(x) |
|
|
x = x.flatten(3) |
|
|
|
|
|
for blk in self.block3: |
|
|
x = blk(x) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
def forward(self, imgs): |
|
|
x = self.forward_encoder(imgs) |
|
|
|
|
|
x = x.flatten(3).mean(3) |
|
|
x_lif = self.lif(x) |
|
|
x = self.head(x).mean(0) |
|
|
|
|
|
if self.kd: |
|
|
x_kd = self.head_kd(x_lif).mean(0) |
|
|
if self.training: |
|
|
return x, x_kd |
|
|
else: |
|
|
return (x + x_kd) / 2 |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
def spikformer12_512(**kwargs): |
|
|
model = Spikformer( |
|
|
T=1, |
|
|
choice="base", |
|
|
img_size_h=32, |
|
|
img_size_w=32, |
|
|
patch_size=16, |
|
|
embed_dim=[128,256,512], |
|
|
num_heads=8, |
|
|
mlp_ratios=4, |
|
|
in_channels=3, |
|
|
num_classes=1000, |
|
|
qkv_bias=False, |
|
|
norm_layer=partial(nn.LayerNorm, eps=1e-6), |
|
|
depths=12, |
|
|
**kwargs) |
|
|
return model |
|
|
def spikformer12_768(**kwargs): |
|
|
model = Spikformer( |
|
|
T=1, |
|
|
choice="large", |
|
|
img_size_h=32, |
|
|
img_size_w=32, |
|
|
patch_size=16, |
|
|
embed_dim=[196, 384, 768], |
|
|
num_heads=8, |
|
|
mlp_ratios=4, |
|
|
in_channels=3, |
|
|
num_classes=1000, |
|
|
qkv_bias=False, |
|
|
norm_layer=partial(nn.LayerNorm, eps=1e-6), |
|
|
depths=12, |
|
|
**kwargs) |
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
import torchinfo |
|
|
|
|
|
model = spikformer12_512() |
|
|
|
|
|
|
|
|
print(f"number of params: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") |
|
|
|
|
|
|