| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | class ResidualBlock(nn.Module): |
| | def __init__(self, channels): |
| | super(ResidualBlock, self).__init__() |
| | self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) |
| | self.bn1 = nn.BatchNorm2d(channels) |
| | self.relu = nn.ReLU(inplace=True) |
| | self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) |
| | self.bn2 = nn.BatchNorm2d(channels) |
| |
|
| | def forward(self, x): |
| | residual = x |
| | out = self.conv1(x) |
| | out = self.bn1(out) |
| | out = self.relu(out) |
| | out = self.conv2(out) |
| | out = self.bn2(out) |
| | out += residual |
| | out = self.relu(out) |
| | return out |
| |
|
| | class Encoder(nn.Module): |
| | def __init__(self, input_channels=1, hidden_dims=[64, 128, 256, 512, 1024], latent_dim=32): |
| | super(Encoder, self).__init__() |
| | self.hidden_dims = hidden_dims |
| |
|
| | |
| | modules = [] |
| | for h_dim in hidden_dims: |
| | modules.append( |
| | nn.Sequential( |
| | nn.Conv2d(input_channels, h_dim, kernel_size=3, stride=2, padding=1), |
| | nn.BatchNorm2d(h_dim), |
| | nn.LeakyReLU(), |
| | ResidualBlock(h_dim) |
| | ) |
| | ) |
| | input_channels = h_dim |
| |
|
| | self.encoder = nn.Sequential(*modules) |
| | self.fc_mu = nn.Linear(hidden_dims[-1]*hidden_dims[-3], latent_dim) |
| | self.fc_var = nn.Linear(hidden_dims[-1]*hidden_dims[-3], latent_dim) |
| |
|
| | def forward(self, x): |
| | for layer in self.encoder: |
| | x = layer(x) |
| | x = torch.flatten(x, start_dim=1) |
| | mu = self.fc_mu(x) |
| | log_var = self.fc_var(x) |
| | return mu, log_var |
| |
|
| | class Decoder(nn.Module): |
| | def __init__(self, latent_dim=32, output_channels=1, hidden_dims=[64, 128, 256, 512, 1024]): |
| | super(Decoder, self).__init__() |
| | self.hidden_dims = hidden_dims |
| | |
| | hidden_dims = hidden_dims[::-1] |
| | self.decoder_input = nn.Linear(latent_dim, hidden_dims[0]*hidden_dims[2]) |
| |
|
| | |
| | modules = [] |
| | for i in range(len(hidden_dims) - 1): |
| | modules.append( |
| | nn.Sequential( |
| | nn.ConvTranspose2d(hidden_dims[i], hidden_dims[i+1], kernel_size=3, stride=2, padding=1, output_padding=1), |
| | nn.BatchNorm2d(hidden_dims[i+1]), |
| | nn.LeakyReLU(), |
| | ResidualBlock(hidden_dims[i+1]) |
| | ) |
| | ) |
| |
|
| | self.decoder = nn.Sequential(*modules) |
| | self.final_layer = nn.Sequential( |
| | nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1], kernel_size=3, stride=2, padding=1, output_padding=1), |
| | nn.BatchNorm2d(hidden_dims[-1]), |
| | nn.LeakyReLU(), |
| | nn.Conv2d(hidden_dims[-1], output_channels, kernel_size=3, padding=1), |
| | nn.Sigmoid() |
| | ) |
| |
|
| | def forward(self, z): |
| | z = self.decoder_input(z) |
| | z = z.view(-1, 1024, 16, 16) |
| | for layer in self.decoder: |
| | z = layer(z) |
| | result = self.final_layer(z) |
| | return result |
| |
|
| | class VAE(nn.Module): |
| | def __init__(self, |
| | input_channels=1, |
| | latent_dim=32, |
| | hidden_dims=None): |
| | super(VAE, self).__init__() |
| |
|
| | if hidden_dims is None: |
| | hidden_dims = [64, 128, 256, 512, 1024] |
| |
|
| | self.encoder = Encoder(input_channels=input_channels, |
| | hidden_dims=hidden_dims, |
| | latent_dim=latent_dim) |
| |
|
| | self.decoder = Decoder(latent_dim=latent_dim, |
| | output_channels=input_channels, |
| | hidden_dims=hidden_dims) |
| |
|
| | def encode(self, input): |
| | mu, log_var = self.encoder(input) |
| | return mu, log_var |
| |
|
| | def reparameterize(self, mu, log_var): |
| | std = torch.exp(0.5 * log_var) |
| | eps = torch.randn_like(std) |
| | return mu + eps * std |
| |
|
| | def decode(self, z): |
| | return self.decoder(z) |
| |
|
| | def forward(self, input): |
| | mu, log_var = self.encode(input) |
| | z = self.reparameterize(mu, log_var) |
| | return self.decode(z), mu, log_var |
| |
|
| | |
| | def loss_function(recon_x, x, mu, log_var): |
| | BCE = F.binary_cross_entropy(recon_x, x, reduction='sum') |
| | KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) |
| | return BCE + KLD |