| """ |
| BEATRIX FLOW-MATCHING - CIFAR-10 (T5 Text Encoder) |
| =================================================== |
| |
| SD 1.5 VAE + Flan-T5-Large text encoder |
| Dual tower collectives: vision towers + text towers |
| |
| Text prompts for CIFAR-10 classes: |
| "a photo of an airplane" |
| "a photo of an automobile" |
| etc. |
| |
| Requirements: |
| pip install transformers diffusers torchvision tqdm |
| pip install git+https://github.com/AbstractEyes/geofractal |
| |
| Currently running like a turtle, will optimize tomorrow. |
| |
| apache 2.0 license |
| """ |
|
|
| from __future__ import annotations |
|
|
| import math |
| from dataclasses import dataclass |
| from typing import Dict, Tuple, Optional, List |
| from pathlib import Path |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch import Tensor |
| from torch.utils.data import DataLoader, Dataset |
| from torchvision import datasets, transforms |
| from torchvision.utils import make_grid, save_image |
| from huggingface_hub import HfApi, upload_file, create_repo |
| import json |
| from tqdm import tqdm |
|
|
| |
| |
| |
|
|
| from geofractal.router.wide_router import WideRouter |
| from geofractal.router.prefab.agatha.beatrix_tension_oscillator import ( |
| BeatrixOscillator, |
| ScheduleType, |
| ) |
| from geofractal.router.prefab.geometric_tower_builder import ( |
| TowerConfig, |
| FusionType, |
| ConfigurableCollective, |
| build_tower_collective, |
| preset_pos_neg_pairs, |
| ) |
| from geofractal.router.prefab.geometric_conv_tower_builder import ( |
| ConvTowerConfig, |
| ConvTowerCollective, |
| build_conv_collective, |
| preset_conv_pos_neg, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| CIFAR10_PROMPTS = [ |
| "a photo of an airplane", |
| "a photo of an automobile", |
| "a photo of a bird", |
| "a photo of a cat", |
| "a photo of a deer", |
| "a photo of a dog", |
| "a photo of a frog", |
| "a photo of a horse", |
| "a photo of a ship", |
| "a photo of a truck", |
| ] |
|
|
|
|
| |
| |
| |
|
|
| class SD15VAE(nn.Module): |
| def __init__(self, freeze: bool = True): |
| super().__init__() |
| from diffusers import AutoencoderKL |
| |
| self.vae = AutoencoderKL.from_pretrained( |
| "runwayml/stable-diffusion-v1-5", |
| subfolder="vae", |
| torch_dtype=torch.float32, |
| ) |
| |
| if freeze: |
| self.vae.eval() |
| for p in self.vae.parameters(): |
| p.requires_grad = False |
| |
| self.scale_factor = 0.18215 |
| |
| @torch.no_grad() |
| def encode(self, x: Tensor) -> Tensor: |
| return self.vae.encode(x).latent_dist.sample() * self.scale_factor |
| |
| @torch.no_grad() |
| def decode(self, z: Tensor) -> Tensor: |
| return self.vae.decode(z / self.scale_factor).sample |
|
|
|
|
| |
| |
| |
|
|
| class T5TextEncoder(nn.Module): |
| """Flan-T5 encoder with bottleneck projection.""" |
| |
| def __init__( |
| self, |
| model_name: str = "google/flan-t5-xl", |
| freeze: bool = True, |
| max_length: int = 77, |
| bottleneck_dim: int = 256, |
| ): |
| super().__init__() |
| from transformers import T5EncoderModel, T5Tokenizer |
| |
| self.tokenizer = T5Tokenizer.from_pretrained(model_name) |
| self.encoder = T5EncoderModel.from_pretrained(model_name) |
| self.max_length = max_length |
| self.raw_dim = self.encoder.config.d_model |
| self.output_dim = bottleneck_dim |
| |
| |
| self.bottleneck = nn.Sequential( |
| nn.Linear(self.raw_dim, bottleneck_dim), |
| nn.GELU(), |
| nn.Linear(bottleneck_dim, bottleneck_dim), |
| ) |
| |
| if freeze: |
| self.encoder.eval() |
| for p in self.encoder.parameters(): |
| p.requires_grad = False |
| |
| |
| @torch.no_grad() |
| def forward(self, texts: List[str], device: torch.device) -> Tuple[Tensor, Tensor]: |
| """ |
| Encode text prompts with bottleneck. |
| |
| Returns: |
| sequence: [B, L, bottleneck_dim] - compressed sequence embeddings |
| pooled: [B, bottleneck_dim] - compressed mean pooled embedding |
| """ |
| tokens = self.tokenizer( |
| texts, |
| padding="max_length", |
| max_length=self.max_length, |
| truncation=True, |
| return_tensors="pt", |
| ) |
| |
| input_ids = tokens.input_ids.to(device) |
| attention_mask = tokens.attention_mask.to(device) |
| |
| outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) |
| sequence_raw = outputs.last_hidden_state |
| |
| |
| sequence = self.bottleneck(sequence_raw) |
| |
| |
| mask_expanded = attention_mask.unsqueeze(-1).float() |
| pooled = (sequence * mask_expanded).sum(dim=1) / mask_expanded.sum(dim=1) |
| |
| return sequence, pooled |
| |
| @torch.no_grad() |
| def encode_raw(self, texts: List[str], device: torch.device) -> Tuple[Tensor, Tensor]: |
| """ |
| Encode text prompts WITHOUT bottleneck (for caching raw embeddings). |
| |
| Returns: |
| sequence: [B, L, raw_dim] - raw T5 embeddings |
| pooled: [B, raw_dim] - raw mean pooled embedding |
| """ |
| tokens = self.tokenizer( |
| texts, |
| padding="max_length", |
| max_length=self.max_length, |
| truncation=True, |
| return_tensors="pt", |
| ) |
| |
| input_ids = tokens.input_ids.to(device) |
| attention_mask = tokens.attention_mask.to(device) |
| |
| outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) |
| sequence = outputs.last_hidden_state |
| |
| |
| mask_expanded = attention_mask.unsqueeze(-1).float() |
| pooled = (sequence * mask_expanded).sum(dim=1) / mask_expanded.sum(dim=1) |
| |
| return sequence, pooled |
|
|
|
|
| |
| |
| |
|
|
| class CachedCIFAR10T5(Dataset): |
| """ |
| Pre-cached CIFAR-10 with VAE latents. |
| T5 embeddings are computed per-class (not per-image). |
| """ |
| |
| T5_MODEL = "google/flan-t5-xl" |
| |
| def __init__( |
| self, |
| train: bool = True, |
| image_size: int = 256, |
| cache_dir: str = "./cache", |
| device: str = "cuda", |
| ): |
| self.train = train |
| |
| t5_suffix = self.T5_MODEL.replace("/", "_") |
| self.cache_path = Path(cache_dir) / f"cifar10_{t5_suffix}_{'train' if train else 'val'}_{image_size}.pt" |
| |
| if self.cache_path.exists(): |
| print(f"Loading cache: {self.cache_path}") |
| cache = torch.load(self.cache_path, weights_only=False) |
| self.latents = cache['latents'] |
| self.labels = cache['labels'] |
| self.text_sequence = cache['text_sequence'] |
| self.text_pooled = cache['text_pooled'] |
| self.text_dim = cache.get('text_dim', self.text_pooled.shape[-1]) |
| else: |
| print(f"Building cache for {'train' if train else 'val'} set...") |
| self._build_cache(image_size, device) |
| |
| def _build_cache(self, image_size: int, device: str): |
| |
| print(" Loading VAE...") |
| vae = SD15VAE(freeze=True).to(device) |
| print(f" Loading T5 ({self.T5_MODEL})...") |
| t5 = T5TextEncoder(model_name=self.T5_MODEL, freeze=True).to(device) |
| |
| |
| print(f" Encoding text prompts (T5 raw_dim={t5.raw_dim})...") |
| text_seq, text_pool = t5.encode_raw(CIFAR10_PROMPTS, device) |
| self.text_sequence = text_seq.cpu() |
| self.text_pooled = text_pool.cpu() |
| self.text_dim = t5.raw_dim |
| |
| |
| transform = transforms.Compose([ |
| transforms.Resize((image_size, image_size)), |
| transforms.ToTensor(), |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), |
| ]) |
| |
| dataset = datasets.CIFAR10('./data', train=self.train, download=True, transform=transform) |
| loader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=4, pin_memory=True) |
| |
| all_latents, all_labels = [], [] |
| |
| print(" Encoding images...") |
| with torch.no_grad(): |
| for images, labels in tqdm(loader, desc=" Caching", leave=False): |
| images = images.to(device) |
| all_latents.append(vae.encode(images).cpu()) |
| all_labels.append(labels) |
| |
| self.latents = torch.cat(all_latents, dim=0) |
| self.labels = torch.cat(all_labels, dim=0) |
| |
| del vae, t5 |
| torch.cuda.empty_cache() |
| |
| |
| self.cache_path.parent.mkdir(parents=True, exist_ok=True) |
| torch.save({ |
| 'latents': self.latents, |
| 'labels': self.labels, |
| 'text_sequence': self.text_sequence, |
| 'text_pooled': self.text_pooled, |
| 'text_dim': self.text_dim, |
| }, self.cache_path) |
| print(f" Saved: {self.cache_path}") |
| |
| def __len__(self): |
| return len(self.labels) |
| |
| def __getitem__(self, idx): |
| label = self.labels[idx] |
| return ( |
| self.latents[idx], |
| self.text_sequence[label], |
| self.text_pooled[label], |
| label, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| class SinusoidalEmbed(nn.Module): |
| def __init__(self, dim: int): |
| super().__init__() |
| self.dim = dim |
|
|
| def forward(self, t: Tensor) -> Tensor: |
| half = self.dim // 2 |
| freqs = torch.exp(-math.log(10000) * torch.arange(half, device=t.device) / half) |
| args = t.unsqueeze(-1) * freqs |
| return torch.cat([torch.cos(args), torch.sin(args)], dim=-1) |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class FlowConfig: |
| image_size: int = 256 |
| num_classes: int = 10 |
| latent_channels: int = 4 |
| latent_size: int = 32 |
| |
| |
| text_raw_dim: int = 2048 |
| text_seq_len: int = 77 |
| bottleneck_dim: int = 256 |
| |
| |
| tower_dim: int = 256 |
| tower_depth: int = 2 |
| num_heads: int = 8 |
| geometric_types: Tuple[str, ...] = ('cantor', 'beatrix', 'helix', 'simplex') |
| |
| |
| conv_types: Tuple[str, ...] = ('wide_resnet', 'frequency', 'bottleneck', 'squeeze_excite') |
| conv_spatial_size: int = 8 |
| |
| |
| manifold_dim: int = 1024 |
| num_tower_pairs: int = 16 |
| osc_steps: int = 50 |
| fingerprint_dim: int = 64 |
| |
| |
| num_flow_steps: int = 50 |
| sigma_min: float = 0.001 |
| |
| |
| batch_size: int = 64 |
| lr: float = 1e-4 |
| weight_decay: float = 0.01 |
| num_epochs: int = 100 |
| |
| cache_dir: str = "./cache" |
| device: str = "cuda" |
| output_dir: str = "./beatrix_cifar_t5" |
| |
| @property |
| def latent_flat_dim(self) -> int: |
| """Full flattened latent size: 4 Γ 32 Γ 32 = 4096""" |
| return self.latent_channels * self.latent_size * self.latent_size |
|
|
|
|
| |
| |
| |
|
|
| class BeatrixFlowT5(WideRouter): |
| """ |
| Flow model with dual tower collectives per modality: |
| |
| Vision side: |
| - Geometric towers (transformer): cantor, beatrix, helix, simplex (pos/neg) |
| - Conv towers: wide_resnet, frequency, bottleneck, squeeze_excite (pos/neg) |
| |
| Text side (mirrored): |
| - Geometric towers (transformer): cantor, beatrix, helix, simplex (pos/neg) |
| - Conv towers: wide_resnet, frequency, bottleneck, squeeze_excite (pos/neg) |
| |
| All towers output opinions that combine for velocity prediction. |
| """ |
| |
| def __init__(self, cfg: FlowConfig): |
| super().__init__(name='beatrix_flow_t5', strict=False, auto_discover=False) |
| self.objects['cfg'] = cfg |
| |
| |
| |
| |
| self.attach('text_bottleneck_seq', nn.Sequential( |
| nn.Linear(cfg.text_raw_dim, cfg.bottleneck_dim), |
| nn.GELU(), |
| nn.Linear(cfg.bottleneck_dim, cfg.bottleneck_dim), |
| )) |
| self.attach('text_bottleneck_pool', nn.Sequential( |
| nn.Linear(cfg.text_raw_dim, cfg.bottleneck_dim), |
| nn.GELU(), |
| nn.Linear(cfg.bottleneck_dim, cfg.bottleneck_dim), |
| )) |
| |
| |
| |
| |
| vision_geo_configs = preset_pos_neg_pairs(list(cfg.geometric_types)) |
| |
| vision_geo_collective = build_tower_collective( |
| configs=vision_geo_configs, |
| dim=cfg.tower_dim, |
| default_depth=cfg.tower_depth, |
| num_heads=cfg.num_heads, |
| ffn_mult=4.0, |
| dropout=0.1, |
| fingerprint_dim=cfg.fingerprint_dim, |
| fusion_type='adaptive', |
| name='vision_geo', |
| ) |
| self.attach('vision_geo', vision_geo_collective) |
| |
| |
| |
| |
| vision_conv_configs = preset_conv_pos_neg(list(cfg.conv_types)) |
| |
| vision_conv_collective = build_conv_collective( |
| configs=vision_conv_configs, |
| dim=cfg.tower_dim, |
| default_depth=cfg.tower_depth, |
| fingerprint_dim=cfg.fingerprint_dim, |
| spatial_size=cfg.conv_spatial_size, |
| name='vision_conv', |
| ) |
| self.attach('vision_conv', vision_conv_collective) |
| |
| |
| |
| |
| text_geo_configs = preset_pos_neg_pairs(list(cfg.geometric_types)) |
| |
| text_geo_collective = build_tower_collective( |
| configs=text_geo_configs, |
| dim=cfg.tower_dim, |
| default_depth=cfg.tower_depth, |
| num_heads=cfg.num_heads, |
| ffn_mult=4.0, |
| dropout=0.1, |
| fingerprint_dim=cfg.fingerprint_dim, |
| fusion_type='adaptive', |
| name='text_geo', |
| ) |
| self.attach('text_geo', text_geo_collective) |
| |
| |
| |
| |
| text_conv_configs = preset_conv_pos_neg(list(cfg.conv_types)) |
| |
| text_conv_collective = build_conv_collective( |
| configs=text_conv_configs, |
| dim=cfg.tower_dim, |
| default_depth=cfg.tower_depth, |
| fingerprint_dim=cfg.fingerprint_dim, |
| spatial_size=cfg.conv_spatial_size, |
| name='text_conv', |
| ) |
| self.attach('text_conv', text_conv_collective) |
| |
| |
| |
| |
| |
| patch_size = 4 |
| num_patches = (cfg.latent_size // patch_size) ** 2 |
| patch_dim = cfg.latent_channels * patch_size * patch_size |
| |
| self.attach('patch_proj', nn.Linear(patch_dim, cfg.tower_dim)) |
| self.patch_pos_embed = nn.Parameter(torch.randn(1, num_patches, cfg.tower_dim) * 0.02) |
| self.objects['patch_size'] = patch_size |
| self.objects['num_patches'] = num_patches |
| |
| |
| |
| |
| |
| |
| |
| num_geo_towers = len(vision_geo_configs) |
| num_conv_towers = len(vision_conv_configs) |
| total_towers = (num_geo_towers + num_conv_towers) * 2 |
| |
| oscillator = BeatrixOscillator( |
| name='oscillator', |
| manifold_dim=cfg.manifold_dim, |
| tower_dim=cfg.tower_dim, |
| num_tower_pairs=total_towers // 2, |
| num_theta_probes=4, |
| fingerprint_dim=cfg.fingerprint_dim, |
| kappa_schedule=ScheduleType.TESLA_369, |
| use_intrinsic_tension=True, |
| ) |
| self.attach('oscillator', oscillator) |
| |
| |
| |
| |
| |
| time_embed = nn.Sequential( |
| SinusoidalEmbed(256), |
| nn.Linear(256, cfg.tower_dim), |
| nn.GELU(), |
| nn.Linear(cfg.tower_dim, cfg.tower_dim), |
| ) |
| self.attach('time_embed', time_embed) |
| |
| |
| self.attach('text_to_ref', nn.Sequential( |
| nn.Linear(cfg.bottleneck_dim, cfg.manifold_dim), |
| nn.GELU(), |
| nn.Linear(cfg.manifold_dim, cfg.manifold_dim), |
| )) |
| |
| |
| self.attach('time_to_ref', nn.Linear(cfg.tower_dim, cfg.manifold_dim)) |
| |
| |
| |
| |
| self.attach('latent_down', nn.Linear(cfg.latent_flat_dim, cfg.manifold_dim)) |
| self.attach('latent_up', nn.Linear(cfg.manifold_dim, cfg.latent_flat_dim)) |
| |
| |
| self.velocity_mix = nn.Parameter(torch.tensor(0.5)) |
| |
| def patchify(self, z: Tensor) -> Tensor: |
| """[B, 4, 32, 32] -> [B, num_patches, tower_dim]""" |
| B, C, H, W = z.shape |
| p = self.objects['patch_size'] |
| |
| z = z.unfold(2, p, p).unfold(3, p, p) |
| z = z.permute(0, 2, 3, 1, 4, 5).contiguous() |
| z = z.view(B, -1, C * p * p) |
| |
| return self['patch_proj'](z) + self.patch_pos_embed |
| |
| def get_tower_outputs(self, z: Tensor, text_seq: Tensor) -> List[Tensor]: |
| """ |
| Run all four tower collectives. |
| Returns list of tower opinions [B, tower_dim] (32 total). |
| """ |
| patches = self.patchify(z) |
| text_bottlenecked = self['text_bottleneck_seq'](text_seq) |
| |
| |
| vision_geo = self['vision_geo'](patches) |
| vision_conv_fused, vision_conv_ops = self['vision_conv'](patches) |
| text_geo = self['text_geo'](text_bottlenecked) |
| text_conv_fused, text_conv_ops = self['text_conv'](text_bottlenecked) |
| |
| |
| return ( |
| [op.opinion for op in vision_geo.opinions.values()] + |
| list(vision_conv_ops.values()) + |
| [op.opinion for op in text_geo.opinions.values()] + |
| list(text_conv_ops.values()) |
| ) |
| |
| def forward( |
| self, |
| z_0: Tensor, |
| text_seq: Tensor, |
| text_pooled: Tensor, |
| labels: Tensor, |
| t: Optional[Tensor] = None, |
| ) -> Dict[str, Tensor]: |
| """Training forward - single step velocity prediction.""" |
| cfg = self.objects['cfg'] |
| B = z_0.shape[0] |
| device = z_0.device |
| |
| if t is None: |
| t = torch.rand(B, device=device) |
| |
| |
| z_0_flat = z_0.flatten(1) |
| |
| |
| eps = torch.randn_like(z_0) |
| eps_flat = eps.flatten(1) |
| t_exp = t.view(B, 1, 1, 1) |
| z_t = (1 - t_exp) * z_0 + t_exp * eps |
| z_t_flat = z_t.flatten(1) |
| |
| |
| v_target = eps_flat - z_0_flat |
| |
| |
| z_t_proj = self['latent_down'](z_t_flat) |
| |
| |
| text_pooled_bn = self['text_bottleneck_pool'](text_pooled) |
| |
| |
| time_emb = self['time_embed'](t) |
| x_ref = self['text_to_ref'](text_pooled_bn) + self['time_to_ref'](time_emb) |
| |
| |
| tower_outputs = self.get_tower_outputs(z_t, text_seq) |
| |
| |
| osc = self['oscillator'] |
| tower_force, _ = osc.force_generator(z_t_proj, tower_outputs, state_fingerprint=None) |
| spring_force = x_ref - z_t_proj |
| |
| |
| tau = torch.sigmoid(self.velocity_mix) |
| v_pred_proj = (1 - tau) * spring_force + tau * tower_force |
| |
| |
| v_pred = self['latent_up'](v_pred_proj) |
| |
| loss = F.mse_loss(v_pred, v_target) |
| |
| return {'loss': loss, 'tau': tau.detach()} |
| |
| @torch.no_grad() |
| def sample( |
| self, |
| text_seq: Tensor, |
| text_pooled: Tensor, |
| vae: SD15VAE, |
| num_steps: Optional[int] = None, |
| ) -> Tensor: |
| """Generate samples from text conditioning.""" |
| cfg = self.objects['cfg'] |
| B = text_seq.shape[0] |
| device = text_seq.device |
| num_steps = num_steps or cfg.num_flow_steps |
| |
| |
| text_pooled_bn = self['text_bottleneck_pool'](text_pooled) |
| |
| |
| z = torch.randn(B, cfg.latent_channels, cfg.latent_size, cfg.latent_size, device=device) |
| |
| dt = 1.0 / num_steps |
| |
| for step in range(num_steps): |
| t_val = 1.0 - step * dt |
| t = torch.full((B,), t_val, device=device) |
| |
| time_emb = self['time_embed'](t) |
| x_ref = self['text_to_ref'](text_pooled_bn) + self['time_to_ref'](time_emb) |
| |
| z_flat = z.flatten(1) |
| |
| |
| z_proj = self['latent_down'](z_flat) |
| |
| tower_outputs = self.get_tower_outputs(z, text_seq) |
| |
| osc = self['oscillator'] |
| tower_force, _ = osc.force_generator(z_proj, tower_outputs, state_fingerprint=None) |
| spring_force = x_ref - z_proj |
| |
| tau = torch.sigmoid(self.velocity_mix) |
| v_pred_proj = (1 - tau) * spring_force + tau * tower_force |
| |
| |
| v_pred = self['latent_up'](v_pred_proj) |
| z_flat = z_flat - dt * v_pred |
| z = z_flat.view(B, cfg.latent_channels, cfg.latent_size, cfg.latent_size) |
| |
| return vae.decode(z) |
|
|
|
|
| |
| |
| |
|
|
| class Trainer: |
| def __init__(self, cfg: FlowConfig): |
| self.cfg = cfg |
| self.device = torch.device(cfg.device if torch.cuda.is_available() else "cpu") |
| self.output_dir = Path(cfg.output_dir) |
| self.output_dir.mkdir(parents=True, exist_ok=True) |
| |
| if torch.cuda.is_available(): |
| torch.backends.cudnn.benchmark = True |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
| |
| self.scaler = torch.amp.GradScaler('cuda') |
| |
| |
| print("\n=== Building Cached Datasets ===") |
| self.train_dataset = CachedCIFAR10T5(train=True, image_size=cfg.image_size, cache_dir=cfg.cache_dir, device=cfg.device) |
| self.val_dataset = CachedCIFAR10T5(train=False, image_size=cfg.image_size, cache_dir=cfg.cache_dir, device=cfg.device) |
| |
| |
| cfg.text_raw_dim = self.train_dataset.text_dim |
| print(f"T5 raw dimension: {cfg.text_raw_dim} β bottleneck: {cfg.bottleneck_dim}") |
| |
| self.train_loader = DataLoader(self.train_dataset, batch_size=cfg.batch_size, shuffle=True, num_workers=0, pin_memory=True, drop_last=True) |
| self.val_loader = DataLoader(self.val_dataset, batch_size=cfg.batch_size, shuffle=False, num_workers=0, pin_memory=True) |
| |
| |
| self.text_sequence = self.train_dataset.text_sequence.to(self.device) |
| self.text_pooled = self.train_dataset.text_pooled.to(self.device) |
| |
| |
| print("\n=== Building Model (Vision + Text Towers) ===") |
| self.model = BeatrixFlowT5(cfg).to(self.device) |
| |
| |
| if hasattr(torch, 'compile'): |
| print("Compiling with WideRouter.prepare_and_compile()...") |
| self.model = self.model.prepare_and_compile( |
| mode="reduce-overhead", |
| fullgraph=False, |
| ) |
| |
| num_params = sum(p.numel() for p in self.model.parameters()) |
| print(f"Trainable parameters: {num_params:,}") |
| |
| self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay) |
| self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=cfg.num_epochs * len(self.train_loader)) |
| |
| |
| self.start_epoch = 0 |
| self.hf_repo = "AbstractPhil/beatrix-diffusion-proto" |
| self._load_latest_checkpoint() |
| |
| self._vae = None |
| |
| |
| self._setup_hf_repo() |
| |
| def _setup_hf_repo(self): |
| """Create HF repo if needed and save initial config.""" |
| try: |
| self.hf_api = HfApi() |
| create_repo(self.hf_repo, exist_ok=True, repo_type="model") |
| print(f"HF repo: {self.hf_repo}") |
| |
| |
| config_dict = { |
| 'image_size': self.cfg.image_size, |
| 'num_classes': self.cfg.num_classes, |
| 'latent_channels': self.cfg.latent_channels, |
| 'latent_size': self.cfg.latent_size, |
| 'text_raw_dim': self.cfg.text_raw_dim, |
| 'bottleneck_dim': self.cfg.bottleneck_dim, |
| 'tower_dim': self.cfg.tower_dim, |
| 'tower_depth': self.cfg.tower_depth, |
| 'num_heads': self.cfg.num_heads, |
| 'geometric_types': self.cfg.geometric_types, |
| 'conv_types': self.cfg.conv_types, |
| 'conv_spatial_size': self.cfg.conv_spatial_size, |
| 'manifold_dim': self.cfg.manifold_dim, |
| 'fingerprint_dim': self.cfg.fingerprint_dim, |
| 'num_flow_steps': self.cfg.num_flow_steps, |
| } |
| config_path = self.output_dir / "config.json" |
| with open(config_path, 'w') as f: |
| json.dump(config_dict, f, indent=2) |
| |
| upload_file( |
| path_or_fileobj=str(config_path), |
| path_in_repo="config.json", |
| repo_id=self.hf_repo, |
| ) |
| except Exception as e: |
| print(f"HF setup warning: {e}") |
| self.hf_api = None |
| |
| def _upload_to_hf(self, epoch: int, sample_path: Path, metrics: dict = None): |
| """Upload checkpoint, samples, and metrics to HuggingFace.""" |
| if self.hf_api is None: |
| return |
| |
| try: |
| |
| ckpt_path = self.output_dir / "ckpt_latest.pt" |
| if ckpt_path.exists(): |
| upload_file( |
| path_or_fileobj=str(ckpt_path), |
| path_in_repo="ckpt_latest.pt", |
| repo_id=self.hf_repo, |
| ) |
| |
| |
| if sample_path.exists(): |
| upload_file( |
| path_or_fileobj=str(sample_path), |
| path_in_repo=f"samples/epoch_{epoch:03d}.png", |
| repo_id=self.hf_repo, |
| ) |
| |
| upload_file( |
| path_or_fileobj=str(sample_path), |
| path_in_repo="samples/latest.png", |
| repo_id=self.hf_repo, |
| ) |
| |
| |
| if metrics: |
| metrics_path = self.output_dir / "metrics.jsonl" |
| with open(metrics_path, 'a') as f: |
| f.write(json.dumps({'epoch': epoch, **metrics}) + '\n') |
| upload_file( |
| path_or_fileobj=str(metrics_path), |
| path_in_repo="metrics.jsonl", |
| repo_id=self.hf_repo, |
| ) |
| |
| print(f" β Uploaded to HF") |
| except Exception as e: |
| print(f" β HF upload failed: {e}") |
| |
| def _load_latest_checkpoint(self): |
| """Load most recent checkpoint if available (local or HF).""" |
| latest_path = self.output_dir / "ckpt_latest.pt" |
| |
| |
| if latest_path.exists(): |
| print(f"Resuming from local ckpt_latest.pt...") |
| ckpt = torch.load(latest_path, weights_only=False) |
| else: |
| |
| ckpts = sorted(self.output_dir.glob("ckpt_epoch*.pt")) |
| if ckpts: |
| latest_path = ckpts[-1] |
| print(f"Resuming from {latest_path.name}...") |
| ckpt = torch.load(latest_path, weights_only=False) |
| else: |
| |
| try: |
| from huggingface_hub import hf_hub_download |
| print(f"Checking HF for checkpoint...") |
| hf_path = hf_hub_download( |
| repo_id=self.hf_repo, |
| filename="ckpt_latest.pt", |
| local_dir=str(self.output_dir), |
| ) |
| print(f"Downloaded checkpoint from HF") |
| ckpt = torch.load(hf_path, weights_only=False) |
| except Exception as e: |
| print(f"No checkpoint found (local or HF): {e}") |
| return |
| |
| self.model.load_state_dict(ckpt['model']) |
| self.optimizer.load_state_dict(ckpt['optimizer']) |
| self.scheduler.load_state_dict(ckpt['scheduler']) |
| self.start_epoch = ckpt['epoch'] |
| print(f" Resumed at epoch {self.start_epoch}") |
| |
| def _load_vae(self): |
| """Load VAE for sampling (temporary).""" |
| print("Loading VAE for sampling...") |
| return SD15VAE(freeze=True).to(self.device) |
| |
| def _unload_vae(self, vae): |
| """Unload VAE after sampling.""" |
| del vae |
| torch.cuda.empty_cache() |
| |
| def train_epoch(self, epoch: int) -> Dict[str, float]: |
| self.model.train() |
| total_loss, total_tau, n = 0.0, 0.0, 0 |
| |
| pbar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{self.cfg.num_epochs}", leave=False) |
| for latents, text_seq, text_pooled, labels in pbar: |
| latents = latents.to(self.device) |
| text_seq = text_seq.to(self.device) |
| text_pooled = text_pooled.to(self.device) |
| labels = labels.to(self.device) |
| |
| with torch.amp.autocast('cuda'): |
| out = self.model(latents, text_seq, text_pooled, labels) |
| loss = out['loss'] |
| |
| self.optimizer.zero_grad() |
| self.scaler.scale(loss).backward() |
| self.scaler.unscale_(self.optimizer) |
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) |
| self.scaler.step(self.optimizer) |
| self.scaler.update() |
| self.scheduler.step() |
| |
| total_loss += loss.item() |
| total_tau += out['tau'].item() |
| n += 1 |
| |
| pbar.set_postfix(loss=f"{loss.item():.4f}", Ο=f"{out['tau'].item():.2f}") |
| |
| return {'loss': total_loss / n, 'tau': total_tau / n} |
| |
| @torch.no_grad() |
| def validate(self) -> Dict[str, float]: |
| self.model.eval() |
| total_loss, n = 0.0, 0 |
| |
| for latents, text_seq, text_pooled, labels in self.val_loader: |
| latents = latents.to(self.device) |
| text_seq = text_seq.to(self.device) |
| text_pooled = text_pooled.to(self.device) |
| labels = labels.to(self.device) |
| |
| with torch.amp.autocast('cuda'): |
| out = self.model(latents, text_seq, text_pooled, labels) |
| total_loss += out['loss'].item() |
| n += 1 |
| |
| return {'val_loss': total_loss / n} |
| |
| @torch.no_grad() |
| def sample_images(self, n_per_class: int = 10) -> Tensor: |
| """Generate samples for each class (memory-efficient batched).""" |
| self.model.eval() |
| torch.cuda.empty_cache() |
| |
| |
| vae = self._load_vae() |
| |
| all_samples = [] |
| batch_size = 10 |
| |
| for class_idx in range(10): |
| |
| for batch_start in range(0, n_per_class, batch_size): |
| batch_n = min(batch_size, n_per_class - batch_start) |
| |
| text_seq = self.text_sequence[class_idx:class_idx+1].expand(batch_n, -1, -1) |
| text_pooled = self.text_pooled[class_idx:class_idx+1].expand(batch_n, -1) |
| |
| with torch.amp.autocast('cuda'): |
| samples = self.model.sample(text_seq, text_pooled, vae) |
| |
| all_samples.append(samples.cpu()) |
| |
| |
| self._unload_vae(vae) |
| |
| samples = torch.cat(all_samples, dim=0).to(self.device) |
| return ((samples + 1) / 2).clamp(0, 1) |
| |
| def save_checkpoint(self, epoch: int, milestone: bool = False): |
| ckpt = { |
| 'epoch': epoch, |
| 'model': self.model.state_dict(), |
| 'optimizer': self.optimizer.state_dict(), |
| 'scheduler': self.scheduler.state_dict(), |
| } |
| |
| torch.save(ckpt, self.output_dir / "ckpt_latest.pt") |
| |
| if milestone: |
| torch.save(ckpt, self.output_dir / f"ckpt_epoch{epoch:03d}.pt") |
| |
| def train(self): |
| num_geo = len(self.cfg.geometric_types) * 2 |
| num_conv = len(self.cfg.conv_types) * 2 |
| total_towers = (num_geo + num_conv) * 2 |
| |
| print(f"\n{'='*60}") |
| print("BEATRIX FLOW - Dual Geometric + Conv Towers (Bottlenecked)") |
| print(f"{'='*60}") |
| print(f"Device: {self.device}") |
| print(f"Geometric towers: {self.cfg.geometric_types} (pos/neg)") |
| print(f"Conv towers: {self.cfg.conv_types} (pos/neg)") |
| print(f"Tower dim: {self.cfg.tower_dim}") |
| print(f"T5 raw β bottleneck: {self.cfg.text_raw_dim} β {self.cfg.bottleneck_dim}") |
| print(f"Latent β manifold: {self.cfg.latent_flat_dim} β {self.cfg.manifold_dim}") |
| print(f"Total towers: {total_towers}") |
| print(f"Batch size: {self.cfg.batch_size}") |
| print(f"Epochs: {self.start_epoch}/{self.cfg.num_epochs}") |
| print(f"{'='*60}\n") |
| |
| for epoch in range(self.start_epoch, self.cfg.num_epochs): |
| train_metrics = self.train_epoch(epoch) |
| val_metrics = self.validate() |
| |
| lr = self.scheduler.get_last_lr()[0] |
| print(f"Epoch {epoch+1:3d} β loss={train_metrics['loss']:.4f} β val={val_metrics['val_loss']:.4f} β Ο={train_metrics['tau']:.2f} β lr={lr:.2e}") |
| |
| |
| samples = self.sample_images(10) |
| grid = make_grid(samples, nrow=10, padding=2) |
| sample_path = self.output_dir / f"samples_epoch{epoch+1:03d}.png" |
| save_image(grid, sample_path) |
| print(f" β Saved samples") |
| |
| |
| self.save_checkpoint(epoch + 1, milestone=((epoch + 1) % 10 == 0)) |
| |
| |
| metrics = { |
| 'loss': train_metrics['loss'], |
| 'val_loss': val_metrics['val_loss'], |
| 'tau': train_metrics['tau'], |
| 'lr': lr, |
| } |
| self._upload_to_hf(epoch + 1, sample_path, metrics) |
| |
| samples = self.sample_images(10) |
| grid = make_grid(samples, nrow=10, padding=2) |
| final_path = self.output_dir / "samples_final.png" |
| save_image(grid, final_path) |
| self.save_checkpoint(self.cfg.num_epochs, milestone=True) |
| self._upload_to_hf(self.cfg.num_epochs, final_path) |
| print(f"\nTraining complete!") |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| |
| cfg = FlowConfig( |
| image_size=256, |
| tower_dim=256, |
| tower_depth=2, |
| num_heads=8, |
| geometric_types=('cantor', 'beatrix'), |
| conv_types=('wide_resnet', 'squeeze_excite'), |
| conv_spatial_size=8, |
| bottleneck_dim=256, |
| manifold_dim=512, |
| batch_size=64, |
| num_epochs=100, |
| cache_dir="./cache", |
| output_dir="./beatrix_cifar_t5", |
| ) |
| |
| trainer = Trainer(cfg) |
| trainer.train() |
|
|
|
|
| def main_full(): |
| """Full 32-tower configuration.""" |
| cfg = FlowConfig( |
| image_size=256, |
| tower_dim=256, |
| tower_depth=2, |
| num_heads=8, |
| geometric_types=('cantor', 'beatrix', 'helix', 'simplex'), |
| conv_types=('wide_resnet', 'frequency', 'bottleneck', 'squeeze_excite'), |
| conv_spatial_size=8, |
| bottleneck_dim=256, |
| manifold_dim=1024, |
| batch_size=64, |
| num_epochs=100, |
| cache_dir="./cache", |
| output_dir="./beatrix_cifar_t5", |
| ) |
| |
| trainer = Trainer(cfg) |
| trainer.train() |
|
|
|
|
| if __name__ == "__main__": |
| main() |