File size: 5,536 Bytes
d36f70f
bccf506
 
 
 
d36f70f
 
 
bccf506
d36f70f
 
bccf506
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
---
title: Tox21 GROVER Classifier
emoji: 🤖
colorFrom: green
colorTo: blue
sdk: docker
pinned: false
license: cc-by-nc-4.0
short_description: GROVER Classifier for Tox21
---

# Tox21 Graph Isomorphism Network (GIN) Classifier

This repository hosts a Hugging Face Space that provides an examplary API for submitting models to the [Tox21 Leaderboard](https://huggingface.co/spaces/ml-jku/tox21_leaderboard).

Here the base version of [GROVER](https://arxiv.org/pdf/2007.02835) is finetuned on the Tox21 dataset, using the [code](https://github.com/tencent-ailab/grover) provided and the finetuning hyperparameters specified in the paper. The final model is provided for 
inference. Model input is a SMILES string of the small molecule, and the output are 12 numeric values for 
each of the toxic effects of the Tox21 dataset. 


**Important:** For leaderboard submission, your Space needs to include training code. The file `train.py` should train the model using the config specified inside the `config/` folder and save the final model parameters into a file inside the `checkpoints/` folder. The model should be trained using the [Tox21_dataset](https://huggingface.co/datasets/ml-jku/tox21) provided on Hugging Face. The datasets can be loaded like this:
```python
from datasets import load_dataset
ds = load_dataset("ml-jku/tox21", token=token)
train_df = ds["train"].to_pandas()
val_df = ds["validation"].to_pandas()
```
 Additionally, the Space needs to implement inference in the `predict()` function inside `predict.py`. The `predict()` function must keep the provided skeleton: it should take a list of SMILES strings as input and return a nested prediction dictionary as output, with SMILES as keys and dictionaries containing targetname-prediction pairs as values. Therefore, any preprocessing of SMILES strings must be executed on-the-fly during inference.

# Repository Structure
- `predict.py` - Defines the `predict()` function required by the leaderboard (entry point for inference).
- `app.py` - FastAPI application wrapper (can be used as-is).
- `main.py` - provided grover code.
- `evaluate.py` - predict outputs of a given model on a dataset and compute AUC.
- `generate_features.py` - generate features used as model input, given a csv containing smiles.
- `hp_search.py` - finetune and evaluate 300 configs that are randomly drawn from a parameter grid specified in the paper.
- `prepare_data.py` - clean smiles in a given csv and save a mask to consider uncleanable smiles during evaluation.
- `train.py` - finetunes and saves a model using the config in the `config/` folder.

- `config/` - the config file used by `train.py`. 
- `checkpoint/` - the saved model that is used in `predict.py` is here.
- `grover/` - [GROVER](https://github.com/tencent-ailab/grover) repository with slight changes in file structure and import paths.
- `predictions/` - [GROVER](https://github.com/tencent-ailab/grover) saves prediction results in a csv. These are saved here.
- `pretrained/` - pretrained GROVER models provided.
- `tox21/` - all masks, generated features and clean data csv files are saved here.

- `src/` - Core model & preprocessing logic:
    - `preprocess.py` - SMILES preprocessing pipeline and dataset creation
    - `commands.py` - GROVER commands
    - `eval.py` - compute evaluation metric
    - `hp_search.py` - generate configs for hyperparameter search

# Quickstart with Spaces

You can easily adapt this project in your own Hugging Face account:

- Open this Space on Hugging Face.

- Click "Duplicate this Space" (top-right corner).

- Create a `.env` according to `.example.env`.

- Modify `src/` for your preprocessing pipeline and model class

- Modify `predict()` inside `predict.py` to perform model inference while keeping the function skeleton unchanged to remain compatible with the leaderboard.

- Modify `train.py` according to your model and preprocessing pipeline.

- Modify the file inside `config/` to contain all hyperparameters that are set in `train.py`.
That’s it, your model will be available as an API endpoint for the Tox21 Leaderboard.

# Installation
To run the GROVER classifier, clone the repository and install dependencies:

```bash
git clone https://huggingface.co/spaces/ml-jku/tox21_grover_classifier
cd tox21_grover_classifier
conda env create -f environment.yaml
```

# Training


To train the GROVER model from scratch, download the [Tox21](https://huggingface.co/datasets/ml-jku/tox21/tree/main) csv files and put them into the tox21 folder.

Then run:

```bash
python prepare_data.py
python generate_features.py
python train.py
```

These commands will:
1. Load and preprocess the Tox21 training dataset
2. Generate and save features used as GROVER inputs
2. Finetune the GROVER base model
3. Store the resulting model in the `finetune/` directory.

# Inference

For inference, you only need `predict.py`.

Example usage inside Python:

```python
from predict import predict

smiles_list = ["CCO", "c1ccccc1", "CC(=O)O"]
results = predict(smiles_list)

print(results)
```

The output will be a nested dictionary in the format:

```python
{
    "CCO": {"target1": 0, "target2": 1, ..., "target12": 0},
    "c1ccccc1": {"target1": 1, "target2": 0, ..., "target12": 1},
    "CC(=O)O": {"target1": 0, "target2": 0, ..., "target12": 0}
}
```

# Notes

- Adapting `predict.py`, `train.py`, `config/`, and `checkpoints/` is required for leaderboard submission.

- Preprocessing (here inside `src/preprocess.py`) must be done inside `predict.py` not just `train.py`.