from torch_geometric.data import Batch from torch_geometric.utils import from_rdmol import torch from src.model import GIN from src.preprocess import create_clean_mol_objects from src.seed import set_seed def predict_from_smiles(smiles_list): """ Predict toxicity targets for a list of SMILES strings. Args: smiles_list (list[str]): SMILES strings Returns: dict: {smiles: {target_name: prediction_prob}} """ set_seed(42) # tox21 targets TARGET_NAMES = [ "NR-AR", "NR-AR-LBD", "NR-AhR", "NR-Aromatase", "NR-ER", "NR-ER-LBD", "NR-PPAR-gamma", "SR-ARE", "SR-ATAD5", "SR-HSE", "SR-MMP", "SR-p53", ] DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Received {len(smiles_list)} SMILES strings") # setup model model = GIN(num_features=9, num_classes=12, dropout=0.1, hidden_dim=128, num_layers=5, add_or_mean="mean") model_path = "./assets/best_gin_model.pt" model.load_state_dict(torch.load(model_path, map_location=DEVICE)) print(f"Loaded model from {model_path}") model.to(DEVICE) model.eval() predictions = {} for smiles in smiles_list: try: # Convert SMILES to graph mol, _ = create_clean_mol_objects([smiles]) data = from_rdmol(mol[0]).to(DEVICE) batch = Batch.from_data_list([data]) # Forward pass with torch.no_grad(): logits = model(batch.x, batch.edge_index, batch.batch) probs = torch.sigmoid(logits).cpu().numpy().flatten() # Map predictions to targets pred_dict = {t: float(p) for t, p in zip(TARGET_NAMES, probs)} predictions[smiles] = pred_dict except Exception as e: # If SMILES fails, return zeros pred_dict = {t: 0.0 for t in TARGET_NAMES} predictions[smiles] = pred_dict return predictions