| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from torch.utils.data import Dataset, DataLoader |
| from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR |
| import json |
| import numpy as np |
| from tqdm import tqdm |
|
|
| |
| ESM_DIM = 1280 |
| COMP_RATIO = 16 |
| COMP_DIM = ESM_DIM // COMP_RATIO |
| MAX_SEQ_LEN = 50 |
| BATCH_SIZE = 32 |
| EPOCHS = 30 |
| BASE_LR = 1e-3 |
| LR_MIN = 8e-5 |
| WARMUP_STEPS = 10_000 |
| DEPTH = 4 |
| HEADS = 8 |
| DIM_FF = ESM_DIM * 4 |
| POOLING = True |
|
|
| |
| class PrecomputedEmbeddingDataset(Dataset): |
| def __init__(self, embeddings_path): |
| """ |
| Load pre-computed embeddings from the final_sequence_encoder.py output. |
| Args: |
| embeddings_path: Path to the directory containing individual .pt embedding files |
| """ |
| print(f"Loading pre-computed embeddings from {embeddings_path}...") |
| |
| |
| import glob |
| import os |
| |
| embedding_files = glob.glob(os.path.join(embeddings_path, "*.pt")) |
| embedding_files = [f for f in embedding_files if not f.endswith('metadata.json') and not f.endswith('sequence_ids.json')] |
| |
| print(f"Found {len(embedding_files)} embedding files") |
| |
| |
| embeddings_list = [] |
| for file_path in embedding_files: |
| try: |
| embedding = torch.load(file_path) |
| if embedding.dim() == 2: |
| embeddings_list.append(embedding) |
| else: |
| print(f"Warning: Skipping {file_path} - unexpected shape {embedding.shape}") |
| except Exception as e: |
| print(f"Warning: Could not load {file_path}: {e}") |
| |
| if not embeddings_list: |
| raise ValueError("No valid embeddings found!") |
| |
| self.embeddings = torch.stack(embeddings_list) |
| print(f"Loaded {len(self.embeddings)} embeddings with shape {self.embeddings.shape}") |
| |
| |
| if len(self.embeddings.shape) != 3: |
| raise ValueError(f"Expected 3D tensor, got shape {self.embeddings.shape}") |
| |
| if self.embeddings.shape[1] != MAX_SEQ_LEN: |
| print(f"Warning: Expected sequence length {MAX_SEQ_LEN}, got {self.embeddings.shape[1]}") |
| |
| if self.embeddings.shape[2] != ESM_DIM: |
| print(f"Warning: Expected embedding dim {ESM_DIM}, got {self.embeddings.shape[2]}") |
|
|
| def __len__(self): |
| return len(self.embeddings) |
| |
| def __getitem__(self, idx): |
| return self.embeddings[idx] |
|
|
| |
| class Compressor(nn.Module): |
| def __init__(self, in_dim=ESM_DIM, out_dim=COMP_DIM): |
| super().__init__() |
| self.norm = nn.LayerNorm(in_dim) |
| layer = lambda: nn.TransformerEncoderLayer( |
| d_model=in_dim, nhead=HEADS, dim_feedforward=DIM_FF, |
| batch_first=True) |
| |
| self.pre_tr = nn.TransformerEncoder(layer(), num_layers=DEPTH//2) |
| self.post_tr = nn.TransformerEncoder(layer(), num_layers=DEPTH//2) |
| self.proj = nn.Sequential( |
| nn.LayerNorm(in_dim), |
| nn.Linear(in_dim, out_dim), |
| nn.Tanh() |
| ) |
| self.pooling = POOLING |
|
|
| def forward(self, x, stats=None): |
| if stats: |
| m, s, mn, mx = stats['mean'], stats['std'], stats['min'], stats['max'] |
| |
| m = m.to(x.device) |
| s = s.to(x.device) |
| mn = mn.to(x.device) |
| mx = mx.to(x.device) |
| x = torch.clamp((x - m) / s, -4, 4) |
| x = torch.clamp((x - mn) / (mx - mn + 1e-8), 0, 1) |
| x = self.norm(x) |
| x = self.pre_tr(x) |
| if self.pooling: |
| B, L, D = x.shape |
| if L % 2: x = x[:, :-1, :] |
| x = x.view(B, L//2, 2, D).mean(2) |
| x = self.post_tr(x) |
| return self.proj(x) |
|
|
| |
| class Decompressor(nn.Module): |
| def __init__(self, in_dim=COMP_DIM, out_dim=ESM_DIM): |
| super().__init__() |
| self.proj = nn.Sequential( |
| nn.LayerNorm(in_dim), |
| nn.Linear(in_dim, out_dim) |
| ) |
| layer = lambda: nn.TransformerEncoderLayer( |
| d_model=out_dim, nhead=HEADS, dim_feedforward=DIM_FF, |
| batch_first=True) |
| self.decoder = nn.TransformerEncoder(layer(), num_layers=DEPTH//2) |
| self.pooling = POOLING |
|
|
| def forward(self, z): |
| x = self.proj(z) |
| if self.pooling: |
| x = x.repeat_interleave(2, dim=1) |
| return self.decoder(x) |
|
|
| |
| def train_with_precomputed_embeddings(embeddings_path, device='cuda'): |
| """ |
| Train compressor using pre-computed embeddings from final_sequence_encoder.py |
| """ |
| |
| ds = PrecomputedEmbeddingDataset(embeddings_path) |
| |
| |
| print("Computing normalization statistics...") |
| flat = ds.embeddings.view(-1, ESM_DIM) |
| stats = { |
| 'mean': flat.mean(0), |
| 'std': flat.std(0) + 1e-8, |
| 'min': torch.clamp((flat - flat.mean(0)) / (flat.std(0) + 1e-8), -4,4).min(0)[0], |
| 'max': torch.clamp((flat - flat.mean(0)) / (flat.std(0) + 1e-8), -4,4).max(0)[0] |
| } |
| |
| |
| torch.save(stats, 'normalization_stats.pt') |
| print("Saved normalization statistics to normalization_stats.pt") |
| |
| |
| dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True) |
| |
| |
| comp = Compressor().to(device) |
| decomp = Decompressor().to(device) |
| |
| |
| opt = optim.AdamW(list(comp.parameters()) + list(decomp.parameters()), lr=BASE_LR) |
| |
| |
| warmup_sched = LinearLR(opt, start_factor=1e-8, end_factor=1.0, total_iters=WARMUP_STEPS) |
| cosine_sched = CosineAnnealingLR(opt, T_max=EPOCHS*len(dl), eta_min=LR_MIN) |
| sched = SequentialLR(opt, [warmup_sched, cosine_sched], milestones=[WARMUP_STEPS]) |
|
|
| print(f"Starting training for {EPOCHS} epochs...") |
| print(f"Device: {device}") |
| print(f"Batch size: {BATCH_SIZE}") |
| print(f"Total batches per epoch: {len(dl)}") |
|
|
| |
| for epoch in range(1, EPOCHS+1): |
| total_loss = 0 |
| comp.train() |
| decomp.train() |
| |
| for batch_idx, x in enumerate(tqdm(dl, desc=f"Epoch {epoch}/{EPOCHS}")): |
| x = x.to(device) |
| z = comp(x, stats) |
| xr = decomp(z) |
| loss = (x - xr).pow(2).mean() |
| |
| opt.zero_grad() |
| loss.backward() |
| opt.step() |
| sched.step() |
| |
| total_loss += loss.item() |
| |
| |
| if batch_idx % 100 == 0: |
| print(f" Batch {batch_idx}/{len(dl)} - Loss: {loss.item():.6f}") |
| |
| avg_loss = total_loss / len(dl) |
| print(f"Epoch {epoch}/{EPOCHS} — Average MSE: {avg_loss:.6f}") |
| |
| |
| if epoch % 5 == 0: |
| torch.save({ |
| 'epoch': epoch, |
| 'compressor_state_dict': comp.state_dict(), |
| 'decompressor_state_dict': decomp.state_dict(), |
| 'optimizer_state_dict': opt.state_dict(), |
| 'loss': avg_loss, |
| }, f'checkpoint_epoch_{epoch}.pth') |
|
|
| |
| torch.save(comp.state_dict(), 'compressor_final.pth') |
| torch.save(decomp.state_dict(), 'decompressor_final.pth') |
| print("Training completed! Models saved as compressor_final.pth and decompressor_final.pth") |
|
|
| |
| def load_and_test_models(compressor_path, decompressor_path, embeddings_path, device='cuda'): |
| """ |
| Load trained models and test reconstruction quality |
| """ |
| print("Loading trained models...") |
| comp = Compressor().to(device) |
| decomp = Decompressor().to(device) |
| |
| comp.load_state_dict(torch.load(compressor_path)) |
| decomp.load_state_dict(torch.load(decompressor_path)) |
| |
| comp.eval() |
| decomp.eval() |
| |
| |
| ds = PrecomputedEmbeddingDataset(embeddings_path) |
| test_loader = DataLoader(ds, batch_size=16, shuffle=False) |
| |
| |
| stats = torch.load('normalization_stats.pt') |
| |
| print("Testing reconstruction quality...") |
| total_mse = 0 |
| total_samples = 0 |
| |
| with torch.no_grad(): |
| for batch in tqdm(test_loader, desc="Testing"): |
| x = batch.to(device) |
| z = comp(x, stats) |
| xr = decomp(z) |
| mse = (x - xr).pow(2).mean() |
| total_mse += mse.item() * len(x) |
| total_samples += len(x) |
| |
| avg_mse = total_mse / total_samples |
| print(f"Average reconstruction MSE: {avg_mse:.6f}") |
| |
| return avg_mse |
|
|
| |
| if __name__ == '__main__': |
| import argparse |
| |
| parser = argparse.ArgumentParser(description='Train protein compressor with pre-computed embeddings') |
| parser.add_argument('--embeddings', type=str, default='/data2/edwardsun/flow_project/compressor_dataset/peptide_embeddings.pt', |
| help='Path to pre-computed embeddings from final_sequence_encoder.py') |
| parser.add_argument('--device', type=str, default='cuda', help='Device to use (cuda/cpu)') |
| parser.add_argument('--test', action='store_true', help='Test existing models instead of training') |
| |
| args = parser.parse_args() |
| |
| device = torch.device(args.device if torch.cuda.is_available() else 'cpu') |
| print(f"Using device: {device}") |
| |
| if args.test: |
| |
| load_and_test_models('compressor_final.pth', 'decompressor_final.pth', args.embeddings, device) |
| else: |
| |
| train_with_precomputed_embeddings(args.embeddings, device) |