| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """VAE model for WorldEngine frame encoding/decoding.""" |
|
|
| from dataclasses import dataclass |
| from typing import List, Tuple |
|
|
| import torch |
| from torch import Tensor |
|
|
| from diffusers.configuration_utils import ConfigMixin, register_to_config |
| from diffusers.models.modeling_utils import ModelMixin |
| from .dcae import Encoder, Decoder |
|
|
|
|
| @dataclass |
| class EncoderDecoderConfig: |
| """Config object for Encoder/Decoder initialization.""" |
|
|
| sample_size: Tuple[int, int] |
| channels: int |
| latent_channels: int |
| ch_0: int |
| ch_max: int |
| encoder_blocks_per_stage: List[int] |
| decoder_blocks_per_stage: List[int] |
| use_middle_block: bool |
| skip_logvar: bool = False |
| skip_residuals: bool = False |
| normalize_mu: bool = False |
|
|
|
|
| class WorldEngineVAE(ModelMixin, ConfigMixin): |
| """ |
| VAE for encoding/decoding video frames using DCAE architecture. |
| |
| Encodes RGB uint8 images to latent space and decodes latents back to RGB. |
| """ |
|
|
| _supports_gradient_checkpointing = False |
|
|
| @register_to_config |
| def __init__( |
| self, |
| |
| sample_size: Tuple[int, int] = (360, 640), |
| channels: int = 3, |
| latent_channels: int = 16, |
| |
| encoder_ch_0: int = 64, |
| encoder_ch_max: int = 256, |
| encoder_blocks_per_stage: List[int] = None, |
| |
| decoder_ch_0: int = 128, |
| decoder_ch_max: int = 1024, |
| decoder_blocks_per_stage: List[int] = None, |
| |
| use_middle_block: bool = False, |
| skip_logvar: bool = False, |
| |
| scale_factor: float = 1.0, |
| shift_factor: float = 0.0, |
| ): |
| super().__init__() |
|
|
| |
| if encoder_blocks_per_stage is None: |
| encoder_blocks_per_stage = [1, 1, 1, 1] |
| if decoder_blocks_per_stage is None: |
| decoder_blocks_per_stage = [1, 1, 1, 1] |
|
|
| |
| encoder_config = EncoderDecoderConfig( |
| sample_size=tuple(sample_size), |
| channels=channels, |
| latent_channels=latent_channels, |
| ch_0=encoder_ch_0, |
| ch_max=encoder_ch_max, |
| encoder_blocks_per_stage=list(encoder_blocks_per_stage), |
| decoder_blocks_per_stage=list(decoder_blocks_per_stage), |
| use_middle_block=use_middle_block, |
| skip_logvar=skip_logvar, |
| ) |
|
|
| |
| decoder_config = EncoderDecoderConfig( |
| sample_size=tuple(sample_size), |
| channels=channels, |
| latent_channels=latent_channels, |
| ch_0=decoder_ch_0, |
| ch_max=decoder_ch_max, |
| encoder_blocks_per_stage=list(encoder_blocks_per_stage), |
| decoder_blocks_per_stage=list(decoder_blocks_per_stage), |
| use_middle_block=use_middle_block, |
| skip_logvar=skip_logvar, |
| ) |
|
|
| self.encoder = Encoder(encoder_config) |
| self.decoder = Decoder(decoder_config) |
|
|
| def encode(self, img: Tensor): |
| """RGB -> RGB+D -> latent""" |
| assert img.dim() == 3, "Expected [H, W, C] image tensor" |
| img = img.unsqueeze(0).to(device=self.device, dtype=self.dtype) |
| rgb = img.permute(0, 3, 1, 2).contiguous().div(255).mul(2).sub(1) |
| return self.encoder(rgb) |
|
|
| @torch.compile |
| def decode(self, latent: Tensor): |
| decoded = self.decoder(latent) |
| decoded = (decoded / 2 + 0.5).clamp(0, 1) |
| decoded = (decoded * 255).round().to(torch.uint8) |
| return decoded.squeeze(0).permute(1, 2, 0)[..., :3] |
|
|
| def forward(self, x: Tensor, encode: bool = True) -> Tensor: |
| """ |
| Forward pass - encode or decode based on flag. |
| |
| Args: |
| x: Input tensor (image for encode, latent for decode) |
| encode: If True, encode; if False, decode |
| |
| Returns: |
| Encoded latent or decoded image |
| """ |
| if encode: |
| return self.encode(x) |
| else: |
| return self.decode(x) |
|
|