|
|
from functools import partial |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torchinfo |
|
|
from timm.models.layers import to_2tuple, trunc_normal_, DropPath |
|
|
from timm.models.registry import register_model |
|
|
from timm.models.vision_transformer import _cfg |
|
|
from einops.layers.torch import Rearrange |
|
|
import torch.nn.functional as F |
|
|
from timm.models.vision_transformer import PatchEmbed, Block |
|
|
|
|
|
from spikingjelly.clock_driven import layer |
|
|
import copy |
|
|
from torchvision import transforms |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
import models.encoder as encoder |
|
|
from .util.pos_embed import get_2d_sincos_pos_embed |
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
T=4 |
|
|
|
|
|
|
|
|
class multispike(torch.autograd.Function): |
|
|
@staticmethod |
|
|
def forward(ctx, input, lens=T): |
|
|
ctx.save_for_backward(input) |
|
|
ctx.lens = lens |
|
|
return torch.floor(torch.clamp(input, 0, lens) + 0.5) |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, grad_output): |
|
|
input, = ctx.saved_tensors |
|
|
grad_input = grad_output.clone() |
|
|
temp1 = 0 < input |
|
|
temp2 = input < ctx.lens |
|
|
return grad_input * temp1.float() * temp2.float(), None |
|
|
|
|
|
|
|
|
class Multispike(nn.Module): |
|
|
def __init__(self, spike=multispike,norm=T): |
|
|
super().__init__() |
|
|
self.lens = norm |
|
|
self.spike = spike |
|
|
self.norm=norm |
|
|
|
|
|
def forward(self, inputs): |
|
|
return self.spike.apply(inputs)/self.norm |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def MS_conv_unit(in_channels, out_channels,kernel_size=1,padding=0,groups=1): |
|
|
return nn.Sequential( |
|
|
layer.SeqToANNContainer( |
|
|
encoder.SparseConv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, groups=groups,bias=True), |
|
|
encoder.SparseBatchNorm2d(out_channels) |
|
|
) |
|
|
) |
|
|
class MS_ConvBlock(nn.Module): |
|
|
def __init__(self, dim, |
|
|
mlp_ratio=4.0): |
|
|
super().__init__() |
|
|
|
|
|
self.neuron1 = Multispike() |
|
|
self.conv1 = MS_conv_unit(dim, dim * mlp_ratio, 3, 1) |
|
|
|
|
|
self.neuron2 = Multispike() |
|
|
self.conv2 = MS_conv_unit(dim*mlp_ratio, dim, 3, 1) |
|
|
|
|
|
|
|
|
def forward(self, x, mask=None): |
|
|
short_cut = x |
|
|
x = self.neuron1(x) |
|
|
x = self.conv1(x) |
|
|
x = self.neuron2(x) |
|
|
x = self.conv2(x) |
|
|
x = x +short_cut |
|
|
return x |
|
|
|
|
|
class MS_MLP(nn.Module): |
|
|
def __init__( |
|
|
self, in_features, hidden_features=None, out_features=None, drop=0.0, layer=0 |
|
|
): |
|
|
super().__init__() |
|
|
out_features = out_features or in_features |
|
|
hidden_features = hidden_features or in_features |
|
|
self.fc1_conv = nn.Conv1d(in_features, hidden_features, kernel_size=1, stride=1) |
|
|
self.fc1_bn = nn.BatchNorm1d(hidden_features) |
|
|
self.fc1_lif = Multispike() |
|
|
|
|
|
|
|
|
self.fc2_conv = nn.Conv1d( |
|
|
hidden_features, out_features, kernel_size=1, stride=1 |
|
|
) |
|
|
self.fc2_bn = nn.BatchNorm1d(out_features) |
|
|
self.fc2_lif = Multispike() |
|
|
|
|
|
self.c_hidden = hidden_features |
|
|
self.c_output = out_features |
|
|
|
|
|
def forward(self, x): |
|
|
T, B, C, N= x.shape |
|
|
|
|
|
x = self.fc1_lif(x) |
|
|
x = self.fc1_conv(x.flatten(0, 1)) |
|
|
x = self.fc1_bn(x).reshape(T, B, self.c_hidden, N).contiguous() |
|
|
|
|
|
x = self.fc2_lif(x) |
|
|
x = self.fc2_conv(x.flatten(0, 1)) |
|
|
x = self.fc2_bn(x).reshape(T, B, C, N).contiguous() |
|
|
|
|
|
return x |
|
|
|
|
|
class RepConv(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
in_channel, |
|
|
out_channel, |
|
|
bias=False, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.conv1 = nn.Sequential(nn.Conv1d(in_channel, int(in_channel*1.5), kernel_size=1, stride=1,bias=False), nn.BatchNorm1d(int(in_channel*1.5))) |
|
|
self.conv2 = nn.Sequential(nn.Conv1d(int(in_channel*1.5), out_channel, kernel_size=1, stride=1,bias=False), nn.BatchNorm1d(out_channel)) |
|
|
def forward(self, x): |
|
|
return self.conv2(self.conv1(x)) |
|
|
class RepConv2(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
in_channel, |
|
|
out_channel, |
|
|
bias=False, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.conv1 = nn.Sequential(nn.Conv1d(in_channel, int(in_channel*1.5), kernel_size=1, stride=1,bias=False), nn.BatchNorm1d(int(in_channel*1.5))) |
|
|
self.conv2 = nn.Sequential(nn.Conv1d(int(in_channel*1.5), out_channel, kernel_size=1, stride=1,bias=False), nn.BatchNorm1d(out_channel)) |
|
|
def forward(self, x): |
|
|
return self.conv2(self.conv1(x)) |
|
|
|
|
|
class MS_Attention_Conv_qkv_id(nn.Module): |
|
|
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): |
|
|
super().__init__() |
|
|
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." |
|
|
self.dim = dim |
|
|
self.num_heads = num_heads |
|
|
self.scale = 0.125 |
|
|
self.sr_ratio=sr_ratio |
|
|
|
|
|
self.head_lif = Multispike() |
|
|
|
|
|
|
|
|
self.q_conv = nn.Sequential(RepConv(dim,dim), nn.BatchNorm1d(dim)) |
|
|
self.k_conv = nn.Sequential(RepConv(dim,dim), nn.BatchNorm1d(dim)) |
|
|
self.v_conv = nn.Sequential(RepConv(dim,dim*sr_ratio), nn.BatchNorm1d(dim*sr_ratio)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.q_lif = Multispike() |
|
|
|
|
|
self.k_lif = Multispike() |
|
|
|
|
|
self.v_lif = Multispike() |
|
|
|
|
|
self.attn_lif = Multispike() |
|
|
|
|
|
self.proj_conv = nn.Sequential(RepConv(sr_ratio*dim,dim), nn.BatchNorm1d(dim)) |
|
|
|
|
|
def forward(self, x): |
|
|
T, B, C, N = x.shape |
|
|
|
|
|
x = self.head_lif(x) |
|
|
|
|
|
x_for_qkv = x.flatten(0, 1) |
|
|
q_conv_out = self.q_conv(x_for_qkv).reshape(T, B, C, N) |
|
|
|
|
|
q_conv_out = self.q_lif(q_conv_out) |
|
|
|
|
|
q = q_conv_out.transpose(-1, -2).reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, |
|
|
4) |
|
|
|
|
|
k_conv_out = self.k_conv(x_for_qkv).reshape(T, B, C, N) |
|
|
|
|
|
k_conv_out = self.k_lif(k_conv_out) |
|
|
|
|
|
k = k_conv_out.transpose(-1, -2).reshape(T, B, N, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, |
|
|
4) |
|
|
|
|
|
v_conv_out = self.v_conv(x_for_qkv).reshape(T, B, self.sr_ratio*C, N) |
|
|
|
|
|
v_conv_out = self.v_lif(v_conv_out) |
|
|
|
|
|
v = v_conv_out.transpose(-1, -2).reshape(T, B, N, self.num_heads, self.sr_ratio*C // self.num_heads).permute(0, 1, 3, 2, |
|
|
4) |
|
|
|
|
|
x = k.transpose(-2, -1) @ v |
|
|
x = (q @ x) * self.scale |
|
|
x = x.transpose(3, 4).reshape(T, B, self.sr_ratio*C, N) |
|
|
x = self.attn_lif(x) |
|
|
|
|
|
x = self.proj_conv(x.flatten(0, 1)).reshape(T, B, C, N) |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MS_DownSampling(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
in_channels=2, |
|
|
embed_dims=256, |
|
|
kernel_size=3, |
|
|
stride=2, |
|
|
padding=1, |
|
|
first_layer=True, |
|
|
|
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.encode_conv = encoder.SparseConv2d( |
|
|
in_channels, |
|
|
embed_dims, |
|
|
kernel_size=kernel_size, |
|
|
stride=stride, |
|
|
padding=padding, |
|
|
) |
|
|
|
|
|
self.encode_bn = encoder.SparseBatchNorm2d(embed_dims) |
|
|
self.first_layer = first_layer |
|
|
if not first_layer: |
|
|
self.encode_spike = Multispike() |
|
|
|
|
|
def forward(self, x): |
|
|
T, B, _, _, _ = x.shape |
|
|
if hasattr(self, "encode_spike"): |
|
|
x = self.encode_spike(x) |
|
|
x = self.encode_conv(x.flatten(0, 1)) |
|
|
_, _, H, W = x.shape |
|
|
x = self.encode_bn(x).reshape(T, B, -1, H, W) |
|
|
|
|
|
return x |
|
|
|
|
|
class MS_Block(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
dim, |
|
|
choice, |
|
|
num_heads, |
|
|
mlp_ratio=4.0, |
|
|
qkv_bias=False, |
|
|
qk_scale=None, |
|
|
drop=0.0, |
|
|
attn_drop=0.0, |
|
|
drop_path=0.0, |
|
|
norm_layer=nn.LayerNorm, |
|
|
sr_ratio=1,init_values=1e-6,finetune=False, |
|
|
): |
|
|
super().__init__() |
|
|
self.model=choice |
|
|
if self.model=="base": |
|
|
self.rep_conv=RepConv2(dim,dim) |
|
|
self.lif = Multispike() |
|
|
self.attn = MS_Attention_Conv_qkv_id( |
|
|
dim, |
|
|
num_heads=num_heads, |
|
|
qkv_bias=qkv_bias, |
|
|
qk_scale=qk_scale, |
|
|
attn_drop=attn_drop, |
|
|
proj_drop=drop, |
|
|
sr_ratio=sr_ratio, |
|
|
) |
|
|
self.finetune = finetune |
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() |
|
|
mlp_hidden_dim = int(dim * mlp_ratio) |
|
|
self.mlp = MS_MLP(in_features=dim, hidden_features=mlp_hidden_dim, drop=drop) |
|
|
|
|
|
if self.finetune: |
|
|
self.layer_scale1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) |
|
|
self.layer_scale2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True) |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
if self.model=="base": |
|
|
x= x + self.rep_conv(self.lif(x).flatten(0, 1)).reshape(T, B, C, N) |
|
|
|
|
|
if self.finetune: |
|
|
x = x + self.drop_path(self.attn(x) * self.layer_scale1.unsqueeze(0).unsqueeze(0).unsqueeze(-1)) |
|
|
x = x + self.drop_path(self.mlp(x) * self.layer_scale2.unsqueeze(0).unsqueeze(0).unsqueeze(-1)) |
|
|
else: |
|
|
x = x + self.attn(x) |
|
|
x = x + self.mlp(x) |
|
|
return x |
|
|
|
|
|
class Spikmae(nn.Module): |
|
|
def __init__(self, T=1,choice=None, |
|
|
img_size_h=224, |
|
|
img_size_w=224, |
|
|
patch_size=16, |
|
|
embed_dim=[128, 256, 512], |
|
|
num_heads=8, |
|
|
mlp_ratios=4, |
|
|
in_channels=3, |
|
|
qk_scale=None, |
|
|
drop_rate=0.0, |
|
|
attn_drop_rate=0.0, |
|
|
drop_path_rate=0.0, |
|
|
num_classes=1000, |
|
|
qkv_bias=False, |
|
|
norm_layer=partial(nn.LayerNorm, eps=1e-6), |
|
|
depths=8, |
|
|
sr_ratios=1, |
|
|
decoder_embed_dim=768, |
|
|
decoder_depth=4, |
|
|
decoder_num_heads=16, |
|
|
mlp_ratio=4., |
|
|
norm_pix_loss=False, nb_classes=1000): |
|
|
super().__init__() |
|
|
|
|
|
self.num_classes = num_classes |
|
|
self.depths = depths |
|
|
self.T = 1 |
|
|
|
|
|
dpr = [ |
|
|
x.item() for x in torch.linspace(0, drop_path_rate, depths) |
|
|
] |
|
|
|
|
|
self.downsample1_1 = MS_DownSampling( |
|
|
in_channels=in_channels, |
|
|
embed_dims=embed_dim[0] // 2, |
|
|
kernel_size=7, |
|
|
stride=2, |
|
|
padding=3, |
|
|
first_layer=True, |
|
|
) |
|
|
|
|
|
self.ConvBlock1_1 = nn.ModuleList( |
|
|
[MS_ConvBlock(dim=embed_dim[0] // 2, mlp_ratio=mlp_ratios)] |
|
|
) |
|
|
|
|
|
self.downsample1_2 = MS_DownSampling( |
|
|
in_channels=embed_dim[0] // 2, |
|
|
embed_dims=embed_dim[0], |
|
|
kernel_size=3, |
|
|
stride=2, |
|
|
padding=1, |
|
|
first_layer=False, |
|
|
|
|
|
) |
|
|
|
|
|
self.ConvBlock1_2 = nn.ModuleList( |
|
|
[MS_ConvBlock(dim=embed_dim[0], mlp_ratio=mlp_ratios)] |
|
|
) |
|
|
|
|
|
self.downsample2 = MS_DownSampling( |
|
|
in_channels=embed_dim[0], |
|
|
embed_dims=embed_dim[1], |
|
|
kernel_size=3, |
|
|
stride=2, |
|
|
padding=1, |
|
|
first_layer=False, |
|
|
|
|
|
) |
|
|
|
|
|
self.ConvBlock2_1 = nn.ModuleList( |
|
|
[MS_ConvBlock(dim=embed_dim[1], mlp_ratio=mlp_ratios)] |
|
|
) |
|
|
|
|
|
self.ConvBlock2_2 = nn.ModuleList( |
|
|
[MS_ConvBlock(dim=embed_dim[1], mlp_ratio=mlp_ratios)] |
|
|
) |
|
|
|
|
|
self.downsample3 = MS_DownSampling( |
|
|
in_channels=embed_dim[1], |
|
|
embed_dims=embed_dim[2], |
|
|
kernel_size=3, |
|
|
stride=2, |
|
|
padding=1, |
|
|
first_layer=False, |
|
|
|
|
|
) |
|
|
|
|
|
self.block3 = nn.ModuleList( |
|
|
[ |
|
|
MS_Block( |
|
|
dim=embed_dim[2], |
|
|
choice=choice, |
|
|
num_heads=num_heads, |
|
|
mlp_ratio=mlp_ratios, |
|
|
qkv_bias=qkv_bias, |
|
|
qk_scale=qk_scale, |
|
|
drop=drop_rate, |
|
|
attn_drop=attn_drop_rate, |
|
|
drop_path=dpr[j], |
|
|
norm_layer=norm_layer, |
|
|
sr_ratio=sr_ratios, |
|
|
finetune=False, |
|
|
) |
|
|
for j in range(depths) |
|
|
] |
|
|
) |
|
|
|
|
|
self.norm = nn.BatchNorm1d(embed_dim[-1]) |
|
|
self.downsample_raito =16 |
|
|
|
|
|
num_patches = 196 |
|
|
|
|
|
self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim[-1],num_patches), requires_grad=False) |
|
|
|
|
|
|
|
|
self.decoder_embed = nn.Linear(embed_dim[-1], decoder_embed_dim,bias=True) |
|
|
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) |
|
|
|
|
|
self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches, decoder_embed_dim), requires_grad=False) |
|
|
self.decoder_blocks = nn.ModuleList([ |
|
|
Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=False, norm_layer=norm_layer) |
|
|
for i in range(decoder_depth)]) |
|
|
self.decoder_norm = norm_layer(decoder_embed_dim) |
|
|
self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size ** 2 * in_channels,bias=True) |
|
|
self.initialize_weights() |
|
|
|
|
|
def initialize_weights(self): |
|
|
num_patches=196 |
|
|
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[1], int(num_patches ** .5), |
|
|
cls_token=False) |
|
|
|
|
|
self.pos_embed.data.copy_(torch.from_numpy(pos_embed.transpose(1,0)).float().unsqueeze(0)) |
|
|
|
|
|
decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], |
|
|
int(num_patches** .5), cls_token=False) |
|
|
self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) |
|
|
|
|
|
torch.nn.init.normal_(self.mask_token, std=.02) |
|
|
self.apply(self._init_weights) |
|
|
|
|
|
def _init_weights(self, m): |
|
|
if isinstance(m, nn.Linear): |
|
|
trunc_normal_(m.weight, std=0.02) |
|
|
if isinstance(m, nn.Linear) and m.bias is not None: |
|
|
nn.init.constant_(m.bias, 0) |
|
|
elif isinstance(m, nn.LayerNorm): |
|
|
nn.init.constant_(m.bias, 0) |
|
|
nn.init.constant_(m.weight, 1.0) |
|
|
def random_masking(self, x, mask_ratio): |
|
|
""" |
|
|
Perform per-sample random masking by per-sample shuffling. |
|
|
Per-sample shuffling is done by argsort random noise. |
|
|
x: [N, L, D], sequence |
|
|
""" |
|
|
num_patches=196 |
|
|
T, N, _, _, _ = x.shape |
|
|
L = num_patches |
|
|
len_keep = int(L * (1 - mask_ratio)) |
|
|
|
|
|
noise = torch.rand(N, L, device=x.device) |
|
|
|
|
|
|
|
|
ids_shuffle = torch.argsort(noise, dim=1) |
|
|
ids_restore = torch.argsort(ids_shuffle, dim=1) |
|
|
|
|
|
|
|
|
ids_keep = ids_shuffle[:, :len_keep] |
|
|
|
|
|
|
|
|
mask = torch.ones([N, L], device=x.device) |
|
|
mask[:, :len_keep] = 0 |
|
|
|
|
|
mask = torch.gather(mask, dim=1, index=ids_restore) |
|
|
|
|
|
|
|
|
active = torch.ones([N, L], device=x.device) |
|
|
active[:, len_keep:] = 0 |
|
|
active = torch.gather(active, dim=1, index=ids_restore) |
|
|
|
|
|
return ids_keep, active, ids_restore |
|
|
|
|
|
def forward_encoder(self, x , mask_ratio=1.0): |
|
|
x = (x.unsqueeze(0)).repeat(self.T, 1, 1, 1, 1) |
|
|
|
|
|
ids_keep, active, ids_restore = self.random_masking(x , mask_ratio) |
|
|
B,N=active.shape |
|
|
active_b1ff=active.reshape(B,1,14,14) |
|
|
|
|
|
|
|
|
encoder._cur_active = active_b1ff |
|
|
active_hw = active_b1ff.repeat_interleave(self.downsample_raito, 2).repeat_interleave(self.downsample_raito, 3) |
|
|
active_hw = active_hw.unsqueeze(0) |
|
|
masked_bchw = x * active_hw |
|
|
x = masked_bchw |
|
|
x = self.downsample1_1(x) |
|
|
for blk in self.ConvBlock1_1: |
|
|
x = blk(x) |
|
|
x = self.downsample1_2(x) |
|
|
for blk in self.ConvBlock1_2: |
|
|
x = blk(x) |
|
|
|
|
|
x = self.downsample2(x) |
|
|
for blk in self.ConvBlock2_1: |
|
|
x = blk(x) |
|
|
for blk in self.ConvBlock2_2: |
|
|
x = blk(x) |
|
|
|
|
|
x = self.downsample3(x) |
|
|
x = x.flatten(3) |
|
|
for blk in self.block3: |
|
|
x = blk(x) |
|
|
|
|
|
x = x.mean(0) |
|
|
x = self.norm(x).transpose(-1, -2).contiguous() |
|
|
return x, active,ids_restore,active_hw |
|
|
|
|
|
def forward_decoder(self, x, ids_restore): |
|
|
|
|
|
B, N, C = x.shape |
|
|
x = self.decoder_embed(x) |
|
|
|
|
|
|
|
|
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] - x.shape[1], 1) |
|
|
x_ = torch.cat([x[:, :, :], mask_tokens], dim=1) |
|
|
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) |
|
|
x = x_ |
|
|
|
|
|
|
|
|
x = x + self.decoder_pos_embed |
|
|
|
|
|
for blk in self.decoder_blocks: |
|
|
x = blk(x) |
|
|
x = self.decoder_norm(x) |
|
|
x = self.decoder_pred(x) |
|
|
|
|
|
return x |
|
|
|
|
|
def patchify(self, imgs): |
|
|
""" |
|
|
imgs: (N, 3, H, W) |
|
|
x: (N, L, patch_size**2 *3) |
|
|
""" |
|
|
p = 16 |
|
|
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 |
|
|
|
|
|
h = w = imgs.shape[2] // p |
|
|
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) |
|
|
x = torch.einsum('nchpwq->nhwpqc', x) |
|
|
x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3)) |
|
|
return x |
|
|
|
|
|
def unpatchify(self, x): |
|
|
""" |
|
|
x: (N, L, patch_size**2 *3) |
|
|
imgs: (N, 3, H, W) |
|
|
""" |
|
|
p = 16 |
|
|
h = w = int(x.shape[1] ** .5) |
|
|
assert h * w == x.shape[1] |
|
|
|
|
|
x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) |
|
|
x = torch.einsum('nhwpqc->nchpwq', x) |
|
|
imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) |
|
|
return imgs |
|
|
def forward_loss(self, imgs, pred, mask): |
|
|
""" |
|
|
imgs: [N, 3, H, W] |
|
|
pred: [N, L, p*p*3] |
|
|
mask: [N, L], 0 is keep, 1 is remove, |
|
|
""" |
|
|
|
|
|
inp, rec = self.patchify(imgs), pred |
|
|
mean = inp.mean(dim=-1, keepdim=True) |
|
|
var = (inp.var(dim=-1, keepdim=True) + 1e-6) ** .5 |
|
|
inp = (inp - mean) / var |
|
|
l2_loss = ((rec - inp) ** 2).mean(dim=2, keepdim=False) |
|
|
non_active = mask.logical_not().int().view(mask.shape[0], -1) |
|
|
recon_loss = l2_loss.mul_(non_active).sum() / (non_active.sum() + 1e-8) |
|
|
return recon_loss,mean,var |
|
|
|
|
|
def forward(self, imgs, mask_ratio=0.5,vis=False): |
|
|
|
|
|
latent, active, ids_restore,active_hw = self.forward_encoder(imgs, mask_ratio) |
|
|
rec = self.forward_decoder(latent, ids_restore) |
|
|
recon_loss,mean,var = self.forward_loss(imgs, rec, active) |
|
|
if vis: |
|
|
masked_bchw = imgs * active_hw.flatten(0,1) |
|
|
rec_bchw = self.unpatchify(rec * var + mean) |
|
|
rec_or_inp = torch.where(active_hw.flatten(0,1).bool(), imgs, rec_bchw) |
|
|
return imgs, masked_bchw, rec_or_inp |
|
|
else: |
|
|
return recon_loss |
|
|
|
|
|
|
|
|
def spikmae_12_512(**kwargs): |
|
|
model = Spikmae( |
|
|
T=1, |
|
|
choice="base", |
|
|
img_size_h=224, |
|
|
img_size_w=224, |
|
|
patch_size=16, |
|
|
embed_dim=[128,256,512], |
|
|
num_heads=8, |
|
|
mlp_ratios=4, |
|
|
in_channels=3, |
|
|
num_classes=1000, |
|
|
qkv_bias=False, |
|
|
norm_layer=partial(nn.LayerNorm, eps=1e-6), |
|
|
depths=12, |
|
|
sr_ratios=1, decoder_embed_dim=256, decoder_depth=4, decoder_num_heads=4, |
|
|
**kwargs) |
|
|
return model |
|
|
def spikmae_12_768(**kwargs): |
|
|
model = Spikmae( |
|
|
T=1, |
|
|
choice="large", |
|
|
img_size_h=224, |
|
|
img_size_w=224, |
|
|
patch_size=16, |
|
|
embed_dim=[192,384,768], |
|
|
num_heads=8, |
|
|
mlp_ratios=4, |
|
|
in_channels=3, |
|
|
num_classes=1000, |
|
|
qkv_bias=False, |
|
|
norm_layer=partial(nn.LayerNorm, eps=1e-6), |
|
|
depths=12, |
|
|
sr_ratios=1, decoder_embed_dim=256, decoder_depth=4, decoder_num_heads=4, |
|
|
**kwargs) |
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
model = spikmae_12_768() |
|
|
x=torch.randn(1,3,224,224) |
|
|
loss = model(x,mask_ratio=0.50) |
|
|
print('loss',loss) |
|
|
torchinfo.summary(model, (1, 3, 224, 224)) |
|
|
print(f"number of params: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") |
|
|
|