tox21_grover_classifier / hp_search.py
hiitsmeme
initial commit
b25d2b6
import json
import os
import numpy as np
from datetime import datetime
from src.hp_search import generate_random_search
from src.commands import finetune, predict_from_csv
from src.eval import compute_roc_auc_from_csv
HYPERPARAM_GRID = {
"batch_size": [32],
"init_lr": [10],
"max_lr": [0.001, 0.0005, 0.0001],
"final_lr": [2, 3, 4, 5, 6, 7, 8, 9, 10],
"dropout": [0.0, 0.05, 0.1, 0.2],
"attn_hidden": [128],
"attn_out": [4, 8],
"dist_coff": [0.05, 0.1, 0.15],
"bond_drop_rate": [0.0, 0.2, 0.4, 0.6],
"ffn_num_layer": [2, 3],
"ffn_hidden_size": [5, 7, 13],
}
hp_grid = generate_random_search(HYPERPARAM_GRID, num_trials=300, seed=42)
print("Total number of configs:", len(hp_grid))
# general vars
train_path = "tox21/tox21_train_clean.csv"
val_path = "tox21/tox21_validation_clean.csv"
train_features_path = train_path.replace(".csv", ".npz")
val_features_path = val_path.replace(".csv", ".npz")
checkpoint_path = "pretrained_models/grover_base.pt"
# Tracking best model
best_mean_auc = -1
best_config = None
best_model_path = None
# Create directory for logs
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_dir = f"hp_search/logs/{timestamp}"
os.makedirs(log_dir, exist_ok=True)
overall_log_path = f"{log_dir}/hp_search_results.txt"
best_log_path = f"{log_dir}/best_result.txt"
# iterate over configs
for i, args in enumerate(hp_grid):
save_dir = f"hp_search/trials/Trial_{i+1}"
print("\n=========================================")
print("Training with config:")
print(args)
print("Save dir:", save_dir)
print("=========================================\n")
# finetune model
finetune(train_path, val_path, train_features_path, val_features_path,
save_dir, checkpoint_path, args)
# predict on val set
finetuned_model_dir = save_dir + "/fold_0/model_0"
output_path = save_dir + "/predictions.csv"
predict_from_csv(val_path, val_features_path, finetuned_model_dir, output_path)
# evaluate model
preds_path = save_dir + "/predictions.csv"
labels_path = "tox21/tox21_validation.csv"
valid_mask = np.load("./tox21/valid_mask_val.npy")
auc_per_task, mean_auc = compute_roc_auc_from_csv(preds_path, labels_path, valid_mask)
# Save all experiment results
with open(overall_log_path, "a") as f:
f.write("\n===============================\n")
f.write(f"Trial Num: {i+1}\n")
f.write(f"Mean AUC: {mean_auc}\n")
f.write(f"Config: {args}\n")
f.write(f"Save dir: {save_dir}\n")
f.write(f"AUC per task: {auc_per_task}\n")
# Check if best model
if mean_auc > best_mean_auc:
print("New BEST model found!")
best_mean_auc = mean_auc
best_config = args
best_model_path = save_dir
with open(best_log_path, "w") as f:
f.write("==== BEST MODEL SO FAR ====\n")
f.write(f"Trial Num: {i+1}\n")
f.write(f"Mean AUC: {best_mean_auc}\n")
f.write(f"Config: {best_config}\n")
f.write(f"Saved at: {best_model_path}\n")
print("\n============================")
print("Hyperparameter Search DONE!")
print("Trial Num: ", i+1)
print("Best mean AUC:", best_mean_auc)
print("Best model saved at:", best_model_path)
print("Best config:", best_config)
print("============================\n")