hiitsmeme
initial commit
b25d2b6
import os
import json
from src.commands import finetune, predict_from_csv
from src.eval import compute_roc_auc_from_csv
def load_config(path="config.json"):
with open(path, "r") as f:
config = json.load(f)
return config
config = load_config()
print(config)
# Paths to custom split
train_path = "tox21/tox21_train_clean.csv"
val_path = "tox21/tox21_validation_clean.csv"
train_features_path = train_path.replace(".csv", ".npz")
val_features_path = val_path.replace(".csv", ".npz")
checkpoint_path = "pretrained_models/grover_base.pt"
# Output directory for finetuned model
save_dir = "finetune/"
finetune(train_path, val_path, train_features_path, val_features_path,
save_dir, checkpoint_path, args)
# predict on val set
finetuned_model_dir = save_dir + "/fold_0/model_0"
output_path = save_dir + "/predictions.csv"
predict_from_csv(val_path, val_features_path, finetuned_model_dir, output_path)
# evaluate model
preds_path = save_dir + "/predictions.csv"
labels_path = "tox21/tox21_validation.csv"
valid_mask = np.load("./tox21/valid_mask_val.npy")
auc_per_task, mean_auc = compute_roc_auc_from_csv(preds_path, labels_path, valid_mask)