| import torch |
| import torch.nn as nn |
| from torch.utils.data import Dataset, DataLoader |
| import numpy as np |
| import json |
| import os |
| from typing import Dict, List, Tuple, Optional |
| import random |
| import re |
|
|
| def parse_fasta_with_amp_labels(fasta_path: str, max_seq_len: int = 50) -> Dict[str, any]: |
| """ |
| Parse FASTA file and assign AMP/Non-AMP labels based on header prefixes. |
| |
| Label assignment strategy: |
| - AMP (0): Headers starting with '>AP' |
| - Non-AMP (1): Headers starting with '>sp' |
| - Mask (2): Used for CFG training (randomly assigned) |
| |
| File format: |
| - Odd lines: Headers (>sp or >AP) |
| - Even lines: Amino acid sequences |
| |
| Args: |
| fasta_path: Path to FASTA file |
| max_seq_len: Maximum sequence length to include |
| |
| Returns: |
| Dictionary with sequences, labels, and metadata |
| """ |
| sequences = [] |
| labels = [] |
| headers = [] |
| |
| print(f"Parsing FASTA file: {fasta_path}") |
| print("Label assignment: >AP = AMP (0), >sp = Non-AMP (1)") |
| |
| current_header = "" |
| current_sequence = "" |
| |
| with open(fasta_path, 'r') as f: |
| for line in f: |
| line = line.strip() |
| if line.startswith('>'): |
| |
| if current_sequence and current_header: |
| if 2 <= len(current_sequence) <= max_seq_len: |
| |
| canonical_aa = set('ACDEFGHIKLMNPQRSTVWY') |
| if all(aa in canonical_aa for aa in current_sequence.upper()): |
| sequences.append(current_sequence.upper()) |
| headers.append(current_header) |
| |
| |
| if current_header.startswith('AP'): |
| labels.append(0) |
| elif current_header.startswith('sp'): |
| labels.append(1) |
| else: |
| |
| labels.append(1) |
| print(f"Warning: Unknown header prefix in '{current_header}', defaulting to Non-AMP") |
| |
| |
| current_header = line[1:] |
| current_sequence = "" |
| else: |
| current_sequence += line |
| |
| |
| if current_sequence and current_header: |
| if 2 <= len(current_sequence) <= max_seq_len: |
| canonical_aa = set('ACDEFGHIKLMNPQRSTVWY') |
| if all(aa in canonical_aa for aa in current_sequence.upper()): |
| sequences.append(current_sequence.upper()) |
| headers.append(current_header) |
| |
| |
| if current_header.startswith('AP'): |
| labels.append(0) |
| elif current_header.startswith('sp'): |
| labels.append(1) |
| else: |
| |
| labels.append(1) |
| print(f"Warning: Unknown header prefix in '{current_header}', defaulting to Non-AMP") |
| |
| |
| original_labels = np.array(labels) |
| masked_labels = original_labels.copy() |
| mask_probability = 0.1 |
| mask_indices = np.random.choice( |
| len(original_labels), |
| size=int(len(original_labels) * mask_probability), |
| replace=False |
| ) |
| masked_labels[mask_indices] = 2 |
| |
| print(f"✓ Parsed {len(sequences)} valid sequences from FASTA") |
| print(f" AMP sequences: {np.sum(original_labels == 0)}") |
| print(f" Non-AMP sequences: {np.sum(original_labels == 1)}") |
| print(f" Masked for CFG: {len(mask_indices)}") |
| |
| return { |
| 'sequences': sequences, |
| 'headers': headers, |
| 'labels': original_labels.tolist(), |
| 'masked_labels': masked_labels.tolist(), |
| 'mask_indices': mask_indices.tolist() |
| } |
|
|
| class CFGUniProtDataset(Dataset): |
| """ |
| Dataset class for UniProt sequences with classifier-free guidance. |
| |
| This dataset: |
| 1. Loads processed UniProt data with AMP classifications |
| 2. Handles label masking for CFG training |
| 3. Integrates with your existing flow training pipeline |
| 4. Provides sequences, labels, and masking information |
| """ |
| |
| def __init__(self, |
| data_path: str, |
| use_masked_labels: bool = True, |
| mask_probability: float = 0.1, |
| max_seq_len: int = 50, |
| device: str = 'cuda'): |
| |
| self.data_path = data_path |
| self.use_masked_labels = use_masked_labels |
| self.mask_probability = mask_probability |
| self.max_seq_len = max_seq_len |
| self.device = device |
| |
| |
| self._load_data() |
| |
| |
| self.label_map = { |
| 0: 'amp', |
| 1: 'non_amp', |
| 2: 'mask' |
| } |
| |
| print(f"CFG Dataset initialized:") |
| print(f" Total sequences: {len(self.sequences)}") |
| print(f" Using masked labels: {use_masked_labels}") |
| print(f" Mask probability: {mask_probability}") |
| print(f" Label distribution: {self._get_label_distribution()}") |
| |
| def _load_data(self): |
| """Load processed UniProt data.""" |
| if os.path.exists(self.data_path): |
| with open(self.data_path, 'r') as f: |
| data = json.load(f) |
| |
| self.sequences = data['sequences'] |
| self.original_labels = np.array(data['original_labels']) |
| self.masked_labels = np.array(data['masked_labels']) |
| self.mask_indices = set(data['mask_indices']) |
| |
| else: |
| raise FileNotFoundError(f"Data file not found: {self.data_path}") |
| |
| def _get_label_distribution(self) -> Dict[str, int]: |
| """Get distribution of labels in the dataset.""" |
| labels = self.masked_labels if self.use_masked_labels else self.original_labels |
| unique, counts = np.unique(labels, return_counts=True) |
| return {self.label_map[label]: count for label, count in zip(unique, counts)} |
| |
| def __len__(self) -> int: |
| return len(self.sequences) |
| |
| def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: |
| """Get a single sample with sequence and label.""" |
| sequence = self.sequences[idx] |
| |
| |
| if self.use_masked_labels: |
| label = self.masked_labels[idx] |
| else: |
| label = self.original_labels[idx] |
| |
| |
| is_masked = idx in self.mask_indices |
| |
| return { |
| 'sequence': sequence, |
| 'label': torch.tensor(label, dtype=torch.long), |
| 'original_label': torch.tensor(self.original_labels[idx], dtype=torch.long), |
| 'is_masked': torch.tensor(is_masked, dtype=torch.bool), |
| 'index': torch.tensor(idx, dtype=torch.long) |
| } |
| |
| def get_label_statistics(self) -> Dict[str, Dict]: |
| """Get detailed statistics about labels.""" |
| stats = { |
| 'original': self._get_label_distribution(), |
| 'masked': self._get_label_distribution() if self.use_masked_labels else None, |
| 'masking_info': { |
| 'total_masked': len(self.mask_indices), |
| 'mask_probability': self.mask_probability, |
| 'masked_indices': list(self.mask_indices) |
| } |
| } |
| return stats |
|
|
| class CFGFlowDataset(Dataset): |
| """ |
| Dataset that integrates CFG labels with your existing flow training pipeline. |
| |
| This dataset: |
| 1. Loads your existing AMP embeddings |
| 2. Adds CFG labels from UniProt processing |
| 3. Handles the integration between embeddings and labels |
| 4. Provides data in the format expected by your flow training |
| """ |
| |
| def __init__(self, |
| embeddings_path: str, |
| cfg_data_path: str, |
| use_masked_labels: bool = True, |
| max_seq_len: int = 50, |
| device: str = 'cuda'): |
| |
| self.embeddings_path = embeddings_path |
| self.cfg_data_path = cfg_data_path |
| self.use_masked_labels = use_masked_labels |
| self.max_seq_len = max_seq_len |
| self.device = device |
| |
| |
| self._load_embeddings() |
| self._load_cfg_data() |
| self._align_data() |
| |
| print(f"CFG Flow Dataset initialized:") |
| print(f" AMP embeddings: {self.embeddings.shape}") |
| print(f" CFG labels: {len(self.cfg_labels)}") |
| print(f" Aligned samples: {len(self.aligned_indices)}") |
| |
| def _load_embeddings(self): |
| """Load your existing AMP embeddings.""" |
| print(f"Loading AMP embeddings from {self.embeddings_path}...") |
| |
| |
| combined_path = os.path.join(self.embeddings_path, "all_peptide_embeddings.pt") |
| |
| if os.path.exists(combined_path): |
| print(f"Loading combined embeddings from {combined_path} (FULL DATA)...") |
| |
| self.embeddings = torch.load(combined_path, map_location='cpu') |
| print(f"✓ Loaded ALL embeddings: {self.embeddings.shape}") |
| else: |
| print("Combined embeddings file not found, loading individual files...") |
| |
| import glob |
| |
| embedding_files = glob.glob(os.path.join(self.embeddings_path, "*.pt")) |
| embedding_files = [f for f in embedding_files if not f.endswith('metadata.json') and not f.endswith('sequence_ids.json') and not f.endswith('all_peptide_embeddings.pt')] |
| |
| print(f"Found {len(embedding_files)} individual embedding files") |
| |
| |
| embeddings_list = [] |
| for file_path in embedding_files: |
| try: |
| embedding = torch.load(file_path, map_location='cpu') |
| if embedding.dim() == 2: |
| embeddings_list.append(embedding) |
| else: |
| print(f"Warning: Skipping {file_path} - unexpected shape {embedding.shape}") |
| except Exception as e: |
| print(f"Warning: Could not load {file_path}: {e}") |
| |
| if not embeddings_list: |
| raise ValueError("No valid embeddings found!") |
| |
| self.embeddings = torch.stack(embeddings_list) |
| print(f"Loaded {len(self.embeddings)} embeddings from individual files") |
| |
| def _load_cfg_data(self): |
| """Load CFG data from FASTA file with automatic AMP labeling.""" |
| print(f"Loading CFG data from FASTA: {self.cfg_data_path}...") |
| |
| |
| if self.cfg_data_path.endswith('.fasta') or self.cfg_data_path.endswith('.fa'): |
| |
| cfg_data = parse_fasta_with_amp_labels(self.cfg_data_path, self.max_seq_len) |
| |
| self.cfg_sequences = cfg_data['sequences'] |
| self.cfg_headers = cfg_data['headers'] |
| self.cfg_original_labels = np.array(cfg_data['labels']) |
| self.cfg_masked_labels = np.array(cfg_data['masked_labels']) |
| self.cfg_mask_indices = set(cfg_data['mask_indices']) |
| |
| else: |
| |
| with open(self.cfg_data_path, 'r') as f: |
| cfg_data = json.load(f) |
| |
| self.cfg_sequences = cfg_data['sequences'] |
| self.cfg_headers = cfg_data.get('headers', [''] * len(cfg_data['sequences'])) |
| self.cfg_original_labels = np.array(cfg_data['labels']) |
| |
| |
| |
| self.cfg_masked_labels = self.cfg_original_labels.copy() |
| mask_probability = 0.1 |
| mask_indices = np.random.choice( |
| len(self.cfg_original_labels), |
| size=int(len(self.cfg_original_labels) * mask_probability), |
| replace=False |
| ) |
| self.cfg_masked_labels[mask_indices] = 2 |
| self.cfg_mask_indices = set(mask_indices) |
| |
| print(f"Loaded {len(self.cfg_sequences)} CFG sequences") |
| print(f"Label distribution: {np.bincount(self.cfg_original_labels)}") |
| print(f"Masked {len(self.cfg_mask_indices)} labels for CFG training") |
| |
| def _align_data(self): |
| """Align AMP embeddings with CFG data based on sequence matching.""" |
| print("Aligning AMP embeddings with CFG data...") |
| |
| |
| |
| min_samples = min(len(self.embeddings), len(self.cfg_sequences)) |
| |
| self.aligned_indices = list(range(min_samples)) |
| |
| |
| if self.use_masked_labels: |
| self.cfg_labels = self.cfg_masked_labels[:min_samples] |
| else: |
| self.cfg_labels = self.cfg_original_labels[:min_samples] |
| |
| |
| self.aligned_embeddings = self.embeddings[:min_samples] |
| |
| print(f"Aligned {min_samples} samples") |
| |
| def __len__(self) -> int: |
| return len(self.aligned_indices) |
| |
| def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: |
| """Get a single sample with embedding and CFG label.""" |
| |
| embedding = self.aligned_embeddings[idx] |
| label = self.cfg_labels[idx] |
| original_label = self.cfg_original_labels[idx] |
| is_masked = idx in self.cfg_mask_indices |
| |
| return { |
| 'embedding': embedding, |
| 'label': torch.tensor(label, dtype=torch.long), |
| 'original_label': torch.tensor(original_label, dtype=torch.long), |
| 'is_masked': torch.tensor(is_masked, dtype=torch.bool), |
| 'index': torch.tensor(idx, dtype=torch.long) |
| } |
| |
| def get_embedding_stats(self) -> Dict: |
| """Get statistics about the embeddings.""" |
| return { |
| 'shape': self.aligned_embeddings.shape, |
| 'mean': self.aligned_embeddings.mean().item(), |
| 'std': self.aligned_embeddings.std().item(), |
| 'min': self.aligned_embeddings.min().item(), |
| 'max': self.aligned_embeddings.max().item() |
| } |
|
|
| def create_cfg_dataloader(dataset: Dataset, |
| batch_size: int = 32, |
| shuffle: bool = True, |
| num_workers: int = 4) -> DataLoader: |
| """Create a DataLoader for CFG training.""" |
| |
| def collate_fn(batch): |
| """Custom collate function for CFG data.""" |
| |
| embeddings = torch.stack([item['embedding'] for item in batch]) |
| labels = torch.stack([item['label'] for item in batch]) |
| original_labels = torch.stack([item['original_label'] for item in batch]) |
| is_masked = torch.stack([item['is_masked'] for item in batch]) |
| indices = torch.stack([item['index'] for item in batch]) |
| |
| return { |
| 'embeddings': embeddings, |
| 'labels': labels, |
| 'original_labels': original_labels, |
| 'is_masked': is_masked, |
| 'indices': indices |
| } |
| |
| return DataLoader( |
| dataset, |
| batch_size=batch_size, |
| shuffle=shuffle, |
| num_workers=num_workers, |
| collate_fn=collate_fn, |
| pin_memory=True |
| ) |
|
|
| def test_cfg_dataset(): |
| """Test function to verify the CFG dataset works correctly.""" |
| print("Testing CFG Dataset...") |
| |
| |
| test_data = { |
| 'sequences': ['MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG', |
| 'MKLLIVTFCLTFAAL', |
| 'MKLLIVTFCLTFAALMKLLIVTFCLTFAAL'], |
| 'original_labels': [0, 1, 0], |
| 'masked_labels': [0, 2, 0], |
| 'mask_indices': [1] |
| } |
| |
| |
| test_path = 'test_cfg_data.json' |
| with open(test_path, 'w') as f: |
| json.dump(test_data, f) |
| |
| |
| dataset = CFGUniProtDataset(test_path, use_masked_labels=True) |
| |
| print(f"Dataset length: {len(dataset)}") |
| for i in range(len(dataset)): |
| sample = dataset[i] |
| print(f"Sample {i}:") |
| print(f" Sequence: {sample['sequence'][:20]}...") |
| print(f" Label: {sample['label'].item()}") |
| print(f" Original Label: {sample['original_label'].item()}") |
| print(f" Is Masked: {sample['is_masked'].item()}") |
| |
| |
| os.remove(test_path) |
| print("Test completed successfully!") |
|
|
| if __name__ == "__main__": |
| test_cfg_dataset() |