import json import os import numpy as np from datetime import datetime from src.hp_search import generate_random_search from src.commands import finetune, predict_from_csv from src.eval import compute_roc_auc_from_csv HYPERPARAM_GRID = { "batch_size": [32], "init_lr": [10], "max_lr": [0.001, 0.0005, 0.0001], "final_lr": [2, 3, 4, 5, 6, 7, 8, 9, 10], "dropout": [0.0, 0.05, 0.1, 0.2], "attn_hidden": [128], "attn_out": [4, 8], "dist_coff": [0.05, 0.1, 0.15], "bond_drop_rate": [0.0, 0.2, 0.4, 0.6], "ffn_num_layer": [2, 3], "ffn_hidden_size": [5, 7, 13], } hp_grid = generate_random_search(HYPERPARAM_GRID, num_trials=300, seed=42) print("Total number of configs:", len(hp_grid)) # general vars train_path = "tox21/tox21_train_clean.csv" val_path = "tox21/tox21_validation_clean.csv" train_features_path = train_path.replace(".csv", ".npz") val_features_path = val_path.replace(".csv", ".npz") checkpoint_path = "pretrained_models/grover_base.pt" # Tracking best model best_mean_auc = -1 best_config = None best_model_path = None # Create directory for logs timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") log_dir = f"hp_search/logs/{timestamp}" os.makedirs(log_dir, exist_ok=True) overall_log_path = f"{log_dir}/hp_search_results.txt" best_log_path = f"{log_dir}/best_result.txt" # iterate over configs for i, args in enumerate(hp_grid): save_dir = f"hp_search/trials/Trial_{i+1}" print("\n=========================================") print("Training with config:") print(args) print("Save dir:", save_dir) print("=========================================\n") # finetune model finetune(train_path, val_path, train_features_path, val_features_path, save_dir, checkpoint_path, args) # predict on val set finetuned_model_dir = save_dir + "/fold_0/model_0" output_path = save_dir + "/predictions.csv" predict_from_csv(val_path, val_features_path, finetuned_model_dir, output_path) # evaluate model preds_path = save_dir + "/predictions.csv" labels_path = "tox21/tox21_validation.csv" valid_mask = np.load("./tox21/valid_mask_val.npy") auc_per_task, mean_auc = compute_roc_auc_from_csv(preds_path, labels_path, valid_mask) # Save all experiment results with open(overall_log_path, "a") as f: f.write("\n===============================\n") f.write(f"Trial Num: {i+1}\n") f.write(f"Mean AUC: {mean_auc}\n") f.write(f"Config: {args}\n") f.write(f"Save dir: {save_dir}\n") f.write(f"AUC per task: {auc_per_task}\n") # Check if best model if mean_auc > best_mean_auc: print("New BEST model found!") best_mean_auc = mean_auc best_config = args best_model_path = save_dir with open(best_log_path, "w") as f: f.write("==== BEST MODEL SO FAR ====\n") f.write(f"Trial Num: {i+1}\n") f.write(f"Mean AUC: {best_mean_auc}\n") f.write(f"Config: {best_config}\n") f.write(f"Saved at: {best_model_path}\n") print("\n============================") print("Hyperparameter Search DONE!") print("Trial Num: ", i+1) print("Best mean AUC:", best_mean_auc) print("Best model saved at:", best_model_path) print("Best config:", best_config) print("============================\n")