slam3r-i2p_demo / slam3r /blocks /multiview_blocks.py
siyan824's picture
init
8bd45de
import torch
import torch.nn as nn
from .basic_blocks import Mlp, Attention, CrossAttention, DropPath
try:
import xformers.ops as xops
XFORMERS_AVALIABLE = True
except ImportError:
print("xformers not avaliable, use self-implemented attention instead")
XFORMERS_AVALIABLE = False
class XFormer_Attention(nn.Module):
"""Warpper for self-attention module with xformers.
Calculate attention scores with xformers memory_efficient_attention.
When inference is performed on the CPU or when xformer is not installed, it will degrade to the normal attention.
"""
def __init__(self, old_module:Attention):
super().__init__()
self.num_heads = old_module.num_heads
self.scale = old_module.scale
self.qkv = old_module.qkv
self.attn_drop_prob = old_module.attn_drop.p
self.proj = old_module.proj
self.proj_drop = old_module.proj_drop
self.rope = old_module.rope
self.attn_drop = old_module.attn_drop
def forward(self, x, xpos):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).transpose(1,3)
q, k, v = [qkv[:,:,i] for i in range(3)] #shape: (B, num_heads, N, C//num_heads)
if self.rope is not None:
q = self.rope(q, xpos) # (B, H, N, K)
k = self.rope(k, xpos)
if x.is_cuda and XFORMERS_AVALIABLE:
q = q.permute(0, 2, 1, 3) # (B, N, H, K)
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
drop_prob = self.attn_drop_prob if self.training else 0
x = xops.memory_efficient_attention(q, k, v, scale=self.scale, p=drop_prob) # (B, N, H, K)
else:
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2)
x=x.reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class MultiviewDecoderBlock_max(nn.Module):
"""Multiview decoder block,
which takes as input arbitrary number of source views and target view features.
Use max-pooling to merge features queried from different src views.
"""
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, norm_mem=True, rope=None):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
self.cross_attn = CrossAttention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
self.norm3 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.norm_y = norm_layer(dim) if norm_mem else nn.Identity()
if XFORMERS_AVALIABLE:
self.attn = XFormer_Attention(self.attn)
def batched_cross_attn(self, xs, ys, xposes, yposes, rel_ids_list_d, M):
"""
Calculate cross-attention between Vx target views and Vy source views in a single batch.
"""
xs_normed = self.norm2(xs)
ys_normed = self.norm_y(ys)
cross_attn = self.cross_attn
Vx, B, Nx, C = xs.shape
Vy, B, Ny, C = ys.shape
num_heads = cross_attn.num_heads
#precompute q,k,v for each view to save computation
qs = cross_attn.projq(xs_normed).reshape(Vx*B,Nx,num_heads, C//num_heads).permute(0, 2, 1, 3) # (Vx*B,num_heads,Nx,C//num_heads)
ks = cross_attn.projk(ys_normed).reshape(Vy*B,Ny,num_heads, C//num_heads).permute(0, 2, 1, 3) # (Vy*B,num_heads,Ny,C//num_heads)
vs = cross_attn.projv(ys_normed).reshape(Vy,B,Ny,num_heads, C//num_heads) # (Vy*B,num_heads,Ny,C//num_heads)
#add rope
if cross_attn.rope is not None:
qs = cross_attn.rope(qs, xposes)
ks = cross_attn.rope(ks, yposes)
qs = qs.permute(0, 2, 1, 3).reshape(Vx,B,Nx,num_heads,C// num_heads) # (Vx, B, Nx, H, K)
ks = ks.permute(0, 2, 1, 3).reshape(Vy,B,Ny,num_heads,C// num_heads) # (Vy, B, Ny, H, K)
# construct query, key, value for each target view
ks_respect = torch.index_select(ks, 0, rel_ids_list_d) # (Vx*M, B, Ny, H, K)
vs_respect = torch.index_select(vs, 0, rel_ids_list_d) # (Vx*M, B, Ny, H, K)
qs_corresp = torch.unsqueeze(qs, 1).expand(-1, M, -1, -1, -1, -1) # (Vx, M, B, Nx, H, K)
ks_compact = ks_respect.reshape(Vx*M*B, Ny, num_heads, C//num_heads)
vs_compact = vs_respect.reshape(Vx*M*B, Ny, num_heads, C//num_heads)
qs_compact = qs_corresp.reshape(Vx*M*B, Nx, num_heads, C//num_heads)
# calculate attention results for all target views in one go
if xs.is_cuda and XFORMERS_AVALIABLE:
drop_prob = cross_attn.attn_drop.p if self.training else 0
attn_outputs = xops.memory_efficient_attention(qs_compact, ks_compact, vs_compact,
scale=self.cross_attn.scale, p=drop_prob) # (V*M*B, N, H, K)
else:
ks_compact = ks_compact.permute(0, 2, 1, 3) # (Vx*M*B, H, Ny, K)
qs_compact = qs_compact.permute(0, 2, 1, 3) # (Vx*M*B, H, Nx, K)
vs_compact = vs_compact.permute(0, 2, 1, 3) # (Vx*M*B, H, Ny, K)
attn = (qs_compact @ ks_compact.transpose(-2, -1)) * self.cross_attn.scale # (V*M*B, H, Nx, Ny)
attn = attn.softmax(dim=-1) # (V*M*B, H, Nx, Ny)
attn = self.cross_attn.attn_drop(attn)
attn_outputs = (attn @ vs_compact).transpose(1, 2).reshape(Vx*M*B, Nx, num_heads, C//num_heads) # (V*M*B, Nx, H, K)
attn_outputs = attn_outputs.reshape(Vx, M, B, Nx, C) #(Vx, M, B, Nx, C)
attn_outputs = cross_attn.proj_drop(cross_attn.proj(attn_outputs)) #(Vx, M, B, Nx, C)
return attn_outputs
def forward(self, xs:torch.Tensor, ys:torch.Tensor,
xposes:torch.Tensor, yposes:torch.Tensor,
rel_ids_list_d:torch.Tensor, M:int):
"""refine Vx target view feature parallelly, with the information of Vy source view
Args:
xs: (Vx,B,S,D): features of target views to refine.(S: number of tokens, D: feature dimension)
ys: (Vy,B,S,D): features of source views to query from.
M: number of source views to query from for each target view
rel_ids_list_d: (Vx*M,) indices of source views to query from for each target view
For example:
Suppose we have 3 target views and 4 source views,
then xs shuold has shape (3,B,S,D), ys should has shape (4,B,S,D).
If we require xs[0] to query features from ys[0], ys[1],
xs[1] to query features from ys[2], ys[2],(duplicate number supported)
xs[2] to query features from ys[2], ys[3],
then we should set M=2, rel_ids_list_d=[0,1, 2,2, 2,3]
"""
Vx, B, Nx, C = xs.shape
# self-attention on each target view feature
xs = xs.reshape(-1, *xs.shape[2:]) # (Vx*B,S,C)
xposes = xposes.reshape(-1, *xposes.shape[2:]) # (Vx*B,S,2)
yposes = yposes.reshape(-1, *yposes.shape[2:])
xs = xs + self.drop_path(self.attn(self.norm1(xs), xposes)) #(Vx*B,S,C)
# each target view conducts cross-attention with all source views to query features
attn_outputs = self.batched_cross_attn(xs.reshape(Vx,B,Nx,C), ys, xposes, yposes, rel_ids_list_d, M)
# max-pooling to aggregate features queried from different source views
merged_ys, indices = torch.max(attn_outputs, dim=1) #(Vx, B, Nx, C)
merged_ys = merged_ys.reshape(Vx*B,Nx,C) #(Vx*B,Nx,C)
xs = xs + self.drop_path(merged_ys)
xs = xs + self.drop_path(self.mlp(self.norm3(xs))) #(VB,N,C)
xs = xs.reshape(Vx,B,Nx,C)
return xs