|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from typing import List |
|
|
import math |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SinusoidalPosEmb(nn.Module): |
|
|
def __init__(self, dim): |
|
|
super().__init__() |
|
|
self.dim = dim |
|
|
|
|
|
def forward(self, x): |
|
|
device = x.device |
|
|
half_dim = self.dim // 2 |
|
|
if half_dim == 0: |
|
|
|
|
|
return torch.sin(x).unsqueeze(-1) |
|
|
elif half_dim == 1: |
|
|
|
|
|
emb = x[:, None] * 1.0 |
|
|
return torch.cat((emb.sin(), emb.cos()), dim=-1) |
|
|
else: |
|
|
emb = math.log(10000) / (half_dim - 1) |
|
|
emb = torch.exp(torch.arange(half_dim, device=device) * -emb) |
|
|
emb = x[:, None] * emb[None, :] |
|
|
emb = torch.cat((emb.sin(), emb.cos()), dim=-1) |
|
|
return emb |
|
|
|
|
|
|
|
|
class FilmLayer(nn.Module): |
|
|
def __init__(self, embedding_dim, num_channels): |
|
|
super().__init__() |
|
|
self.mlp = nn.Sequential(nn.Linear(embedding_dim, num_channels * 2), nn.ReLU()) |
|
|
|
|
|
def forward(self, x, context): |
|
|
mlp_out = self.mlp(context) |
|
|
scale = mlp_out[:, : x.shape[1]] |
|
|
bias = mlp_out[:, x.shape[1] :] |
|
|
|
|
|
scale = scale.view(x.shape[0], x.shape[1], 1, 1, 1) |
|
|
bias = bias.view(x.shape[0], x.shape[1], 1, 1, 1) |
|
|
|
|
|
return (1.0 + scale) * x + bias |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ResNetBlock3D(nn.Module): |
|
|
""" |
|
|
A 3D ResNet block with FiLM conditioning. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, in_channels: int, out_channels: int, embedding_dim: int, context_frames: int |
|
|
): |
|
|
super().__init__() |
|
|
self.context_frames = context_frames |
|
|
|
|
|
self.conv1 = nn.Conv3d( |
|
|
in_channels, |
|
|
out_channels, |
|
|
kernel_size=(3, 3, 3), |
|
|
padding=(1, 1, 1), |
|
|
bias=False, |
|
|
) |
|
|
self.bn1 = nn.Identity() |
|
|
self.relu = nn.ReLU(inplace=True) |
|
|
self.conv2 = nn.Conv3d( |
|
|
out_channels, |
|
|
out_channels, |
|
|
kernel_size=(3, 3, 3), |
|
|
padding=(1, 1, 1), |
|
|
bias=False, |
|
|
) |
|
|
self.bn2 = nn.InstanceNorm3d(out_channels, affine=True) |
|
|
|
|
|
self.film = FilmLayer(embedding_dim, out_channels) |
|
|
|
|
|
self.shortcut = nn.Sequential() |
|
|
if in_channels != out_channels: |
|
|
self.shortcut = nn.Sequential( |
|
|
nn.Conv3d(in_channels, out_channels, kernel_size=1, bias=False), |
|
|
nn.InstanceNorm3d(out_channels, affine=True), |
|
|
) |
|
|
|
|
|
def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor: |
|
|
h = self.relu(self.bn1(self.conv1(x))) |
|
|
|
|
|
|
|
|
h_context = h[:, :, : self.context_frames, :, :] |
|
|
h_noisy = h[:, :, self.context_frames :, :, :] |
|
|
|
|
|
h_noisy_filmed = self.film(h_noisy, context) |
|
|
|
|
|
h = torch.cat([h_context, h_noisy_filmed], dim=2) |
|
|
|
|
|
h = self.bn2(self.conv2(h)) |
|
|
return self.relu(h + self.shortcut(x)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class UNet_DCAE_3D(nn.Module): |
|
|
""" |
|
|
A 3D U-Net architecture that only performs spatial down/up-sampling, with FiLM conditioning. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
in_channels: int = 1, |
|
|
out_channels: int = 1, |
|
|
features: List[int] = [32, 64, 128, 256], |
|
|
context_dim: int = 4, |
|
|
embedding_dim: int = 128, |
|
|
context_frames: int = 4, |
|
|
num_additional_resnet_blocks: int = 0, |
|
|
time_emb_dim: int = 64, |
|
|
): |
|
|
super().__init__() |
|
|
self.features = features |
|
|
self.context_dim = context_dim |
|
|
self.embedding_dim = embedding_dim |
|
|
self.context_frames = context_frames |
|
|
self.num_additional_resnet_blocks = num_additional_resnet_blocks |
|
|
self.time_emb_dim = time_emb_dim |
|
|
|
|
|
|
|
|
time_mlp_input_dim = context_dim - 1 + self.time_emb_dim |
|
|
self.time_mlp = nn.Sequential( |
|
|
nn.Linear(time_mlp_input_dim, 128), nn.ReLU(), nn.Linear(128, embedding_dim) |
|
|
) |
|
|
self.time_emb = SinusoidalPosEmb(dim=self.time_emb_dim) |
|
|
|
|
|
self.encoder_convs = nn.ModuleList() |
|
|
self.decoder_convs = nn.ModuleList() |
|
|
self.downs = nn.ModuleList() |
|
|
|
|
|
|
|
|
current_channels = in_channels |
|
|
for feature in features: |
|
|
self.encoder_convs.append( |
|
|
ResNetBlock3D( |
|
|
current_channels, feature * 2, embedding_dim, self.context_frames |
|
|
) |
|
|
) |
|
|
self.downs.append(nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2))) |
|
|
current_channels = feature * 2 |
|
|
|
|
|
|
|
|
bottleneck_channels = features[-1] * 2 |
|
|
self.bottleneck = ResNetBlock3D( |
|
|
bottleneck_channels, bottleneck_channels, embedding_dim, self.context_frames |
|
|
) |
|
|
|
|
|
|
|
|
for feature in reversed(features): |
|
|
self.decoder_convs.append( |
|
|
ResNetBlock3D(feature * 4, feature, embedding_dim, self.context_frames) |
|
|
) |
|
|
|
|
|
self.additional_resnet_blocks = nn.ModuleList() |
|
|
for feature in reversed(features): |
|
|
blocks = nn.ModuleList() |
|
|
for _ in range(self.num_additional_resnet_blocks): |
|
|
blocks.append( |
|
|
ResNetBlock3D(feature, feature, embedding_dim, self.context_frames) |
|
|
) |
|
|
self.additional_resnet_blocks.append(blocks) |
|
|
|
|
|
|
|
|
self.final_conv = nn.Conv3d( |
|
|
features[0], out_channels, kernel_size=(1, 1, 1) |
|
|
) |
|
|
|
|
|
def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: |
|
|
time_val = t[:, -1] |
|
|
emb = self.time_emb(time_val) |
|
|
spatial = t[:, :-1] |
|
|
combined = torch.cat([spatial, emb], dim=1) |
|
|
context = self.time_mlp(combined) |
|
|
skip_connections = [] |
|
|
|
|
|
|
|
|
for i in range(len(self.features)): |
|
|
|
|
|
x = self.encoder_convs[i](x, context) |
|
|
skip_connections.append(x) |
|
|
x = self.downs[i](x) |
|
|
|
|
|
|
|
|
x = self.bottleneck(x, context) |
|
|
|
|
|
|
|
|
skip_connections = skip_connections[::-1] |
|
|
for i in range(len(self.decoder_convs)): |
|
|
|
|
|
x = F.interpolate(x, scale_factor=(1, 2, 2), mode='nearest') |
|
|
skip_connection = skip_connections[i] |
|
|
|
|
|
if x.shape != skip_connection.shape: |
|
|
x = F.interpolate(x, size=skip_connection.shape[2:]) |
|
|
|
|
|
concat_skip = torch.cat((skip_connection, x), dim=1) |
|
|
x = self.decoder_convs[i](concat_skip, context) |
|
|
|
|
|
for block in self.additional_resnet_blocks[i]: |
|
|
x = block(x, context) |
|
|
|
|
|
return self.final_conv(x) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
print( |
|
|
"--- Testing Full 3D U-Net with DC-AE, ResNet Blocks, and FiLM conditioning ---" |
|
|
) |
|
|
|
|
|
|
|
|
CONTEXT_FRAMES = 4 |
|
|
IMG_DEPTH = CONTEXT_FRAMES + 2 |
|
|
IMG_HEIGHT, IMG_WIDTH = 128, 128 |
|
|
IN_CHANNELS = 3 |
|
|
OUT_CHANNELS = 3 |
|
|
BATCH_SIZE = 2 |
|
|
CONTEXT_DIM = 128 |
|
|
|
|
|
|
|
|
input_tensor = torch.randn( |
|
|
BATCH_SIZE, IN_CHANNELS, IMG_DEPTH, IMG_HEIGHT, IMG_WIDTH |
|
|
) |
|
|
t = torch.rand(BATCH_SIZE, CONTEXT_DIM) |
|
|
print(f"Input shape: {input_tensor.shape}") |
|
|
print(f"Time shape: {t.shape}") |
|
|
|
|
|
|
|
|
model = UNet_DCAE_3D( |
|
|
in_channels=IN_CHANNELS, |
|
|
out_channels=OUT_CHANNELS, |
|
|
features=[64, 128, 256], |
|
|
context_dim=CONTEXT_DIM, |
|
|
embedding_dim=128, |
|
|
context_frames=CONTEXT_FRAMES, |
|
|
num_additional_resnet_blocks=3 |
|
|
) |
|
|
|
|
|
|
|
|
output_tensor = model(input_tensor, t) |
|
|
|
|
|
print(f"Output shape: {output_tensor.shape}") |
|
|
|
|
|
|
|
|
expected_shape = (BATCH_SIZE, OUT_CHANNELS, IMG_DEPTH, IMG_HEIGHT, IMG_WIDTH) |
|
|
assert output_tensor.shape == expected_shape, ( |
|
|
f"Shape mismatch! Expected {expected_shape}, got {output_tensor.shape}" |
|
|
) |
|
|
|
|
|
print("✅ 3D U-Net model shape test PASSED.") |
|
|
|
|
|
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
print(f"Total trainable parameters: {num_params:,}") |
|
|
|