Spaces:
Runtime error
Runtime error
| """The final fusion stage for the film_net frame interpolator. | |
| The inputs to this module are the warped input images, image features and | |
| flow fields, all aligned to the target frame (often midway point between the | |
| two original inputs). The output is the final image. FILM has no explicit | |
| occlusion handling -- instead using the abovementioned information this module | |
| automatically decides how to best blend the inputs together to produce content | |
| in areas where the pixels can only be borrowed from one of the inputs. | |
| Similarly, this module also decides on how much to blend in each input in case | |
| of fractional timestep that is not at the halfway point. For example, if the two | |
| inputs images are at t=0 and t=1, and we were to synthesize a frame at t=0.1, | |
| it often makes most sense to favor the first input. However, this is not | |
| always the case -- in particular in occluded pixels. | |
| The architecture of the Fusion module follows U-net [1] architecture's decoder | |
| side, e.g. each pyramid level consists of concatenation with upsampled coarser | |
| level output, and two 3x3 convolutions. | |
| The upsampling is implemented as 'resize convolution', e.g. nearest neighbor | |
| upsampling followed by 2x2 convolution as explained in [2]. The classic U-net | |
| uses max-pooling which has a tendency to create checkerboard artifacts. | |
| [1] Ronneberger et al. U-Net: Convolutional Networks for Biomedical Image | |
| Segmentation, 2015, https://arxiv.org/pdf/1505.04597.pdf | |
| [2] https://distill.pub/2016/deconv-checkerboard/ | |
| """ | |
| from typing import List | |
| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| from util import Conv2d | |
| _NUMBER_OF_COLOR_CHANNELS = 3 | |
| def get_channels_at_level(level, filters): | |
| n_images = 2 | |
| channels = _NUMBER_OF_COLOR_CHANNELS | |
| flows = 2 | |
| return (sum(filters << i for i in range(level)) + channels + flows) * n_images | |
| class Fusion(nn.Module): | |
| """The decoder.""" | |
| def __init__(self, n_layers=4, specialized_layers=3, filters=64): | |
| """ | |
| Args: | |
| m: specialized levels | |
| """ | |
| super().__init__() | |
| # The final convolution that outputs RGB: | |
| self.output_conv = nn.Conv2d(filters, 3, kernel_size=1) | |
| # Each item 'convs[i]' will contain the list of convolutions to be applied | |
| # for pyramid level 'i'. | |
| self.convs = nn.ModuleList() | |
| # Create the convolutions. Roughly following the feature extractor, we | |
| # double the number of filters when the resolution halves, but only up to | |
| # the specialized_levels, after which we use the same number of filters on | |
| # all levels. | |
| # | |
| # We create the convs in fine-to-coarse order, so that the array index | |
| # for the convs will correspond to our normal indexing (0=finest level). | |
| # in_channels: tuple = (128, 202, 256, 522, 512, 1162, 1930, 2442) | |
| in_channels = get_channels_at_level(n_layers, filters) | |
| increase = 0 | |
| for i in range(n_layers)[::-1]: | |
| num_filters = (filters << i) if i < specialized_layers else (filters << specialized_layers) | |
| convs = nn.ModuleList([ | |
| Conv2d(in_channels, num_filters, size=2, activation=None), | |
| Conv2d(in_channels + (increase or num_filters), num_filters, size=3), | |
| Conv2d(num_filters, num_filters, size=3)] | |
| ) | |
| self.convs.append(convs) | |
| in_channels = num_filters | |
| increase = get_channels_at_level(i, filters) - num_filters // 2 | |
| def forward(self, pyramid: List[torch.Tensor]) -> torch.Tensor: | |
| """Runs the fusion module. | |
| Args: | |
| pyramid: The input feature pyramid as list of tensors. Each tensor being | |
| in (B x H x W x C) format, with finest level tensor first. | |
| Returns: | |
| A batch of RGB images. | |
| Raises: | |
| ValueError, if len(pyramid) != config.fusion_pyramid_levels as provided in | |
| the constructor. | |
| """ | |
| # As a slight difference to a conventional decoder (e.g. U-net), we don't | |
| # apply any extra convolutions to the coarsest level, but just pass it | |
| # to finer levels for concatenation. This choice has not been thoroughly | |
| # evaluated, but is motivated by the educated guess that the fusion part | |
| # probably does not need large spatial context, because at this point the | |
| # features are spatially aligned by the preceding warp. | |
| net = pyramid[-1] | |
| # Loop starting from the 2nd coarsest level: | |
| # for i in reversed(range(0, len(pyramid) - 1)): | |
| for k, layers in enumerate(self.convs): | |
| i = len(self.convs) - 1 - k | |
| # Resize the tensor from coarser level to match for concatenation. | |
| level_size = pyramid[i].shape[2:4] | |
| net = F.interpolate(net, size=level_size, mode='nearest') | |
| net = layers[0](net) | |
| net = torch.cat([pyramid[i], net], dim=1) | |
| net = layers[1](net) | |
| net = layers[2](net) | |
| net = self.output_conv(net) | |
| return net | |