Spaces:
Sleeping
Sleeping
hiitsmeme
commited on
Commit
·
f986893
1
Parent(s):
b25d2b6
added grover code, hf api files
Browse files- Dockerfile +16 -0
- generate_features.py +2 -1
- grover +0 -1
- grover/data/__init__.py +7 -0
- grover/data/dist_sampler.py +137 -0
- grover/data/groverdataset.py +247 -0
- grover/data/moldataset.py +245 -0
- grover/data/molfeaturegenerator.py +146 -0
- grover/data/molgraph.py +378 -0
- grover/data/scaler.py +70 -0
- grover/data/task_labels.py +116 -0
- grover/data/torchvocab.py +190 -0
- grover/model/layers.py +902 -0
- grover/model/models.py +506 -0
- grover/util/metrics.py +122 -0
- grover/util/multi_gpu_wrapper.py +110 -0
- grover/util/nn_utils.py +96 -0
- grover/util/parsing.py +487 -0
- grover/util/scheduler.py +97 -0
- grover/util/utils.py +797 -0
- prepare_data.py +2 -1
- requirements.txt +82 -0
- scripts/__init__.py +0 -0
- scripts/build_vocab.py +41 -0
- scripts/save_features.py +127 -0
- scripts/split_data.py +87 -0
- src/commands.py +10 -0
- task/__init__.py +0 -0
- task/cross_validate.py +69 -0
- task/fingerprint.py +79 -0
- task/grovertrainer.py +279 -0
- task/predict.py +316 -0
- task/pretrain.py +241 -0
- task/run_evaluation.py +157 -0
- task/train.py +454 -0
Dockerfile
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
|
| 2 |
+
# you will also find guides on how best to write your Dockerfile
|
| 3 |
+
|
| 4 |
+
FROM python:3.11.4
|
| 5 |
+
|
| 6 |
+
RUN useradd -m -u 1000 user
|
| 7 |
+
USER user
|
| 8 |
+
ENV PATH="/home/user/.local/bin:$PATH"
|
| 9 |
+
|
| 10 |
+
WORKDIR /app
|
| 11 |
+
|
| 12 |
+
COPY --chown=user ./requirements.txt requirements.txt
|
| 13 |
+
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
| 14 |
+
|
| 15 |
+
COPY --chown=user . /app
|
| 16 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
generate_features.py
CHANGED
|
@@ -6,4 +6,5 @@ from src.commands import generate_features
|
|
| 6 |
TRAIN_CSV = "./tox21/tox21_train_clean.csv"
|
| 7 |
VAL_CSV = "./tox21/tox21_validation_clean.csv"
|
| 8 |
|
| 9 |
-
generate_features(
|
|
|
|
|
|
| 6 |
TRAIN_CSV = "./tox21/tox21_train_clean.csv"
|
| 7 |
VAL_CSV = "./tox21/tox21_validation_clean.csv"
|
| 8 |
|
| 9 |
+
generate_features(TRAIN_CSV, TRAIN_CSV.replace('.csv', '.npz'))
|
| 10 |
+
generate_features(VAL_CSV, VAL_CSV.replace('.csv', '.npz'))
|
grover
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
Subproject commit 3f280d7d3419a781d303b1500c7039e37a1d87a2
|
|
|
|
|
|
grover/data/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from grover.data.molfeaturegenerator import get_available_features_generators, get_features_generator
|
| 2 |
+
from grover.data.molgraph import BatchMolGraph, get_atom_fdim, get_bond_fdim, mol2graph
|
| 3 |
+
from grover.data.molgraph import MolGraph, BatchMolGraph, MolCollator
|
| 4 |
+
from grover.data.moldataset import MoleculeDataset, MoleculeDatapoint
|
| 5 |
+
from grover.data.scaler import StandardScaler
|
| 6 |
+
|
| 7 |
+
# from .utils import load_features, save_features
|
grover/data/dist_sampler.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
The re-implemented distributed sampler for the distributed training of GROVER.
|
| 3 |
+
"""
|
| 4 |
+
import math
|
| 5 |
+
import time
|
| 6 |
+
import torch
|
| 7 |
+
from torch.utils.data.sampler import Sampler
|
| 8 |
+
import torch.distributed as dist
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class DistributedSampler(Sampler):
|
| 12 |
+
"""Sampler that restricts data loading to a subset of the dataset.
|
| 13 |
+
|
| 14 |
+
It is especially useful in conjunction with
|
| 15 |
+
:class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
|
| 16 |
+
process can pass a DistributedSampler instance as a DataLoader sampler,
|
| 17 |
+
and load a subset of the original dataset that is exclusive to it.
|
| 18 |
+
|
| 19 |
+
.. note::
|
| 20 |
+
Dataset is assumed to be of constant size.
|
| 21 |
+
|
| 22 |
+
Arguments:
|
| 23 |
+
dataset: Dataset used for sampling.
|
| 24 |
+
num_replicas (optional): Number of processes participating in
|
| 25 |
+
distributed training.
|
| 26 |
+
rank (optional): Rank of the current process within num_replicas.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, sample_per_file=None):
|
| 30 |
+
if num_replicas is None:
|
| 31 |
+
if not dist.is_available():
|
| 32 |
+
raise RuntimeError("Requires distributed package to be available")
|
| 33 |
+
num_replicas = dist.get_world_size()
|
| 34 |
+
if rank is None:
|
| 35 |
+
if not dist.is_available():
|
| 36 |
+
raise RuntimeError("Requires distributed package to be available")
|
| 37 |
+
rank = dist.get_rank()
|
| 38 |
+
self.dataset = dataset
|
| 39 |
+
self.num_replicas = num_replicas
|
| 40 |
+
self.rank = rank
|
| 41 |
+
self.epoch = 0
|
| 42 |
+
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
|
| 43 |
+
self.total_size = self.num_samples * self.num_replicas
|
| 44 |
+
self.sample_per_file = sample_per_file
|
| 45 |
+
self.shuffle = shuffle
|
| 46 |
+
|
| 47 |
+
def get_indices(self):
|
| 48 |
+
|
| 49 |
+
indices = list(range(len(self.dataset)))
|
| 50 |
+
|
| 51 |
+
if self.sample_per_file is not None:
|
| 52 |
+
indices = self.sub_indices_of_rank(indices)
|
| 53 |
+
else:
|
| 54 |
+
# add extra samples to make it evenly divisible
|
| 55 |
+
indices += indices[:(self.total_size - len(indices))]
|
| 56 |
+
assert len(indices) == self.total_size
|
| 57 |
+
# subsample
|
| 58 |
+
s = self.rank * self.num_samples
|
| 59 |
+
e = min((self.rank + 1) * self.num_samples, len(indices))
|
| 60 |
+
|
| 61 |
+
# indices = indices[self.rank:self.total_size:self.num_replicas]
|
| 62 |
+
indices = indices[s:e]
|
| 63 |
+
|
| 64 |
+
if self.shuffle:
|
| 65 |
+
g = torch.Generator()
|
| 66 |
+
# the seed need to be considered.
|
| 67 |
+
g.manual_seed((self.epoch + 1) * (self.rank + 1) * time.time())
|
| 68 |
+
idx = torch.randperm(len(indices), generator=g).tolist()
|
| 69 |
+
indices = [indices[i] for i in idx]
|
| 70 |
+
|
| 71 |
+
# disable this since sub_indices_of_rank.
|
| 72 |
+
# assert len(indices) == self.num_samples
|
| 73 |
+
|
| 74 |
+
return indices
|
| 75 |
+
|
| 76 |
+
def sub_indices_of_rank(self, indices):
|
| 77 |
+
|
| 78 |
+
# fix generator for each epoch
|
| 79 |
+
g = torch.Generator()
|
| 80 |
+
# All data should be loaded in each epoch.
|
| 81 |
+
g.manual_seed((self.epoch + 1) * 2 + 3)
|
| 82 |
+
|
| 83 |
+
# the fake file indices to cache
|
| 84 |
+
f_indices = list(range(int(math.ceil(len(indices) * 1.0 / self.sample_per_file))))
|
| 85 |
+
idx = torch.randperm(len(f_indices), generator=g).tolist()
|
| 86 |
+
f_indices = [f_indices[i] for i in idx]
|
| 87 |
+
|
| 88 |
+
file_per_rank = int(math.ceil(len(f_indices) * 1.0 / self.num_replicas))
|
| 89 |
+
# add extra fake file to make it evenly divisible
|
| 90 |
+
f_indices += f_indices[:(file_per_rank * self.num_replicas - len(f_indices))]
|
| 91 |
+
|
| 92 |
+
# divide index by rank
|
| 93 |
+
rank_s = self.rank * file_per_rank
|
| 94 |
+
rank_e = min((self.rank + 1) * file_per_rank, len(f_indices))
|
| 95 |
+
|
| 96 |
+
# get file index for this rank
|
| 97 |
+
f_indices = f_indices[rank_s:rank_e]
|
| 98 |
+
# print("f_indices")
|
| 99 |
+
# print(f_indices)
|
| 100 |
+
res_indices = []
|
| 101 |
+
for fi in f_indices:
|
| 102 |
+
# get real indices for this rank
|
| 103 |
+
si = fi * self.sample_per_file
|
| 104 |
+
ei = min((fi + 1) * self.sample_per_file, len(indices))
|
| 105 |
+
cur_idx = [indices[i] for i in range(si, ei)]
|
| 106 |
+
res_indices += cur_idx
|
| 107 |
+
|
| 108 |
+
self.num_samples = len(res_indices)
|
| 109 |
+
return res_indices
|
| 110 |
+
|
| 111 |
+
def __iter__(self):
|
| 112 |
+
return iter(self.get_indices())
|
| 113 |
+
|
| 114 |
+
def __len__(self):
|
| 115 |
+
return self.num_samples
|
| 116 |
+
|
| 117 |
+
def set_epoch(self, epoch):
|
| 118 |
+
self.epoch = epoch
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
if __name__ == "__main__":
|
| 122 |
+
# dataset = [1] * 9
|
| 123 |
+
# ds = DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=True)
|
| 124 |
+
# print(ds.get_indices())
|
| 125 |
+
# ds = DistributedSampler(dataset, num_replicas=2, rank=1, shuffle=True)
|
| 126 |
+
# print(ds.get_indices())
|
| 127 |
+
|
| 128 |
+
dataset = [1] * 190001
|
| 129 |
+
res = []
|
| 130 |
+
ds = DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=True, sample_per_file=777)
|
| 131 |
+
res.extend(ds.get_indices())
|
| 132 |
+
print(len(ds.get_indices()))
|
| 133 |
+
ds = DistributedSampler(dataset, num_replicas=2, rank=1, shuffle=True, sample_per_file=777)
|
| 134 |
+
res.extend(ds.get_indices())
|
| 135 |
+
print(len(ds.get_indices()))
|
| 136 |
+
print(len(set(res)))
|
| 137 |
+
print("hello")
|
grover/data/groverdataset.py
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
The dataset used in training GROVER.
|
| 3 |
+
"""
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
import csv
|
| 7 |
+
from typing import Union, List
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
from torch.utils.data.dataset import Dataset
|
| 11 |
+
from rdkit import Chem
|
| 12 |
+
|
| 13 |
+
import grover.util.utils as feautils
|
| 14 |
+
from grover.data import mol2graph
|
| 15 |
+
from grover.data.moldataset import MoleculeDatapoint
|
| 16 |
+
from grover.data.task_labels import atom_to_vocab, bond_to_vocab
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def get_data(data_path, logger=None):
|
| 20 |
+
"""
|
| 21 |
+
Load data from the data_path.
|
| 22 |
+
:param data_path: the data_path.
|
| 23 |
+
:param logger: the logger.
|
| 24 |
+
:return:
|
| 25 |
+
"""
|
| 26 |
+
debug = logger.debug if logger is not None else print
|
| 27 |
+
summary_path = os.path.join(data_path, "summary.txt")
|
| 28 |
+
smiles_path = os.path.join(data_path, "graph")
|
| 29 |
+
feature_path = os.path.join(data_path, "feature")
|
| 30 |
+
|
| 31 |
+
fin = open(summary_path)
|
| 32 |
+
n_files = int(fin.readline().strip().split(":")[-1])
|
| 33 |
+
n_samples = int(fin.readline().strip().split(":")[-1])
|
| 34 |
+
sample_per_file = int(fin.readline().strip().split(":")[-1])
|
| 35 |
+
debug("Loading data:")
|
| 36 |
+
debug("Number of files: %d" % n_files)
|
| 37 |
+
debug("Number of samples: %d" % n_samples)
|
| 38 |
+
debug("Samples/file: %d" % sample_per_file)
|
| 39 |
+
|
| 40 |
+
datapoints = []
|
| 41 |
+
for i in range(n_files):
|
| 42 |
+
smiles_path_i = os.path.join(smiles_path, str(i) + ".csv")
|
| 43 |
+
feature_path_i = os.path.join(feature_path, str(i) + ".npz")
|
| 44 |
+
n_samples_i = sample_per_file if i != (n_files - 1) else n_samples % sample_per_file
|
| 45 |
+
datapoints.append(BatchDatapoint(smiles_path_i, feature_path_i, n_samples_i))
|
| 46 |
+
return BatchMolDataset(datapoints), sample_per_file
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def split_data(data,
|
| 50 |
+
split_type='random',
|
| 51 |
+
sizes=(0.8, 0.1, 0.1),
|
| 52 |
+
seed=0,
|
| 53 |
+
logger=None):
|
| 54 |
+
"""
|
| 55 |
+
Split data with given train/validation/test ratio.
|
| 56 |
+
:param data:
|
| 57 |
+
:param split_type:
|
| 58 |
+
:param sizes:
|
| 59 |
+
:param seed:
|
| 60 |
+
:param logger:
|
| 61 |
+
:return:
|
| 62 |
+
"""
|
| 63 |
+
assert len(sizes) == 3 and sum(sizes) == 1
|
| 64 |
+
|
| 65 |
+
if split_type == "random":
|
| 66 |
+
data.shuffle(seed=seed)
|
| 67 |
+
data = data.data
|
| 68 |
+
|
| 69 |
+
train_size = int(sizes[0] * len(data))
|
| 70 |
+
train_val_size = int((sizes[0] + sizes[1]) * len(data))
|
| 71 |
+
|
| 72 |
+
train = data[:train_size]
|
| 73 |
+
val = data[train_size:train_val_size]
|
| 74 |
+
test = data[train_val_size:]
|
| 75 |
+
|
| 76 |
+
return BatchMolDataset(train), BatchMolDataset(val), BatchMolDataset(test)
|
| 77 |
+
else:
|
| 78 |
+
raise NotImplementedError("Do not support %s splits" % split_type)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class BatchDatapoint:
|
| 82 |
+
def __init__(self,
|
| 83 |
+
smiles_file,
|
| 84 |
+
feature_file,
|
| 85 |
+
n_samples,
|
| 86 |
+
):
|
| 87 |
+
self.smiles_file = smiles_file
|
| 88 |
+
self.feature_file = feature_file
|
| 89 |
+
# deal with the last batch graph numbers.
|
| 90 |
+
self.n_samples = n_samples
|
| 91 |
+
self.datapoints = None
|
| 92 |
+
|
| 93 |
+
def load_datapoints(self):
|
| 94 |
+
features = self.load_feature()
|
| 95 |
+
self.datapoints = []
|
| 96 |
+
|
| 97 |
+
with open(self.smiles_file) as f:
|
| 98 |
+
reader = csv.reader(f)
|
| 99 |
+
next(reader)
|
| 100 |
+
for i, line in enumerate(reader):
|
| 101 |
+
# line = line[0]
|
| 102 |
+
d = MoleculeDatapoint(line=line,
|
| 103 |
+
features=features[i])
|
| 104 |
+
self.datapoints.append(d)
|
| 105 |
+
|
| 106 |
+
assert len(self.datapoints) == self.n_samples
|
| 107 |
+
|
| 108 |
+
def load_feature(self):
|
| 109 |
+
return feautils.load_features(self.feature_file)
|
| 110 |
+
|
| 111 |
+
def shuffle(self):
|
| 112 |
+
pass
|
| 113 |
+
|
| 114 |
+
def clean_cache(self):
|
| 115 |
+
del self.datapoints
|
| 116 |
+
self.datapoints = None
|
| 117 |
+
|
| 118 |
+
def __len__(self):
|
| 119 |
+
return self.n_samples
|
| 120 |
+
|
| 121 |
+
def __getitem__(self, idx):
|
| 122 |
+
assert self.datapoints is not None
|
| 123 |
+
return self.datapoints[idx]
|
| 124 |
+
|
| 125 |
+
def is_loaded(self):
|
| 126 |
+
return self.datapoints is not None
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class BatchMolDataset(Dataset):
|
| 130 |
+
def __init__(self, data: List[BatchDatapoint],
|
| 131 |
+
graph_per_file=None):
|
| 132 |
+
self.data = data
|
| 133 |
+
|
| 134 |
+
self.len = 0
|
| 135 |
+
for d in self.data:
|
| 136 |
+
self.len += len(d)
|
| 137 |
+
if graph_per_file is not None:
|
| 138 |
+
self.sample_per_file = graph_per_file
|
| 139 |
+
else:
|
| 140 |
+
self.sample_per_file = len(self.data[0]) if len(self.data) != 0 else None
|
| 141 |
+
|
| 142 |
+
def shuffle(self, seed: int = None):
|
| 143 |
+
pass
|
| 144 |
+
|
| 145 |
+
def clean_cache(self):
|
| 146 |
+
for d in self.data:
|
| 147 |
+
d.clean_cache()
|
| 148 |
+
|
| 149 |
+
def __len__(self) -> int:
|
| 150 |
+
return self.len
|
| 151 |
+
|
| 152 |
+
def __getitem__(self, idx) -> Union[MoleculeDatapoint, List[MoleculeDatapoint]]:
|
| 153 |
+
# print(idx)
|
| 154 |
+
dp_idx = int(idx / self.sample_per_file)
|
| 155 |
+
real_idx = idx % self.sample_per_file
|
| 156 |
+
return self.data[dp_idx][real_idx]
|
| 157 |
+
|
| 158 |
+
def load_data(self, idx):
|
| 159 |
+
dp_idx = int(idx / self.sample_per_file)
|
| 160 |
+
if not self.data[dp_idx].is_loaded():
|
| 161 |
+
self.data[dp_idx].load_datapoints()
|
| 162 |
+
|
| 163 |
+
def count_loaded_datapoints(self):
|
| 164 |
+
res = 0
|
| 165 |
+
for d in self.data:
|
| 166 |
+
if d.is_loaded():
|
| 167 |
+
res += 1
|
| 168 |
+
return res
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class GroverCollator(object):
|
| 172 |
+
def __init__(self, shared_dict, atom_vocab, bond_vocab, args):
|
| 173 |
+
self.args = args
|
| 174 |
+
self.shared_dict = shared_dict
|
| 175 |
+
self.atom_vocab = atom_vocab
|
| 176 |
+
self.bond_vocab = bond_vocab
|
| 177 |
+
|
| 178 |
+
def atom_random_mask(self, smiles_batch):
|
| 179 |
+
"""
|
| 180 |
+
Perform the random mask operation on atoms.
|
| 181 |
+
:param smiles_batch:
|
| 182 |
+
:return: The corresponding atom labels.
|
| 183 |
+
"""
|
| 184 |
+
# There is a zero padding.
|
| 185 |
+
vocab_label = [0]
|
| 186 |
+
percent = 0.15
|
| 187 |
+
for smi in smiles_batch:
|
| 188 |
+
mol = Chem.MolFromSmiles(smi)
|
| 189 |
+
mlabel = [0] * mol.GetNumAtoms()
|
| 190 |
+
n_mask = math.ceil(mol.GetNumAtoms() * percent)
|
| 191 |
+
perm = np.random.permutation(mol.GetNumAtoms())[:n_mask]
|
| 192 |
+
for p in perm:
|
| 193 |
+
atom = mol.GetAtomWithIdx(int(p))
|
| 194 |
+
mlabel[p] = self.atom_vocab.stoi.get(atom_to_vocab(mol, atom), self.atom_vocab.other_index)
|
| 195 |
+
|
| 196 |
+
vocab_label.extend(mlabel)
|
| 197 |
+
return vocab_label
|
| 198 |
+
|
| 199 |
+
def bond_random_mask(self, smiles_batch):
|
| 200 |
+
"""
|
| 201 |
+
Perform the random mask operaiion on bonds.
|
| 202 |
+
:param smiles_batch:
|
| 203 |
+
:return: The corresponding bond labels.
|
| 204 |
+
"""
|
| 205 |
+
# There is a zero padding.
|
| 206 |
+
vocab_label = [0]
|
| 207 |
+
percent = 0.15
|
| 208 |
+
for smi in smiles_batch:
|
| 209 |
+
mol = Chem.MolFromSmiles(smi)
|
| 210 |
+
nm_atoms = mol.GetNumAtoms()
|
| 211 |
+
nm_bonds = mol.GetNumBonds()
|
| 212 |
+
mlabel = []
|
| 213 |
+
n_mask = math.ceil(nm_bonds * percent)
|
| 214 |
+
perm = np.random.permutation(nm_bonds)[:n_mask]
|
| 215 |
+
virtual_bond_id = 0
|
| 216 |
+
for a1 in range(nm_atoms):
|
| 217 |
+
for a2 in range(a1 + 1, nm_atoms):
|
| 218 |
+
bond = mol.GetBondBetweenAtoms(a1, a2)
|
| 219 |
+
|
| 220 |
+
if bond is None:
|
| 221 |
+
continue
|
| 222 |
+
if virtual_bond_id in perm:
|
| 223 |
+
label = self.bond_vocab.stoi.get(bond_to_vocab(mol, bond), self.bond_vocab.other_index)
|
| 224 |
+
mlabel.extend([label])
|
| 225 |
+
else:
|
| 226 |
+
mlabel.extend([0])
|
| 227 |
+
|
| 228 |
+
virtual_bond_id += 1
|
| 229 |
+
# todo: might need to consider bond_drop_rate
|
| 230 |
+
# todo: double check reverse bond
|
| 231 |
+
vocab_label.extend(mlabel)
|
| 232 |
+
return vocab_label
|
| 233 |
+
|
| 234 |
+
def __call__(self, batch):
|
| 235 |
+
smiles_batch = [d.smiles for d in batch]
|
| 236 |
+
batchgraph = mol2graph(smiles_batch, self.shared_dict, self.args).get_components()
|
| 237 |
+
|
| 238 |
+
atom_vocab_label = torch.Tensor(self.atom_random_mask(smiles_batch)).long()
|
| 239 |
+
bond_vocab_label = torch.Tensor(self.bond_random_mask(smiles_batch)).long()
|
| 240 |
+
fgroup_label = torch.Tensor([d.features for d in batch]).float()
|
| 241 |
+
# may be some mask here
|
| 242 |
+
res = {"graph_input": batchgraph,
|
| 243 |
+
"targets": {"av_task": atom_vocab_label,
|
| 244 |
+
"bv_task": bond_vocab_label,
|
| 245 |
+
"fg_task": fgroup_label}
|
| 246 |
+
}
|
| 247 |
+
return res
|
grover/data/moldataset.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
The molecule dataset for finetuning.
|
| 3 |
+
This implementation is adapted from
|
| 4 |
+
https://github.com/chemprop/chemprop/blob/master/chemprop/data/data.py
|
| 5 |
+
"""
|
| 6 |
+
import random
|
| 7 |
+
from argparse import Namespace
|
| 8 |
+
from typing import Callable, List, Union
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
from rdkit import Chem
|
| 12 |
+
from torch.utils.data.dataset import Dataset
|
| 13 |
+
|
| 14 |
+
from grover.data.molfeaturegenerator import get_features_generator
|
| 15 |
+
from grover.data.scaler import StandardScaler
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class MoleculeDatapoint:
|
| 19 |
+
"""A MoleculeDatapoint contains a single molecule and its associated features and targets."""
|
| 20 |
+
|
| 21 |
+
def __init__(self,
|
| 22 |
+
line: List[str],
|
| 23 |
+
args: Namespace = None,
|
| 24 |
+
features: np.ndarray = None,
|
| 25 |
+
use_compound_names: bool = False):
|
| 26 |
+
"""
|
| 27 |
+
Initializes a MoleculeDatapoint, which contains a single molecule.
|
| 28 |
+
|
| 29 |
+
:param line: A list of strings generated by separating a line in a data CSV file by comma.
|
| 30 |
+
:param args: Arguments.
|
| 31 |
+
:param features: A numpy array containing additional features (ex. Morgan fingerprint).
|
| 32 |
+
:param use_compound_names: Whether the data CSV includes the compound name on each line.
|
| 33 |
+
"""
|
| 34 |
+
self.features_generator = None
|
| 35 |
+
self.args = None
|
| 36 |
+
if args is not None:
|
| 37 |
+
if hasattr(args, "features_generator"):
|
| 38 |
+
self.features_generator = args.features_generator
|
| 39 |
+
self.args = args
|
| 40 |
+
|
| 41 |
+
if features is not None and self.features_generator is not None:
|
| 42 |
+
raise ValueError('Currently cannot provide both loaded features and a features generator.')
|
| 43 |
+
|
| 44 |
+
self.features = features
|
| 45 |
+
|
| 46 |
+
if use_compound_names:
|
| 47 |
+
self.compound_name = line[0] # str
|
| 48 |
+
line = line[1:]
|
| 49 |
+
else:
|
| 50 |
+
self.compound_name = None
|
| 51 |
+
|
| 52 |
+
self.smiles = line[0] # str
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# Generate additional features if given a generator
|
| 56 |
+
if self.features_generator is not None:
|
| 57 |
+
self.features = []
|
| 58 |
+
mol = Chem.MolFromSmiles(self.smiles)
|
| 59 |
+
for fg in self.features_generator:
|
| 60 |
+
features_generator = get_features_generator(fg)
|
| 61 |
+
if mol is not None and mol.GetNumHeavyAtoms() > 0:
|
| 62 |
+
if fg in ['morgan', 'morgan_count']:
|
| 63 |
+
self.features.extend(features_generator(mol, num_bits=args.num_bits))
|
| 64 |
+
else:
|
| 65 |
+
self.features.extend(features_generator(mol))
|
| 66 |
+
|
| 67 |
+
self.features = np.array(self.features)
|
| 68 |
+
|
| 69 |
+
# Fix nans in features
|
| 70 |
+
if self.features is not None:
|
| 71 |
+
replace_token = 0
|
| 72 |
+
self.features = np.where(np.isnan(self.features), replace_token, self.features)
|
| 73 |
+
|
| 74 |
+
# Create targets
|
| 75 |
+
self.targets = [float(x) if x != '' else None for x in line[1:]]
|
| 76 |
+
|
| 77 |
+
def set_features(self, features: np.ndarray):
|
| 78 |
+
"""
|
| 79 |
+
Sets the features of the molecule.
|
| 80 |
+
|
| 81 |
+
:param features: A 1-D numpy array of features for the molecule.
|
| 82 |
+
"""
|
| 83 |
+
self.features = features
|
| 84 |
+
|
| 85 |
+
def num_tasks(self) -> int:
|
| 86 |
+
"""
|
| 87 |
+
Returns the number of prediction tasks.
|
| 88 |
+
|
| 89 |
+
:return: The number of tasks.
|
| 90 |
+
"""
|
| 91 |
+
return len(self.targets)
|
| 92 |
+
|
| 93 |
+
def set_targets(self, targets: List[float]):
|
| 94 |
+
"""
|
| 95 |
+
Sets the targets of a molecule.
|
| 96 |
+
|
| 97 |
+
:param targets: A list of floats containing the targets.
|
| 98 |
+
"""
|
| 99 |
+
self.targets = targets
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class MoleculeDataset(Dataset):
|
| 103 |
+
"""A MoleculeDataset contains a list of molecules and their associated features and targets."""
|
| 104 |
+
|
| 105 |
+
def __init__(self, data: List[MoleculeDatapoint]):
|
| 106 |
+
"""
|
| 107 |
+
Initializes a MoleculeDataset, which contains a list of MoleculeDatapoints (i.e. a list of molecules).
|
| 108 |
+
|
| 109 |
+
:param data: A list of MoleculeDatapoints.
|
| 110 |
+
"""
|
| 111 |
+
self.data = data
|
| 112 |
+
self.args = self.data[0].args if len(self.data) > 0 else None
|
| 113 |
+
self.scaler = None
|
| 114 |
+
|
| 115 |
+
def compound_names(self) -> List[str]:
|
| 116 |
+
"""
|
| 117 |
+
Returns the compound names associated with the molecule (if they exist).
|
| 118 |
+
|
| 119 |
+
:return: A list of compound names or None if the dataset does not contain compound names.
|
| 120 |
+
"""
|
| 121 |
+
if len(self.data) == 0 or self.data[0].compound_name is None:
|
| 122 |
+
return None
|
| 123 |
+
|
| 124 |
+
return [d.compound_name for d in self.data]
|
| 125 |
+
|
| 126 |
+
def smiles(self) -> List[str]:
|
| 127 |
+
"""
|
| 128 |
+
Returns the smiles strings associated with the molecules.
|
| 129 |
+
|
| 130 |
+
:return: A list of smiles strings.
|
| 131 |
+
"""
|
| 132 |
+
return [d.smiles for d in self.data]
|
| 133 |
+
|
| 134 |
+
def features(self) -> List[np.ndarray]:
|
| 135 |
+
"""
|
| 136 |
+
Returns the features associated with each molecule (if they exist).
|
| 137 |
+
|
| 138 |
+
:return: A list of 1D numpy arrays containing the features for each molecule or None if there are no features.
|
| 139 |
+
"""
|
| 140 |
+
if len(self.data) == 0 or self.data[0].features is None:
|
| 141 |
+
return None
|
| 142 |
+
|
| 143 |
+
return [d.features for d in self.data]
|
| 144 |
+
|
| 145 |
+
def targets(self) -> List[List[float]]:
|
| 146 |
+
"""
|
| 147 |
+
Returns the targets associated with each molecule.
|
| 148 |
+
|
| 149 |
+
:return: A list of lists of floats containing the targets.
|
| 150 |
+
"""
|
| 151 |
+
return [d.targets for d in self.data]
|
| 152 |
+
|
| 153 |
+
def num_tasks(self) -> int:
|
| 154 |
+
"""
|
| 155 |
+
Returns the number of prediction tasks.
|
| 156 |
+
|
| 157 |
+
:return: The number of tasks.
|
| 158 |
+
"""
|
| 159 |
+
if self.args.dataset_type == 'multiclass':
|
| 160 |
+
return int(max([i[0] for i in self.targets()])) + 1
|
| 161 |
+
else:
|
| 162 |
+
return self.data[0].num_tasks() if len(self.data) > 0 else None
|
| 163 |
+
|
| 164 |
+
def features_size(self) -> int:
|
| 165 |
+
"""
|
| 166 |
+
Returns the size of the features array associated with each molecule.
|
| 167 |
+
|
| 168 |
+
:return: The size of the features.
|
| 169 |
+
"""
|
| 170 |
+
return len(self.data[0].features) if len(self.data) > 0 and self.data[0].features is not None else None
|
| 171 |
+
|
| 172 |
+
def shuffle(self, seed: int = None):
|
| 173 |
+
"""
|
| 174 |
+
Shuffles the dataset.
|
| 175 |
+
|
| 176 |
+
:param seed: Optional random seed.
|
| 177 |
+
"""
|
| 178 |
+
if seed is not None:
|
| 179 |
+
random.seed(seed)
|
| 180 |
+
random.shuffle(self.data)
|
| 181 |
+
|
| 182 |
+
def normalize_features(self, scaler: StandardScaler = None, replace_nan_token: int = 0) -> StandardScaler:
|
| 183 |
+
"""
|
| 184 |
+
Normalizes the features of the dataset using a StandardScaler (subtract mean, divide by standard deviation).
|
| 185 |
+
|
| 186 |
+
If a scaler is provided, uses that scaler to perform the normalization. Otherwise fits a scaler to the
|
| 187 |
+
features in the dataset and then performs the normalization.
|
| 188 |
+
|
| 189 |
+
:param scaler: A fitted StandardScaler. Used if provided. Otherwise a StandardScaler is fit on
|
| 190 |
+
this dataset and is then used.
|
| 191 |
+
:param replace_nan_token: What to replace nans with.
|
| 192 |
+
:return: A fitted StandardScaler. If a scaler is provided, this is the same scaler. Otherwise, this is
|
| 193 |
+
a scaler fit on this dataset.
|
| 194 |
+
"""
|
| 195 |
+
if len(self.data) == 0 or self.data[0].features is None:
|
| 196 |
+
return None
|
| 197 |
+
|
| 198 |
+
if scaler is not None:
|
| 199 |
+
self.scaler = scaler
|
| 200 |
+
|
| 201 |
+
elif self.scaler is None:
|
| 202 |
+
features = np.vstack([d.features for d in self.data])
|
| 203 |
+
self.scaler = StandardScaler(replace_nan_token=replace_nan_token)
|
| 204 |
+
self.scaler.fit(features)
|
| 205 |
+
|
| 206 |
+
for d in self.data:
|
| 207 |
+
d.set_features(self.scaler.transform(d.features.reshape(1, -1))[0])
|
| 208 |
+
|
| 209 |
+
return self.scaler
|
| 210 |
+
|
| 211 |
+
def set_targets(self, targets: List[List[float]]):
|
| 212 |
+
"""
|
| 213 |
+
Sets the targets for each molecule in the dataset. Assumes the targets are aligned with the datapoints.
|
| 214 |
+
|
| 215 |
+
:param targets: A list of lists of floats containing targets for each molecule. This must be the
|
| 216 |
+
same length as the underlying dataset.
|
| 217 |
+
"""
|
| 218 |
+
assert len(self.data) == len(targets)
|
| 219 |
+
for i in range(len(self.data)):
|
| 220 |
+
self.data[i].set_targets(targets[i])
|
| 221 |
+
|
| 222 |
+
def sort(self, key: Callable):
|
| 223 |
+
"""
|
| 224 |
+
Sorts the dataset using the provided key.
|
| 225 |
+
|
| 226 |
+
:param key: A function on a MoleculeDatapoint to determine the sorting order.
|
| 227 |
+
"""
|
| 228 |
+
self.data.sort(key=key)
|
| 229 |
+
|
| 230 |
+
def __len__(self) -> int:
|
| 231 |
+
"""
|
| 232 |
+
Returns the length of the dataset (i.e. the number of molecules).
|
| 233 |
+
|
| 234 |
+
:return: The length of the dataset.
|
| 235 |
+
"""
|
| 236 |
+
return len(self.data)
|
| 237 |
+
|
| 238 |
+
def __getitem__(self, idx) -> Union[MoleculeDatapoint, List[MoleculeDatapoint]]:
|
| 239 |
+
"""
|
| 240 |
+
Gets one or more MoleculeDatapoints via an index or slice.
|
| 241 |
+
|
| 242 |
+
:param item: An index (int) or a slice object.
|
| 243 |
+
:return: A MoleculeDatapoint if an int is provided or a list of MoleculeDatapoints if a slice is provided.
|
| 244 |
+
"""
|
| 245 |
+
return self.data[idx]
|
grover/data/molfeaturegenerator.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
The registered feature generator for molecules.
|
| 3 |
+
This implementation is adapted from
|
| 4 |
+
https://github.com/chemprop/chemprop/blob/master/chemprop/features/features_generators.py
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from typing import Callable, List, Union
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
from rdkit import Chem, DataStructs
|
| 11 |
+
from rdkit.Chem import AllChem
|
| 12 |
+
|
| 13 |
+
Molecule = Union[str, Chem.Mol]
|
| 14 |
+
FeaturesGenerator = Callable[[Molecule], np.ndarray]
|
| 15 |
+
FEATURES_GENERATOR_REGISTRY = {}
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def register_features_generator(features_generator_name: str) -> Callable[[FeaturesGenerator], FeaturesGenerator]:
|
| 19 |
+
"""
|
| 20 |
+
Registers a features generator.
|
| 21 |
+
|
| 22 |
+
:param features_generator_name: The name to call the FeaturesGenerator.
|
| 23 |
+
:return: A decorator which will add a FeaturesGenerator to the registry using the specified name.
|
| 24 |
+
"""
|
| 25 |
+
def decorator(features_generator: FeaturesGenerator) -> FeaturesGenerator:
|
| 26 |
+
FEATURES_GENERATOR_REGISTRY[features_generator_name] = features_generator
|
| 27 |
+
return features_generator
|
| 28 |
+
|
| 29 |
+
return decorator
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def get_features_generator(features_generator_name: str) -> FeaturesGenerator:
|
| 33 |
+
"""
|
| 34 |
+
Gets a registered FeaturesGenerator by name.
|
| 35 |
+
|
| 36 |
+
:param features_generator_name: The name of the FeaturesGenerator.
|
| 37 |
+
:return: The desired FeaturesGenerator.
|
| 38 |
+
"""
|
| 39 |
+
if features_generator_name not in FEATURES_GENERATOR_REGISTRY:
|
| 40 |
+
raise ValueError(f'Features generator "{features_generator_name}" could not be found. '
|
| 41 |
+
f'If this generator relies on rdkit features, you may need to install descriptastorus.')
|
| 42 |
+
|
| 43 |
+
return FEATURES_GENERATOR_REGISTRY[features_generator_name]
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def get_available_features_generators() -> List[str]:
|
| 47 |
+
"""Returns the names of available features generators."""
|
| 48 |
+
return list(FEATURES_GENERATOR_REGISTRY.keys())
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
MORGAN_RADIUS = 2
|
| 52 |
+
MORGAN_NUM_BITS = 2048
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@register_features_generator('morgan')
|
| 56 |
+
def morgan_binary_features_generator(mol: Molecule,
|
| 57 |
+
radius: int = MORGAN_RADIUS,
|
| 58 |
+
num_bits: int = MORGAN_NUM_BITS) -> np.ndarray:
|
| 59 |
+
"""
|
| 60 |
+
Generates a binary Morgan fingerprint for a molecule.
|
| 61 |
+
|
| 62 |
+
:param mol: A molecule (i.e. either a SMILES string or an RDKit molecule).
|
| 63 |
+
:param radius: Morgan fingerprint radius.
|
| 64 |
+
:param num_bits: Number of bits in Morgan fingerprint.
|
| 65 |
+
:return: A 1-D numpy array containing the binary Morgan fingerprint.
|
| 66 |
+
"""
|
| 67 |
+
mol = Chem.MolFromSmiles(mol) if type(mol) == str else mol
|
| 68 |
+
features_vec = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=num_bits)
|
| 69 |
+
features = np.zeros((1,))
|
| 70 |
+
DataStructs.ConvertToNumpyArray(features_vec, features)
|
| 71 |
+
|
| 72 |
+
return features
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@register_features_generator('morgan_count')
|
| 76 |
+
def morgan_counts_features_generator(mol: Molecule,
|
| 77 |
+
radius: int = MORGAN_RADIUS,
|
| 78 |
+
num_bits: int = MORGAN_NUM_BITS) -> np.ndarray:
|
| 79 |
+
"""
|
| 80 |
+
Generates a counts-based Morgan fingerprint for a molecule.
|
| 81 |
+
|
| 82 |
+
:param mol: A molecule (i.e. either a SMILES string or an RDKit molecule).
|
| 83 |
+
:param radius: Morgan fingerprint radius.
|
| 84 |
+
:param num_bits: Number of bits in Morgan fingerprint.
|
| 85 |
+
:return: A 1D numpy array containing the counts-based Morgan fingerprint.
|
| 86 |
+
"""
|
| 87 |
+
mol = Chem.MolFromSmiles(mol) if type(mol) == str else mol
|
| 88 |
+
features_vec = AllChem.GetHashedMorganFingerprint(mol, radius, nBits=num_bits)
|
| 89 |
+
features = np.zeros((1,))
|
| 90 |
+
DataStructs.ConvertToNumpyArray(features_vec, features)
|
| 91 |
+
|
| 92 |
+
return features
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
try:
|
| 96 |
+
from descriptastorus.descriptors import rdDescriptors, rdNormalizedDescriptors
|
| 97 |
+
|
| 98 |
+
@register_features_generator('rdkit_2d')
|
| 99 |
+
def rdkit_2d_features_generator(mol: Molecule) -> np.ndarray:
|
| 100 |
+
"""
|
| 101 |
+
Generates RDKit 2D features for a molecule.
|
| 102 |
+
|
| 103 |
+
:param mol: A molecule (i.e. either a SMILES string or an RDKit molecule).
|
| 104 |
+
:return: A 1D numpy array containing the RDKit 2D features.
|
| 105 |
+
"""
|
| 106 |
+
smiles = Chem.MolToSmiles(mol, isomericSmiles=True) if type(mol) != str else mol
|
| 107 |
+
generator = rdDescriptors.RDKit2D()
|
| 108 |
+
features = generator.process(smiles)[1:]
|
| 109 |
+
|
| 110 |
+
return features
|
| 111 |
+
|
| 112 |
+
@register_features_generator('rdkit_2d_normalized')
|
| 113 |
+
def rdkit_2d_features_normalized_generator(mol: Molecule) -> np.ndarray:
|
| 114 |
+
"""
|
| 115 |
+
Generates RDKit 2D normalized features for a molecule.
|
| 116 |
+
|
| 117 |
+
:param mol: A molecule (i.e. either a SMILES string or an RDKit molecule).
|
| 118 |
+
:return: A 1D numpy array containing the RDKit 2D normalized features.
|
| 119 |
+
"""
|
| 120 |
+
smiles = Chem.MolToSmiles(mol, isomericSmiles=True) if type(mol) != str else mol
|
| 121 |
+
generator = rdNormalizedDescriptors.RDKit2DNormalized()
|
| 122 |
+
features = generator.process(smiles)[1:]
|
| 123 |
+
return features
|
| 124 |
+
except ImportError:
|
| 125 |
+
pass
|
| 126 |
+
|
| 127 |
+
"""
|
| 128 |
+
Custom features generator template.
|
| 129 |
+
|
| 130 |
+
Note: The name you use to register the features generator is the name
|
| 131 |
+
you will specify on the command line when using the --features_generator <name> flag.
|
| 132 |
+
Ex. python train.py ... --features_generator custom ...
|
| 133 |
+
|
| 134 |
+
@register_features_generator('custom')
|
| 135 |
+
def custom_features_generator(mol: Molecule) -> np.ndarray:
|
| 136 |
+
# If you want to use the SMILES string
|
| 137 |
+
smiles = Chem.MolToSmiles(mol, isomericSmiles=True) if type(mol) != str else mol
|
| 138 |
+
|
| 139 |
+
# If you want to use the RDKit molecule
|
| 140 |
+
mol = Chem.MolFromSmiles(mol) if type(mol) == str else mol
|
| 141 |
+
|
| 142 |
+
# Replace this with code which generates features from the molecule
|
| 143 |
+
features = np.array([0, 0, 1])
|
| 144 |
+
|
| 145 |
+
return features
|
| 146 |
+
"""
|
grover/data/molgraph.py
ADDED
|
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
The data structure of Molecules.
|
| 3 |
+
This implementation is adapted from
|
| 4 |
+
https://github.com/chemprop/chemprop/blob/master/chemprop/features/featurization.py
|
| 5 |
+
"""
|
| 6 |
+
from argparse import Namespace
|
| 7 |
+
from typing import List, Tuple, Union
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
from rdkit import Chem
|
| 12 |
+
|
| 13 |
+
# Atom feature sizes
|
| 14 |
+
MAX_ATOMIC_NUM = 100
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
ATOM_FEATURES = {
|
| 18 |
+
'atomic_num': list(range(MAX_ATOMIC_NUM)),
|
| 19 |
+
'degree': [0, 1, 2, 3, 4, 5],
|
| 20 |
+
'formal_charge': [-1, -2, 1, 2, 0],
|
| 21 |
+
'chiral_tag': [0, 1, 2, 3],
|
| 22 |
+
'num_Hs': [0, 1, 2, 3, 4],
|
| 23 |
+
'hybridization': [
|
| 24 |
+
Chem.rdchem.HybridizationType.SP,
|
| 25 |
+
Chem.rdchem.HybridizationType.SP2,
|
| 26 |
+
Chem.rdchem.HybridizationType.SP3,
|
| 27 |
+
Chem.rdchem.HybridizationType.SP3D,
|
| 28 |
+
Chem.rdchem.HybridizationType.SP3D2
|
| 29 |
+
],
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
# len(choices) + 1 to include room for uncommon values; + 2 at end for IsAromatic and mass
|
| 33 |
+
ATOM_FDIM = sum(len(choices) + 1 for choices in ATOM_FEATURES.values()) + 2
|
| 34 |
+
BOND_FDIM = 14
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def get_atom_fdim() -> int:
|
| 38 |
+
"""
|
| 39 |
+
Gets the dimensionality of atom features.
|
| 40 |
+
|
| 41 |
+
:param: Arguments.
|
| 42 |
+
"""
|
| 43 |
+
return ATOM_FDIM + 18
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def get_bond_fdim() -> int:
|
| 47 |
+
"""
|
| 48 |
+
Gets the dimensionality of bond features.
|
| 49 |
+
|
| 50 |
+
:param: Arguments.
|
| 51 |
+
"""
|
| 52 |
+
return BOND_FDIM
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def onek_encoding_unk(value: int, choices: List[int]) -> List[int]:
|
| 56 |
+
"""
|
| 57 |
+
Creates a one-hot encoding.
|
| 58 |
+
|
| 59 |
+
:param value: The value for which the encoding should be one.
|
| 60 |
+
:param choices: A list of possible values.
|
| 61 |
+
:return: A one-hot encoding of the value in a list of length len(choices) + 1.
|
| 62 |
+
If value is not in the list of choices, then the final element in the encoding is 1.
|
| 63 |
+
"""
|
| 64 |
+
encoding = [0] * (len(choices) + 1)
|
| 65 |
+
if min(choices) < 0:
|
| 66 |
+
index = value
|
| 67 |
+
else:
|
| 68 |
+
index = choices.index(value) if value in choices else -1
|
| 69 |
+
encoding[index] = 1
|
| 70 |
+
|
| 71 |
+
return encoding
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class MolGraph:
|
| 77 |
+
"""
|
| 78 |
+
A MolGraph represents the graph structure and featurization of a single molecule.
|
| 79 |
+
|
| 80 |
+
A MolGraph computes the following attributes:
|
| 81 |
+
- smiles: Smiles string.
|
| 82 |
+
- n_atoms: The number of atoms in the molecule.
|
| 83 |
+
- n_bonds: The number of bonds in the molecule.
|
| 84 |
+
- f_atoms: A mapping from an atom index to a list atom features.
|
| 85 |
+
- f_bonds: A mapping from a bond index to a list of bond features.
|
| 86 |
+
- a2b: A mapping from an atom index to a list of incoming bond indices.
|
| 87 |
+
- b2a: A mapping from a bond index to the index of the atom the bond originates from.
|
| 88 |
+
- b2revb: A mapping from a bond index to the index of the reverse bond.
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
def __init__(self, smiles: str, args: Namespace):
|
| 92 |
+
"""
|
| 93 |
+
Computes the graph structure and featurization of a molecule.
|
| 94 |
+
|
| 95 |
+
:param smiles: A smiles string.
|
| 96 |
+
:param args: Arguments.
|
| 97 |
+
"""
|
| 98 |
+
self.smiles = smiles
|
| 99 |
+
self.args = args
|
| 100 |
+
self.n_atoms = 0 # number of atoms
|
| 101 |
+
self.n_bonds = 0 # number of bonds
|
| 102 |
+
self.f_atoms = [] # mapping from atom index to atom features
|
| 103 |
+
self.f_bonds = [] # mapping from bond index to concat(in_atom, bond) features
|
| 104 |
+
self.a2b = [] # mapping from atom index to incoming bond indices
|
| 105 |
+
self.b2a = [] # mapping from bond index to the index of the atom the bond is coming from
|
| 106 |
+
self.b2revb = [] # mapping from bond index to the index of the reverse bond
|
| 107 |
+
|
| 108 |
+
# Convert smiles to molecule
|
| 109 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 110 |
+
|
| 111 |
+
self.hydrogen_donor = Chem.MolFromSmarts("[$([N;!H0;v3,v4&+1]),$([O,S;H1;+0]),n&H1&+0]")
|
| 112 |
+
self.hydrogen_acceptor = Chem.MolFromSmarts(
|
| 113 |
+
"[$([O,S;H1;v2;!$(*-*=[O,N,P,S])]),$([O,S;H0;v2]),$([O,S;-]),$([N;v3;!$(N-*=[O,N,P,S])]),"
|
| 114 |
+
"n&H0&+0,$([o,s;+0;!$([o,s]:n);!$([o,s]:c:n)])]")
|
| 115 |
+
self.acidic = Chem.MolFromSmarts("[$([C,S](=[O,S,P])-[O;H1,-1])]")
|
| 116 |
+
self.basic = Chem.MolFromSmarts(
|
| 117 |
+
"[#7;+,$([N;H2&+0][$([C,a]);!$([C,a](=O))]),$([N;H1&+0]([$([C,a]);!$([C,a](=O))])[$([C,a]);"
|
| 118 |
+
"!$([C,a](=O))]),$([N;H0&+0]([C;!$(C(=O))])([C;!$(C(=O))])[C;!$(C(=O))])]")
|
| 119 |
+
|
| 120 |
+
self.hydrogen_donor_match = sum(mol.GetSubstructMatches(self.hydrogen_donor), ())
|
| 121 |
+
self.hydrogen_acceptor_match = sum(mol.GetSubstructMatches(self.hydrogen_acceptor), ())
|
| 122 |
+
self.acidic_match = sum(mol.GetSubstructMatches(self.acidic), ())
|
| 123 |
+
self.basic_match = sum(mol.GetSubstructMatches(self.basic), ())
|
| 124 |
+
self.ring_info = mol.GetRingInfo()
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
# fake the number of "atoms" if we are collapsing substructures
|
| 128 |
+
self.n_atoms = mol.GetNumAtoms()
|
| 129 |
+
|
| 130 |
+
# Get atom features
|
| 131 |
+
for _, atom in enumerate(mol.GetAtoms()):
|
| 132 |
+
self.f_atoms.append(self.atom_features(atom))
|
| 133 |
+
self.f_atoms = [self.f_atoms[i] for i in range(self.n_atoms)]
|
| 134 |
+
|
| 135 |
+
for _ in range(self.n_atoms):
|
| 136 |
+
self.a2b.append([])
|
| 137 |
+
|
| 138 |
+
# Get bond features
|
| 139 |
+
for a1 in range(self.n_atoms):
|
| 140 |
+
for a2 in range(a1 + 1, self.n_atoms):
|
| 141 |
+
bond = mol.GetBondBetweenAtoms(a1, a2)
|
| 142 |
+
|
| 143 |
+
if bond is None:
|
| 144 |
+
continue
|
| 145 |
+
|
| 146 |
+
if args.bond_drop_rate > 0:
|
| 147 |
+
if np.random.binomial(1, args.bond_drop_rate):
|
| 148 |
+
continue
|
| 149 |
+
|
| 150 |
+
f_bond = self.bond_features(bond)
|
| 151 |
+
|
| 152 |
+
# Always treat the bond as directed.
|
| 153 |
+
self.f_bonds.append(self.f_atoms[a1] + f_bond)
|
| 154 |
+
self.f_bonds.append(self.f_atoms[a2] + f_bond)
|
| 155 |
+
|
| 156 |
+
# Update index mappings
|
| 157 |
+
b1 = self.n_bonds
|
| 158 |
+
b2 = b1 + 1
|
| 159 |
+
self.a2b[a2].append(b1) # b1 = a1 --> a2
|
| 160 |
+
self.b2a.append(a1)
|
| 161 |
+
self.a2b[a1].append(b2) # b2 = a2 --> a1
|
| 162 |
+
self.b2a.append(a2)
|
| 163 |
+
self.b2revb.append(b2)
|
| 164 |
+
self.b2revb.append(b1)
|
| 165 |
+
self.n_bonds += 2
|
| 166 |
+
|
| 167 |
+
def atom_features(self, atom: Chem.rdchem.Atom) -> List[Union[bool, int, float]]:
|
| 168 |
+
"""
|
| 169 |
+
Builds a feature vector for an atom.
|
| 170 |
+
|
| 171 |
+
:param atom: An RDKit atom.
|
| 172 |
+
:param functional_groups: A k-hot vector indicating the functional groups the atom belongs to.
|
| 173 |
+
:return: A list containing the atom features.
|
| 174 |
+
"""
|
| 175 |
+
features = onek_encoding_unk(atom.GetAtomicNum() - 1, ATOM_FEATURES['atomic_num']) + \
|
| 176 |
+
onek_encoding_unk(atom.GetTotalDegree(), ATOM_FEATURES['degree']) + \
|
| 177 |
+
onek_encoding_unk(atom.GetFormalCharge(), ATOM_FEATURES['formal_charge']) + \
|
| 178 |
+
onek_encoding_unk(int(atom.GetChiralTag()), ATOM_FEATURES['chiral_tag']) + \
|
| 179 |
+
onek_encoding_unk(int(atom.GetTotalNumHs()), ATOM_FEATURES['num_Hs']) + \
|
| 180 |
+
onek_encoding_unk(int(atom.GetHybridization()), ATOM_FEATURES['hybridization']) + \
|
| 181 |
+
[1 if atom.GetIsAromatic() else 0] + \
|
| 182 |
+
[atom.GetMass() * 0.01]
|
| 183 |
+
atom_idx = atom.GetIdx()
|
| 184 |
+
features = features + \
|
| 185 |
+
onek_encoding_unk(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6]) + \
|
| 186 |
+
[atom_idx in self.hydrogen_acceptor_match] + \
|
| 187 |
+
[atom_idx in self.hydrogen_donor_match] + \
|
| 188 |
+
[atom_idx in self.acidic_match] + \
|
| 189 |
+
[atom_idx in self.basic_match] + \
|
| 190 |
+
[self.ring_info.IsAtomInRingOfSize(atom_idx, 3),
|
| 191 |
+
self.ring_info.IsAtomInRingOfSize(atom_idx, 4),
|
| 192 |
+
self.ring_info.IsAtomInRingOfSize(atom_idx, 5),
|
| 193 |
+
self.ring_info.IsAtomInRingOfSize(atom_idx, 6),
|
| 194 |
+
self.ring_info.IsAtomInRingOfSize(atom_idx, 7),
|
| 195 |
+
self.ring_info.IsAtomInRingOfSize(atom_idx, 8)]
|
| 196 |
+
return features
|
| 197 |
+
|
| 198 |
+
def bond_features(self, bond: Chem.rdchem.Bond
|
| 199 |
+
) -> List[Union[bool, int, float]]:
|
| 200 |
+
"""
|
| 201 |
+
Builds a feature vector for a bond.
|
| 202 |
+
|
| 203 |
+
:param bond: A RDKit bond.
|
| 204 |
+
:return: A list containing the bond features.
|
| 205 |
+
"""
|
| 206 |
+
|
| 207 |
+
if bond is None:
|
| 208 |
+
fbond = [1] + [0] * (BOND_FDIM - 1)
|
| 209 |
+
else:
|
| 210 |
+
bt = bond.GetBondType()
|
| 211 |
+
fbond = [
|
| 212 |
+
0, # bond is not None
|
| 213 |
+
bt == Chem.rdchem.BondType.SINGLE,
|
| 214 |
+
bt == Chem.rdchem.BondType.DOUBLE,
|
| 215 |
+
bt == Chem.rdchem.BondType.TRIPLE,
|
| 216 |
+
bt == Chem.rdchem.BondType.AROMATIC,
|
| 217 |
+
(bond.GetIsConjugated() if bt is not None else 0),
|
| 218 |
+
(bond.IsInRing() if bt is not None else 0)
|
| 219 |
+
]
|
| 220 |
+
fbond += onek_encoding_unk(int(bond.GetStereo()), list(range(6)))
|
| 221 |
+
return fbond
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
class BatchMolGraph:
|
| 225 |
+
"""
|
| 226 |
+
A BatchMolGraph represents the graph structure and featurization of a batch of molecules.
|
| 227 |
+
|
| 228 |
+
A BatchMolGraph contains the attributes of a MolGraph plus:
|
| 229 |
+
- smiles_batch: A list of smiles strings.
|
| 230 |
+
- n_mols: The number of molecules in the batch.
|
| 231 |
+
- atom_fdim: The dimensionality of the atom features.
|
| 232 |
+
- bond_fdim: The dimensionality of the bond features (technically the combined atom/bond features).
|
| 233 |
+
- a_scope: A list of tuples indicating the start and end atom indices for each molecule.
|
| 234 |
+
- b_scope: A list of tuples indicating the start and end bond indices for each molecule.
|
| 235 |
+
- max_num_bonds: The maximum number of bonds neighboring an atom in this batch.
|
| 236 |
+
- b2b: (Optional) A mapping from a bond index to incoming bond indices.
|
| 237 |
+
- a2a: (Optional): A mapping from an atom index to neighboring atom indices.
|
| 238 |
+
"""
|
| 239 |
+
|
| 240 |
+
def __init__(self, mol_graphs: List[MolGraph], args: Namespace):
|
| 241 |
+
self.smiles_batch = [mol_graph.smiles for mol_graph in mol_graphs]
|
| 242 |
+
self.n_mols = len(self.smiles_batch)
|
| 243 |
+
|
| 244 |
+
self.atom_fdim = get_atom_fdim()
|
| 245 |
+
self.bond_fdim = get_bond_fdim() + self.atom_fdim
|
| 246 |
+
|
| 247 |
+
# Start n_atoms and n_bonds at 1 b/c zero padding
|
| 248 |
+
self.n_atoms = 1 # number of atoms (start at 1 b/c need index 0 as padding)
|
| 249 |
+
self.n_bonds = 1 # number of bonds (start at 1 b/c need index 0 as padding)
|
| 250 |
+
self.a_scope = [] # list of tuples indicating (start_atom_index, num_atoms) for each molecule
|
| 251 |
+
self.b_scope = [] # list of tuples indicating (start_bond_index, num_bonds) for each molecule
|
| 252 |
+
|
| 253 |
+
# All start with zero padding so that indexing with zero padding returns zeros
|
| 254 |
+
f_atoms = [[0] * self.atom_fdim] # atom features
|
| 255 |
+
f_bonds = [[0] * self.bond_fdim] # combined atom/bond features
|
| 256 |
+
a2b = [[]] # mapping from atom index to incoming bond indices
|
| 257 |
+
b2a = [0] # mapping from bond index to the index of the atom the bond is coming from
|
| 258 |
+
b2revb = [0] # mapping from bond index to the index of the reverse bond
|
| 259 |
+
|
| 260 |
+
for mol_graph in mol_graphs:
|
| 261 |
+
f_atoms.extend(mol_graph.f_atoms)
|
| 262 |
+
f_bonds.extend(mol_graph.f_bonds)
|
| 263 |
+
|
| 264 |
+
for a in range(mol_graph.n_atoms):
|
| 265 |
+
a2b.append([b + self.n_bonds for b in mol_graph.a2b[a]])
|
| 266 |
+
|
| 267 |
+
for b in range(mol_graph.n_bonds):
|
| 268 |
+
b2a.append(self.n_atoms + mol_graph.b2a[b])
|
| 269 |
+
b2revb.append(self.n_bonds + mol_graph.b2revb[b])
|
| 270 |
+
|
| 271 |
+
self.a_scope.append((self.n_atoms, mol_graph.n_atoms))
|
| 272 |
+
self.b_scope.append((self.n_bonds, mol_graph.n_bonds))
|
| 273 |
+
self.n_atoms += mol_graph.n_atoms
|
| 274 |
+
self.n_bonds += mol_graph.n_bonds
|
| 275 |
+
|
| 276 |
+
# max with 1 to fix a crash in rare case of all single-heavy-atom mols
|
| 277 |
+
self.max_num_bonds = max(1, max(len(in_bonds) for in_bonds in a2b))
|
| 278 |
+
|
| 279 |
+
self.f_atoms = torch.FloatTensor(f_atoms)
|
| 280 |
+
self.f_bonds = torch.FloatTensor(f_bonds)
|
| 281 |
+
self.a2b = torch.LongTensor([a2b[a] + [0] * (self.max_num_bonds - len(a2b[a])) for a in range(self.n_atoms)])
|
| 282 |
+
self.b2a = torch.LongTensor(b2a)
|
| 283 |
+
self.b2revb = torch.LongTensor(b2revb)
|
| 284 |
+
self.b2b = None # try to avoid computing b2b b/c O(n_atoms^3)
|
| 285 |
+
self.a2a = self.b2a[self.a2b] # only needed if using atom messages
|
| 286 |
+
self.a_scope = torch.LongTensor(self.a_scope)
|
| 287 |
+
self.b_scope = torch.LongTensor(self.b_scope)
|
| 288 |
+
|
| 289 |
+
def set_new_atom_feature(self, f_atoms):
|
| 290 |
+
"""
|
| 291 |
+
Set the new atom feature. Do not update bond feature.
|
| 292 |
+
:param f_atoms:
|
| 293 |
+
"""
|
| 294 |
+
self.f_atoms = f_atoms
|
| 295 |
+
|
| 296 |
+
def get_components(self) -> Tuple[torch.FloatTensor, torch.FloatTensor,
|
| 297 |
+
torch.LongTensor, torch.LongTensor, torch.LongTensor,
|
| 298 |
+
List[Tuple[int, int]], List[Tuple[int, int]]]:
|
| 299 |
+
"""
|
| 300 |
+
Returns the components of the BatchMolGraph.
|
| 301 |
+
|
| 302 |
+
:return: A tuple containing PyTorch tensors with the atom features, bond features, and graph structure
|
| 303 |
+
and two lists indicating the scope of the atoms and bonds (i.e. which molecules they belong to).
|
| 304 |
+
"""
|
| 305 |
+
return self.f_atoms, self.f_bonds, self.a2b, self.b2a, self.b2revb, self.a_scope, self.b_scope, self.a2a
|
| 306 |
+
|
| 307 |
+
def get_b2b(self) -> torch.LongTensor:
|
| 308 |
+
"""
|
| 309 |
+
Computes (if necessary) and returns a mapping from each bond index to all the incoming bond indices.
|
| 310 |
+
|
| 311 |
+
:return: A PyTorch tensor containing the mapping from each bond index to all the incoming bond indices.
|
| 312 |
+
"""
|
| 313 |
+
|
| 314 |
+
if self.b2b is None:
|
| 315 |
+
b2b = self.a2b[self.b2a] # num_bonds x max_num_bonds
|
| 316 |
+
# b2b includes reverse edge for each bond so need to mask out
|
| 317 |
+
revmask = (b2b != self.b2revb.unsqueeze(1).repeat(1, b2b.size(1))).long() # num_bonds x max_num_bonds
|
| 318 |
+
self.b2b = b2b * revmask
|
| 319 |
+
|
| 320 |
+
return self.b2b
|
| 321 |
+
|
| 322 |
+
def get_a2a(self) -> torch.LongTensor:
|
| 323 |
+
"""
|
| 324 |
+
Computes (if necessary) and returns a mapping from each atom index to all neighboring atom indices.
|
| 325 |
+
|
| 326 |
+
:return: A PyTorch tensor containing the mapping from each bond index to all the incodming bond indices.
|
| 327 |
+
"""
|
| 328 |
+
if self.a2a is None:
|
| 329 |
+
# b = a1 --> a2
|
| 330 |
+
# a2b maps a2 to all incoming bonds b
|
| 331 |
+
# b2a maps each bond b to the atom it comes from a1
|
| 332 |
+
# thus b2a[a2b] maps atom a2 to neighboring atoms a1
|
| 333 |
+
self.a2a = self.b2a[self.a2b] # num_atoms x max_num_bonds
|
| 334 |
+
|
| 335 |
+
return self.a2a
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def mol2graph(smiles_batch: List[str], shared_dict,
|
| 339 |
+
args: Namespace) -> BatchMolGraph:
|
| 340 |
+
"""
|
| 341 |
+
Converts a list of SMILES strings to a BatchMolGraph containing the batch of molecular graphs.
|
| 342 |
+
|
| 343 |
+
:param smiles_batch: A list of SMILES strings.
|
| 344 |
+
:param args: Arguments.
|
| 345 |
+
:return: A BatchMolGraph containing the combined molecular graph for the molecules
|
| 346 |
+
"""
|
| 347 |
+
mol_graphs = []
|
| 348 |
+
for smiles in smiles_batch:
|
| 349 |
+
if smiles in shared_dict:
|
| 350 |
+
mol_graph = shared_dict[smiles]
|
| 351 |
+
else:
|
| 352 |
+
mol_graph = MolGraph(smiles, args)
|
| 353 |
+
if not args.no_cache:
|
| 354 |
+
shared_dict[smiles] = mol_graph
|
| 355 |
+
mol_graphs.append(mol_graph)
|
| 356 |
+
|
| 357 |
+
return BatchMolGraph(mol_graphs, args)
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
class MolCollator(object):
|
| 361 |
+
"""
|
| 362 |
+
Collator for pytorch dataloader
|
| 363 |
+
:param shared_dict: a shared dict of multiprocess.
|
| 364 |
+
:param args: Arguments.
|
| 365 |
+
"""
|
| 366 |
+
def __init__(self, shared_dict, args):
|
| 367 |
+
self.args = args
|
| 368 |
+
self.shared_dict = shared_dict
|
| 369 |
+
|
| 370 |
+
def __call__(self, batch):
|
| 371 |
+
smiles_batch = [d.smiles for d in batch]
|
| 372 |
+
features_batch = [d.features for d in batch]
|
| 373 |
+
target_batch = [d.targets for d in batch]
|
| 374 |
+
batch_mol_graph = mol2graph(smiles_batch, self.shared_dict, self.args)
|
| 375 |
+
batch = batch_mol_graph.get_components()
|
| 376 |
+
mask = torch.Tensor([[x is not None for x in tb] for tb in target_batch])
|
| 377 |
+
targets = torch.Tensor([[0 if x is None else x for x in tb] for tb in target_batch])
|
| 378 |
+
return smiles_batch, batch, features_batch, mask, targets
|
grover/data/scaler.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
The scaler for the regression task.
|
| 3 |
+
This implementation is adapted from
|
| 4 |
+
https://github.com/chemprop/chemprop/blob/master/chemprop/data/scaler.py
|
| 5 |
+
"""
|
| 6 |
+
from typing import Any, List
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class StandardScaler:
|
| 11 |
+
"""A StandardScaler normalizes a dataset.
|
| 12 |
+
|
| 13 |
+
When fit on a dataset, the StandardScaler learns the mean and standard deviation across the 0th axis.
|
| 14 |
+
When transforming a dataset, the StandardScaler subtracts the means and divides by the standard deviations.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, means: np.ndarray = None, stds: np.ndarray = None, replace_nan_token: Any = None):
|
| 18 |
+
"""
|
| 19 |
+
Initialize StandardScaler, optionally with means and standard deviations precomputed.
|
| 20 |
+
|
| 21 |
+
:param means: An optional 1D numpy array of precomputed means.
|
| 22 |
+
:param stds: An optional 1D numpy array of precomputed standard deviations.
|
| 23 |
+
:param replace_nan_token: The token to use in place of nans.
|
| 24 |
+
"""
|
| 25 |
+
self.means = means
|
| 26 |
+
self.stds = stds
|
| 27 |
+
self.replace_nan_token = replace_nan_token
|
| 28 |
+
|
| 29 |
+
def fit(self, X: List[List[float]]) -> 'StandardScaler':
|
| 30 |
+
"""
|
| 31 |
+
Learns means and standard deviations across the 0th axis.
|
| 32 |
+
|
| 33 |
+
:param X: A list of lists of floats.
|
| 34 |
+
:return: The fitted StandardScaler.
|
| 35 |
+
"""
|
| 36 |
+
X = np.array(X).astype(float)
|
| 37 |
+
self.means = np.nanmean(X, axis=0)
|
| 38 |
+
self.stds = np.nanstd(X, axis=0)
|
| 39 |
+
self.means = np.where(np.isnan(self.means), np.zeros(self.means.shape), self.means)
|
| 40 |
+
self.stds = np.where(np.isnan(self.stds), np.ones(self.stds.shape), self.stds)
|
| 41 |
+
self.stds = np.where(self.stds == 0, np.ones(self.stds.shape), self.stds)
|
| 42 |
+
|
| 43 |
+
return self
|
| 44 |
+
|
| 45 |
+
def transform(self, X: List[List[float]]):
|
| 46 |
+
"""
|
| 47 |
+
Transforms the data by subtracting the means and dividing by the standard deviations.
|
| 48 |
+
|
| 49 |
+
:param X: A list of lists of floats.
|
| 50 |
+
:return: The transformed data.
|
| 51 |
+
"""
|
| 52 |
+
X = np.array(X).astype(float)
|
| 53 |
+
transformed_with_nan = (X - self.means) / self.stds
|
| 54 |
+
transformed_with_none = np.where(np.isnan(transformed_with_nan), self.replace_nan_token, transformed_with_nan)
|
| 55 |
+
|
| 56 |
+
return transformed_with_none
|
| 57 |
+
|
| 58 |
+
def inverse_transform(self, X: List[List[float]]):
|
| 59 |
+
"""
|
| 60 |
+
Performs the inverse transformation by multiplying by the standard deviations and adding the means.
|
| 61 |
+
|
| 62 |
+
:param X: A list of lists of floats.
|
| 63 |
+
:return: The inverse transformed data.
|
| 64 |
+
"""
|
| 65 |
+
if isinstance(X, np.ndarray) or isinstance(X, list):
|
| 66 |
+
X = np.array(X).astype(float)
|
| 67 |
+
transformed_with_nan = X * self.stds + self.means
|
| 68 |
+
transformed_with_none = np.where(np.isnan(transformed_with_nan),
|
| 69 |
+
self.replace_nan_token, transformed_with_nan)
|
| 70 |
+
return transformed_with_none
|
grover/data/task_labels.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
The label generator for the pretraining.
|
| 3 |
+
"""
|
| 4 |
+
from collections import Counter
|
| 5 |
+
from typing import Callable, Union
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
from rdkit import Chem
|
| 9 |
+
from descriptastorus.descriptors import rdDescriptors
|
| 10 |
+
|
| 11 |
+
from grover.data.molfeaturegenerator import register_features_generator
|
| 12 |
+
|
| 13 |
+
Molecule = Union[str, Chem.Mol]
|
| 14 |
+
FeaturesGenerator = Callable[[Molecule], np.ndarray]
|
| 15 |
+
|
| 16 |
+
# The functional group descriptors in RDkit.
|
| 17 |
+
RDKIT_PROPS = ['fr_Al_COO', 'fr_Al_OH', 'fr_Al_OH_noTert', 'fr_ArN',
|
| 18 |
+
'fr_Ar_COO', 'fr_Ar_N', 'fr_Ar_NH', 'fr_Ar_OH', 'fr_COO', 'fr_COO2',
|
| 19 |
+
'fr_C_O', 'fr_C_O_noCOO', 'fr_C_S', 'fr_HOCCN', 'fr_Imine', 'fr_NH0',
|
| 20 |
+
'fr_NH1', 'fr_NH2', 'fr_N_O', 'fr_Ndealkylation1', 'fr_Ndealkylation2',
|
| 21 |
+
'fr_Nhpyrrole', 'fr_SH', 'fr_aldehyde', 'fr_alkyl_carbamate', 'fr_alkyl_halide',
|
| 22 |
+
'fr_allylic_oxid', 'fr_amide', 'fr_amidine', 'fr_aniline', 'fr_aryl_methyl',
|
| 23 |
+
'fr_azide', 'fr_azo', 'fr_barbitur', 'fr_benzene', 'fr_benzodiazepine',
|
| 24 |
+
'fr_bicyclic', 'fr_diazo', 'fr_dihydropyridine', 'fr_epoxide', 'fr_ester',
|
| 25 |
+
'fr_ether', 'fr_furan', 'fr_guanido', 'fr_halogen', 'fr_hdrzine', 'fr_hdrzone',
|
| 26 |
+
'fr_imidazole', 'fr_imide', 'fr_isocyan', 'fr_isothiocyan', 'fr_ketone',
|
| 27 |
+
'fr_ketone_Topliss', 'fr_lactam', 'fr_lactone', 'fr_methoxy', 'fr_morpholine',
|
| 28 |
+
'fr_nitrile', 'fr_nitro', 'fr_nitro_arom', 'fr_nitro_arom_nonortho',
|
| 29 |
+
'fr_nitroso', 'fr_oxazole', 'fr_oxime', 'fr_para_hydroxylation', 'fr_phenol',
|
| 30 |
+
'fr_phenol_noOrthoHbond', 'fr_phos_acid', 'fr_phos_ester', 'fr_piperdine',
|
| 31 |
+
'fr_piperzine', 'fr_priamide', 'fr_prisulfonamd', 'fr_pyridine', 'fr_quatN',
|
| 32 |
+
'fr_sulfide', 'fr_sulfonamd', 'fr_sulfone', 'fr_term_acetylene', 'fr_tetrazole',
|
| 33 |
+
'fr_thiazole', 'fr_thiocyan', 'fr_thiophene', 'fr_unbrch_alkane', 'fr_urea']
|
| 34 |
+
|
| 35 |
+
BOND_FEATURES = ['BondType', 'Stereo', 'BondDir']
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# BOND_FEATURES = ['BondType', 'Stereo']
|
| 39 |
+
# BOND_FEATURES = ['Stereo']
|
| 40 |
+
|
| 41 |
+
@register_features_generator('fgtasklabel')
|
| 42 |
+
def rdkit_functional_group_label_features_generator(mol: Molecule) -> np.ndarray:
|
| 43 |
+
"""
|
| 44 |
+
Generates functional group label for a molecule using RDKit.
|
| 45 |
+
|
| 46 |
+
:param mol: A molecule (i.e. either a SMILES string or an RDKit molecule).
|
| 47 |
+
:return: A 1D numpy array containing the RDKit 2D features.
|
| 48 |
+
"""
|
| 49 |
+
smiles = Chem.MolToSmiles(mol, isomericSmiles=True) if type(mol) != str else mol
|
| 50 |
+
generator = rdDescriptors.RDKit2D(RDKIT_PROPS)
|
| 51 |
+
features = generator.process(smiles)[1:]
|
| 52 |
+
features = np.array(features)
|
| 53 |
+
features[features != 0] = 1
|
| 54 |
+
return features
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def atom_to_vocab(mol, atom):
|
| 58 |
+
"""
|
| 59 |
+
Convert atom to vocabulary. The convention is based on atom type and bond type.
|
| 60 |
+
:param mol: the molecular.
|
| 61 |
+
:param atom: the target atom.
|
| 62 |
+
:return: the generated atom vocabulary with its contexts.
|
| 63 |
+
"""
|
| 64 |
+
nei = Counter()
|
| 65 |
+
for a in atom.GetNeighbors():
|
| 66 |
+
bond = mol.GetBondBetweenAtoms(atom.GetIdx(), a.GetIdx())
|
| 67 |
+
nei[str(a.GetSymbol()) + "-" + str(bond.GetBondType())] += 1
|
| 68 |
+
keys = nei.keys()
|
| 69 |
+
keys = list(keys)
|
| 70 |
+
keys.sort()
|
| 71 |
+
output = atom.GetSymbol()
|
| 72 |
+
for k in keys:
|
| 73 |
+
output = "%s_%s%d" % (output, k, nei[k])
|
| 74 |
+
|
| 75 |
+
# The generated atom_vocab is too long?
|
| 76 |
+
return output
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def bond_to_vocab(mol, bond):
|
| 80 |
+
"""
|
| 81 |
+
Convert bond to vocabulary. The convention is based on atom type and bond type.
|
| 82 |
+
Considering one-hop neighbor atoms
|
| 83 |
+
:param mol: the molecular.
|
| 84 |
+
:param atom: the target atom.
|
| 85 |
+
:return: the generated bond vocabulary with its contexts.
|
| 86 |
+
"""
|
| 87 |
+
nei = Counter()
|
| 88 |
+
two_neighbors = (bond.GetBeginAtom(), bond.GetEndAtom())
|
| 89 |
+
two_indices = [a.GetIdx() for a in two_neighbors]
|
| 90 |
+
for nei_atom in two_neighbors:
|
| 91 |
+
for a in nei_atom.GetNeighbors():
|
| 92 |
+
a_idx = a.GetIdx()
|
| 93 |
+
if a_idx in two_indices:
|
| 94 |
+
continue
|
| 95 |
+
tmp_bond = mol.GetBondBetweenAtoms(nei_atom.GetIdx(), a_idx)
|
| 96 |
+
nei[str(nei_atom.GetSymbol()) + '-' + get_bond_feature_name(tmp_bond)] += 1
|
| 97 |
+
keys = list(nei.keys())
|
| 98 |
+
keys.sort()
|
| 99 |
+
output = get_bond_feature_name(bond)
|
| 100 |
+
for k in keys:
|
| 101 |
+
output = "%s_%s%d" % (output, k, nei[k])
|
| 102 |
+
return output
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def get_bond_feature_name(bond):
|
| 106 |
+
"""
|
| 107 |
+
Return the string format of bond features.
|
| 108 |
+
Bond features are surrounded with ()
|
| 109 |
+
|
| 110 |
+
"""
|
| 111 |
+
ret = []
|
| 112 |
+
for bond_feature in BOND_FEATURES:
|
| 113 |
+
fea = eval(f"bond.Get{bond_feature}")()
|
| 114 |
+
ret.append(str(fea))
|
| 115 |
+
|
| 116 |
+
return '(' + '-'.join(ret) + ')'
|
grover/data/torchvocab.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
The contextual property.
|
| 3 |
+
"""
|
| 4 |
+
import pickle
|
| 5 |
+
from collections import Counter
|
| 6 |
+
from multiprocessing import Pool
|
| 7 |
+
|
| 8 |
+
import tqdm
|
| 9 |
+
from rdkit import Chem
|
| 10 |
+
|
| 11 |
+
from grover.data.task_labels import atom_to_vocab
|
| 12 |
+
from grover.data.task_labels import bond_to_vocab
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class TorchVocab(object):
|
| 16 |
+
"""
|
| 17 |
+
Defines the vocabulary for atoms/bonds in molecular.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(self, counter, max_size=None, min_freq=1, specials=('<pad>', '<other>'), vocab_type='atom'):
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
:param counter:
|
| 24 |
+
:param max_size:
|
| 25 |
+
:param min_freq:
|
| 26 |
+
:param specials:
|
| 27 |
+
:param vocab_type: 'atom': atom atom_vocab; 'bond': bond atom_vocab.
|
| 28 |
+
"""
|
| 29 |
+
self.freqs = counter
|
| 30 |
+
counter = counter.copy()
|
| 31 |
+
min_freq = max(min_freq, 1)
|
| 32 |
+
if vocab_type in ('atom', 'bond'):
|
| 33 |
+
self.vocab_type = vocab_type
|
| 34 |
+
else:
|
| 35 |
+
raise ValueError('Wrong input for vocab_type!')
|
| 36 |
+
self.itos = list(specials)
|
| 37 |
+
|
| 38 |
+
max_size = None if max_size is None else max_size + len(self.itos)
|
| 39 |
+
# sort by frequency, then alphabetically
|
| 40 |
+
words_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0])
|
| 41 |
+
words_and_frequencies.sort(key=lambda tup: tup[1], reverse=True)
|
| 42 |
+
|
| 43 |
+
for word, freq in words_and_frequencies:
|
| 44 |
+
if freq < min_freq or len(self.itos) == max_size:
|
| 45 |
+
break
|
| 46 |
+
self.itos.append(word)
|
| 47 |
+
# stoi is simply a reverse dict for itos
|
| 48 |
+
self.stoi = {tok: i for i, tok in enumerate(self.itos)}
|
| 49 |
+
self.other_index = 1
|
| 50 |
+
self.pad_index = 0
|
| 51 |
+
|
| 52 |
+
def __eq__(self, other):
|
| 53 |
+
if self.freqs != other.freqs:
|
| 54 |
+
return False
|
| 55 |
+
if self.stoi != other.stoi:
|
| 56 |
+
return False
|
| 57 |
+
if self.itos != other.itos:
|
| 58 |
+
return False
|
| 59 |
+
# if self.vectors != other.vectors:
|
| 60 |
+
# return False
|
| 61 |
+
return True
|
| 62 |
+
|
| 63 |
+
def __len__(self):
|
| 64 |
+
return len(self.itos)
|
| 65 |
+
|
| 66 |
+
def vocab_rerank(self):
|
| 67 |
+
self.stoi = {word: i for i, word in enumerate(self.itos)}
|
| 68 |
+
|
| 69 |
+
def extend(self, v, sort=False):
|
| 70 |
+
words = sorted(v.itos) if sort else v.itos
|
| 71 |
+
for w in words:
|
| 72 |
+
if w not in self.stoi:
|
| 73 |
+
self.itos.append(w)
|
| 74 |
+
self.stoi[w] = len(self.itos) - 1
|
| 75 |
+
self.freqs[w] = 0
|
| 76 |
+
self.freqs[w] += v.freqs[w]
|
| 77 |
+
|
| 78 |
+
def mol_to_seq(self, mol, with_len=False):
|
| 79 |
+
mol = Chem.MolFromSmiles(mol) if type(mol) == str else mol
|
| 80 |
+
if self.vocab_type == 'atom':
|
| 81 |
+
seq = [self.stoi.get(atom_to_vocab(mol, atom), self.other_index) for i, atom in enumerate(mol.GetAtoms())]
|
| 82 |
+
else:
|
| 83 |
+
seq = [self.stoi.get(bond_to_vocab(mol, bond), self.other_index) for i, bond in enumerate(mol.GetBonds())]
|
| 84 |
+
return (seq, len(seq)) if with_len else seq
|
| 85 |
+
|
| 86 |
+
@staticmethod
|
| 87 |
+
def load_vocab(vocab_path: str) -> 'Vocab':
|
| 88 |
+
with open(vocab_path, "rb") as f:
|
| 89 |
+
return pickle.load(f)
|
| 90 |
+
|
| 91 |
+
def save_vocab(self, vocab_path):
|
| 92 |
+
with open(vocab_path, "wb") as f:
|
| 93 |
+
pickle.dump(self, f)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class MolVocab(TorchVocab):
|
| 97 |
+
def __init__(self, smiles, max_size=None, min_freq=1, vocab_type='atom'):
|
| 98 |
+
if vocab_type in ('atom', 'bond'):
|
| 99 |
+
self.vocab_type = vocab_type
|
| 100 |
+
else:
|
| 101 |
+
raise ValueError('Wrong input for vocab_type!')
|
| 102 |
+
|
| 103 |
+
print("Building %s vocab from smiles: %d" % (self.vocab_type, len(smiles)))
|
| 104 |
+
counter = Counter()
|
| 105 |
+
|
| 106 |
+
for smi in tqdm.tqdm(smiles):
|
| 107 |
+
mol = Chem.MolFromSmiles(smi)
|
| 108 |
+
if self.vocab_type == 'atom':
|
| 109 |
+
for _, atom in enumerate(mol.GetAtoms()):
|
| 110 |
+
v = atom_to_vocab(mol, atom)
|
| 111 |
+
counter[v] += 1
|
| 112 |
+
else:
|
| 113 |
+
for _, bond in enumerate(mol.GetBonds()):
|
| 114 |
+
v = bond_to_vocab(mol, bond)
|
| 115 |
+
counter[v] += 1
|
| 116 |
+
super().__init__(counter, max_size=max_size, min_freq=min_freq, vocab_type=vocab_type)
|
| 117 |
+
|
| 118 |
+
def __init__(self, file_path, max_size=None, min_freq=1, num_workers=1, total_lines=None, vocab_type='atom'):
|
| 119 |
+
if vocab_type in ('atom', 'bond'):
|
| 120 |
+
self.vocab_type = vocab_type
|
| 121 |
+
else:
|
| 122 |
+
raise ValueError('Wrong input for vocab_type!')
|
| 123 |
+
print("Building %s vocab from file: %s" % (self.vocab_type, file_path))
|
| 124 |
+
|
| 125 |
+
from rdkit import RDLogger
|
| 126 |
+
lg = RDLogger.logger()
|
| 127 |
+
lg.setLevel(RDLogger.CRITICAL)
|
| 128 |
+
|
| 129 |
+
if total_lines is None:
|
| 130 |
+
def file_len(fname):
|
| 131 |
+
f_len = 0
|
| 132 |
+
with open(fname) as f:
|
| 133 |
+
for f_len, _ in enumerate(f):
|
| 134 |
+
pass
|
| 135 |
+
return f_len + 1
|
| 136 |
+
|
| 137 |
+
total_lines = file_len(file_path)
|
| 138 |
+
|
| 139 |
+
counter = Counter()
|
| 140 |
+
pbar = tqdm.tqdm(total=total_lines)
|
| 141 |
+
pool = Pool(num_workers)
|
| 142 |
+
res = []
|
| 143 |
+
batch = 50000
|
| 144 |
+
callback = lambda a: pbar.update(batch)
|
| 145 |
+
for i in range(int(total_lines / batch + 1)):
|
| 146 |
+
start = int(batch * i)
|
| 147 |
+
end = min(total_lines, batch * (i + 1))
|
| 148 |
+
# print("Start: %d, End: %d"%(start, end))
|
| 149 |
+
res.append(pool.apply_async(MolVocab.read_smiles_from_file,
|
| 150 |
+
args=(file_path, start, end, vocab_type,),
|
| 151 |
+
callback=callback))
|
| 152 |
+
# read_smiles_from_file(lock, file_path, start, end)
|
| 153 |
+
pool.close()
|
| 154 |
+
pool.join()
|
| 155 |
+
for r in res:
|
| 156 |
+
sub_counter = r.get()
|
| 157 |
+
for k in sub_counter:
|
| 158 |
+
if k not in counter:
|
| 159 |
+
counter[k] = 0
|
| 160 |
+
counter[k] += sub_counter[k]
|
| 161 |
+
# print(counter)
|
| 162 |
+
super().__init__(counter, max_size=max_size, min_freq=min_freq, vocab_type=vocab_type)
|
| 163 |
+
|
| 164 |
+
@staticmethod
|
| 165 |
+
def read_smiles_from_file(file_path, start, end, vocab_type):
|
| 166 |
+
# print("start")
|
| 167 |
+
smiles = open(file_path, "r")
|
| 168 |
+
smiles.readline()
|
| 169 |
+
sub_counter = Counter()
|
| 170 |
+
for i, smi in enumerate(smiles):
|
| 171 |
+
if i < start:
|
| 172 |
+
continue
|
| 173 |
+
if i >= end:
|
| 174 |
+
break
|
| 175 |
+
mol = Chem.MolFromSmiles(smi)
|
| 176 |
+
if vocab_type == 'atom':
|
| 177 |
+
for atom in mol.GetAtoms():
|
| 178 |
+
v = atom_to_vocab(mol, atom)
|
| 179 |
+
sub_counter[v] += 1
|
| 180 |
+
else:
|
| 181 |
+
for bond in mol.GetBonds():
|
| 182 |
+
v = bond_to_vocab(mol, bond)
|
| 183 |
+
sub_counter[v] += 1
|
| 184 |
+
# print("end")
|
| 185 |
+
return sub_counter
|
| 186 |
+
|
| 187 |
+
@staticmethod
|
| 188 |
+
def load_vocab(vocab_path: str) -> 'MolVocab':
|
| 189 |
+
with open(vocab_path, "rb") as f:
|
| 190 |
+
return pickle.load(f)
|
grover/model/layers.py
ADDED
|
@@ -0,0 +1,902 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
The basic building blocks in model.
|
| 3 |
+
"""
|
| 4 |
+
import math
|
| 5 |
+
from argparse import Namespace
|
| 6 |
+
from typing import Union
|
| 7 |
+
|
| 8 |
+
import numpy
|
| 9 |
+
import scipy.stats as stats
|
| 10 |
+
import torch
|
| 11 |
+
from torch import nn as nn
|
| 12 |
+
from torch.nn import LayerNorm, functional as F
|
| 13 |
+
|
| 14 |
+
from grover.util.nn_utils import get_activation_function, select_neighbor_and_aggregate
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class SelfAttention(nn.Module):
|
| 18 |
+
"""
|
| 19 |
+
Self SelfAttention Layer
|
| 20 |
+
Given $X\in \mathbb{R}^{n \times in_feature}$, the attention is calculated by: $a=Softmax(W_2tanh(W_1X))$, where
|
| 21 |
+
$W_1 \in \mathbb{R}^{hidden \times in_feature}$, $W_2 \in \mathbb{R}^{out_feature \times hidden}$.
|
| 22 |
+
The final output is: $out=aX$, which is unrelated with input $n$.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(self, *, hidden, in_feature, out_feature):
|
| 26 |
+
"""
|
| 27 |
+
The init function.
|
| 28 |
+
:param hidden: the hidden dimension, can be viewed as the number of experts.
|
| 29 |
+
:param in_feature: the input feature dimension.
|
| 30 |
+
:param out_feature: the output feature dimension.
|
| 31 |
+
"""
|
| 32 |
+
super(SelfAttention, self).__init__()
|
| 33 |
+
self.w1 = torch.nn.Parameter(torch.FloatTensor(hidden, in_feature))
|
| 34 |
+
self.w2 = torch.nn.Parameter(torch.FloatTensor(out_feature, hidden))
|
| 35 |
+
self.reset_parameters()
|
| 36 |
+
|
| 37 |
+
def reset_parameters(self):
|
| 38 |
+
"""
|
| 39 |
+
Use xavier_normal method to initialize parameters.
|
| 40 |
+
"""
|
| 41 |
+
nn.init.xavier_normal_(self.w1)
|
| 42 |
+
nn.init.xavier_normal_(self.w2)
|
| 43 |
+
|
| 44 |
+
def forward(self, X):
|
| 45 |
+
"""
|
| 46 |
+
The forward function.
|
| 47 |
+
:param X: The input feature map. $X \in \mathbb{R}^{n \times in_feature}$.
|
| 48 |
+
:return: The final embeddings and attention matrix.
|
| 49 |
+
"""
|
| 50 |
+
x = torch.tanh(torch.matmul(self.w1, X.transpose(1, 0)))
|
| 51 |
+
x = torch.matmul(self.w2, x)
|
| 52 |
+
attn = torch.nn.functional.softmax(x, dim=-1)
|
| 53 |
+
x = torch.matmul(attn, X)
|
| 54 |
+
return x, attn
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class Readout(nn.Module):
|
| 58 |
+
"""The readout function. Convert the node embeddings to the graph embeddings."""
|
| 59 |
+
|
| 60 |
+
def __init__(self,
|
| 61 |
+
rtype: str = "none",
|
| 62 |
+
hidden_size: int = 0,
|
| 63 |
+
attn_hidden: int = None,
|
| 64 |
+
attn_out: int = None,
|
| 65 |
+
):
|
| 66 |
+
"""
|
| 67 |
+
The readout function.
|
| 68 |
+
:param rtype: readout type, can be "mean" and "self_attention".
|
| 69 |
+
:param hidden_size: input hidden size
|
| 70 |
+
:param attn_hidden: only valid if rtype == "self_attention". The attention hidden size.
|
| 71 |
+
:param attn_out: only valid if rtype == "self_attention". The attention out size.
|
| 72 |
+
:param args: legacy use.
|
| 73 |
+
"""
|
| 74 |
+
super(Readout, self).__init__()
|
| 75 |
+
# Cached zeros
|
| 76 |
+
self.cached_zero_vector = nn.Parameter(torch.zeros(hidden_size), requires_grad=False)
|
| 77 |
+
self.rtype = "mean"
|
| 78 |
+
|
| 79 |
+
if rtype == "self_attention":
|
| 80 |
+
self.attn = SelfAttention(hidden=attn_hidden,
|
| 81 |
+
in_feature=hidden_size,
|
| 82 |
+
out_feature=attn_out)
|
| 83 |
+
self.rtype = "self_attention"
|
| 84 |
+
|
| 85 |
+
def forward(self, embeddings, scope):
|
| 86 |
+
"""
|
| 87 |
+
The forward function, given a batch node/edge embedding and a scope list,
|
| 88 |
+
produce the graph-level embedding by a scope.
|
| 89 |
+
:param embeddings: The embedding matrix, num_atoms or num_bonds \times hidden_size.
|
| 90 |
+
:param scope: a list, in which the element is a list [start, range]. `start` is the index
|
| 91 |
+
:return:
|
| 92 |
+
"""
|
| 93 |
+
# Readout
|
| 94 |
+
mol_vecs = []
|
| 95 |
+
self.attns = []
|
| 96 |
+
for _, (a_start, a_size) in enumerate(scope):
|
| 97 |
+
if a_size == 0:
|
| 98 |
+
mol_vecs.append(self.cached_zero_vector)
|
| 99 |
+
else:
|
| 100 |
+
cur_hiddens = embeddings.narrow(0, a_start, a_size)
|
| 101 |
+
if self.rtype == "self_attention":
|
| 102 |
+
cur_hiddens, attn = self.attn(cur_hiddens)
|
| 103 |
+
cur_hiddens = cur_hiddens.flatten()
|
| 104 |
+
# Temporarily disable. Enable it if you want to save attentions.
|
| 105 |
+
# self.attns.append(attn.cpu().detach().numpy())
|
| 106 |
+
else:
|
| 107 |
+
cur_hiddens = cur_hiddens.sum(dim=0) / a_size
|
| 108 |
+
mol_vecs.append(cur_hiddens)
|
| 109 |
+
|
| 110 |
+
mol_vecs = torch.stack(mol_vecs, dim=0) # (num_molecules, hidden_size)
|
| 111 |
+
return mol_vecs
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class MPNEncoder(nn.Module):
|
| 115 |
+
"""A message passing neural network for encoding a molecule."""
|
| 116 |
+
|
| 117 |
+
def __init__(self, args: Namespace,
|
| 118 |
+
atom_messages: bool,
|
| 119 |
+
init_message_dim: int,
|
| 120 |
+
attached_fea_fdim: int,
|
| 121 |
+
hidden_size: int,
|
| 122 |
+
bias: bool,
|
| 123 |
+
depth: int,
|
| 124 |
+
dropout: float,
|
| 125 |
+
undirected: bool,
|
| 126 |
+
dense: bool,
|
| 127 |
+
aggregate_to_atom: bool,
|
| 128 |
+
attach_fea: bool,
|
| 129 |
+
input_layer="fc",
|
| 130 |
+
dynamic_depth='none'
|
| 131 |
+
):
|
| 132 |
+
"""
|
| 133 |
+
Initializes the MPNEncoder.
|
| 134 |
+
:param args: the arguments.
|
| 135 |
+
:param atom_messages: enables atom_messages or not.
|
| 136 |
+
:param init_message_dim: the initial input message dimension.
|
| 137 |
+
:param attached_fea_fdim: the attached feature dimension.
|
| 138 |
+
:param hidden_size: the output message dimension during message passing.
|
| 139 |
+
:param bias: the bias in the message passing.
|
| 140 |
+
:param depth: the message passing depth.
|
| 141 |
+
:param dropout: the dropout rate.
|
| 142 |
+
:param undirected: the message passing is undirected or not.
|
| 143 |
+
:param dense: enables the dense connections.
|
| 144 |
+
:param attach_fea: enables the feature attachment during the message passing process.
|
| 145 |
+
:param dynamic_depth: enables the dynamic depth. Possible choices: "none", "uniform" and "truncnorm"
|
| 146 |
+
"""
|
| 147 |
+
super(MPNEncoder, self).__init__()
|
| 148 |
+
self.init_message_dim = init_message_dim
|
| 149 |
+
self.attached_fea_fdim = attached_fea_fdim
|
| 150 |
+
self.hidden_size = hidden_size
|
| 151 |
+
self.bias = bias
|
| 152 |
+
self.depth = depth
|
| 153 |
+
self.dropout = dropout
|
| 154 |
+
self.input_layer = input_layer
|
| 155 |
+
self.layers_per_message = 1
|
| 156 |
+
self.undirected = undirected
|
| 157 |
+
self.atom_messages = atom_messages
|
| 158 |
+
self.dense = dense
|
| 159 |
+
self.aggreate_to_atom = aggregate_to_atom
|
| 160 |
+
self.attached_fea = attach_fea
|
| 161 |
+
self.dynamic_depth = dynamic_depth
|
| 162 |
+
|
| 163 |
+
# Dropout
|
| 164 |
+
self.dropout_layer = nn.Dropout(p=self.dropout)
|
| 165 |
+
|
| 166 |
+
# Activation
|
| 167 |
+
self.act_func = get_activation_function(args.activation)
|
| 168 |
+
|
| 169 |
+
# Input
|
| 170 |
+
if self.input_layer == "fc":
|
| 171 |
+
input_dim = self.init_message_dim
|
| 172 |
+
self.W_i = nn.Linear(input_dim, self.hidden_size, bias=self.bias)
|
| 173 |
+
|
| 174 |
+
if self.attached_fea:
|
| 175 |
+
w_h_input_size = self.hidden_size + self.attached_fea_fdim
|
| 176 |
+
else:
|
| 177 |
+
w_h_input_size = self.hidden_size
|
| 178 |
+
|
| 179 |
+
# Shared weight matrix across depths (default)
|
| 180 |
+
self.W_h = nn.Linear(w_h_input_size, self.hidden_size, bias=self.bias)
|
| 181 |
+
|
| 182 |
+
def forward(self,
|
| 183 |
+
init_messages,
|
| 184 |
+
init_attached_features,
|
| 185 |
+
a2nei,
|
| 186 |
+
a2attached,
|
| 187 |
+
b2a=None,
|
| 188 |
+
b2revb=None,
|
| 189 |
+
adjs=None
|
| 190 |
+
) -> torch.FloatTensor:
|
| 191 |
+
"""
|
| 192 |
+
The forward function.
|
| 193 |
+
:param init_messages: initial massages, can be atom features or bond features.
|
| 194 |
+
:param init_attached_features: initial attached_features.
|
| 195 |
+
:param a2nei: the relation of item to its neighbors. For the atom message passing, a2nei = a2a. For bond
|
| 196 |
+
messages a2nei = a2b
|
| 197 |
+
:param a2attached: the relation of item to the attached features during message passing. For the atom message
|
| 198 |
+
passing, a2attached = a2b. For the bond message passing a2attached = a2a
|
| 199 |
+
:param b2a: remove the reversed bond in bond message passing
|
| 200 |
+
:param b2revb: remove the revered atom in bond message passing
|
| 201 |
+
:return: if aggreate_to_atom or self.atom_messages, return num_atoms x hidden.
|
| 202 |
+
Otherwise, return num_bonds x hidden
|
| 203 |
+
"""
|
| 204 |
+
|
| 205 |
+
# Input
|
| 206 |
+
if self.input_layer == 'fc':
|
| 207 |
+
input = self.W_i(init_messages) # num_bonds x hidden_size # f_bond
|
| 208 |
+
message = self.act_func(input) # num_bonds x hidden_size
|
| 209 |
+
elif self.input_layer == 'none':
|
| 210 |
+
input = init_messages
|
| 211 |
+
message = input
|
| 212 |
+
|
| 213 |
+
attached_fea = init_attached_features # f_atom / f_bond
|
| 214 |
+
|
| 215 |
+
# dynamic depth
|
| 216 |
+
# uniform sampling from depth - 1 to depth + 1
|
| 217 |
+
# only works in training.
|
| 218 |
+
if self.training and self.dynamic_depth != "none":
|
| 219 |
+
if self.dynamic_depth == "uniform":
|
| 220 |
+
# uniform sampling
|
| 221 |
+
ndepth = numpy.random.randint(self.depth - 3, self.depth + 3)
|
| 222 |
+
else:
|
| 223 |
+
# truncnorm
|
| 224 |
+
mu = self.depth
|
| 225 |
+
sigma = 1
|
| 226 |
+
lower = mu - 3 * sigma
|
| 227 |
+
upper = mu + 3 * sigma
|
| 228 |
+
X = stats.truncnorm((lower - mu) / sigma, (upper - mu) / sigma, loc=mu, scale=sigma)
|
| 229 |
+
ndepth = int(X.rvs(1))
|
| 230 |
+
else:
|
| 231 |
+
ndepth = self.depth
|
| 232 |
+
|
| 233 |
+
# Message passing
|
| 234 |
+
for _ in range(ndepth - 1):
|
| 235 |
+
if self.undirected:
|
| 236 |
+
# two directions should be the same
|
| 237 |
+
message = (message + message[b2revb]) / 2
|
| 238 |
+
|
| 239 |
+
nei_message = select_neighbor_and_aggregate(message, a2nei)
|
| 240 |
+
a_message = nei_message
|
| 241 |
+
if self.attached_fea:
|
| 242 |
+
attached_nei_fea = select_neighbor_and_aggregate(attached_fea, a2attached)
|
| 243 |
+
a_message = torch.cat((nei_message, attached_nei_fea), dim=1)
|
| 244 |
+
|
| 245 |
+
if not self.atom_messages:
|
| 246 |
+
rev_message = message[b2revb]
|
| 247 |
+
if self.attached_fea:
|
| 248 |
+
atom_rev_message = attached_fea[b2a[b2revb]]
|
| 249 |
+
rev_message = torch.cat((rev_message, atom_rev_message), dim=1)
|
| 250 |
+
# Except reverse bond its-self(w) ! \sum_{k\in N(u) \ w}
|
| 251 |
+
message = a_message[b2a] - rev_message # num_bonds x hidden
|
| 252 |
+
else:
|
| 253 |
+
message = a_message
|
| 254 |
+
|
| 255 |
+
message = self.W_h(message)
|
| 256 |
+
|
| 257 |
+
# BUG here, by default MPNEncoder use the dense connection in the message passing step.
|
| 258 |
+
# The correct form should if not self.dense
|
| 259 |
+
if self.dense:
|
| 260 |
+
message = self.act_func(message) # num_bonds x hidden_size
|
| 261 |
+
else:
|
| 262 |
+
message = self.act_func(input + message)
|
| 263 |
+
message = self.dropout_layer(message) # num_bonds x hidden
|
| 264 |
+
|
| 265 |
+
output = message
|
| 266 |
+
|
| 267 |
+
return output # num_atoms x hidden
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
class PositionwiseFeedForward(nn.Module):
|
| 271 |
+
"""Implements FFN equation."""
|
| 272 |
+
|
| 273 |
+
def __init__(self, d_model, d_ff, activation="PReLU", dropout=0.1, d_out=None):
|
| 274 |
+
"""Initialization.
|
| 275 |
+
|
| 276 |
+
:param d_model: the input dimension.
|
| 277 |
+
:param d_ff: the hidden dimension.
|
| 278 |
+
:param activation: the activation function.
|
| 279 |
+
:param dropout: the dropout rate.
|
| 280 |
+
:param d_out: the output dimension, the default value is equal to d_model.
|
| 281 |
+
"""
|
| 282 |
+
super(PositionwiseFeedForward, self).__init__()
|
| 283 |
+
if d_out is None:
|
| 284 |
+
d_out = d_model
|
| 285 |
+
# By default, bias is on.
|
| 286 |
+
self.W_1 = nn.Linear(d_model, d_ff)
|
| 287 |
+
self.W_2 = nn.Linear(d_ff, d_out)
|
| 288 |
+
self.dropout = nn.Dropout(dropout)
|
| 289 |
+
self.act_func = get_activation_function(activation)
|
| 290 |
+
|
| 291 |
+
def forward(self, x):
|
| 292 |
+
"""
|
| 293 |
+
The forward function
|
| 294 |
+
:param x: input tensor.
|
| 295 |
+
:return:
|
| 296 |
+
"""
|
| 297 |
+
return self.W_2(self.dropout(self.act_func(self.W_1(x))))
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
class SublayerConnection(nn.Module):
|
| 301 |
+
"""
|
| 302 |
+
A residual connection followed by a layer norm.
|
| 303 |
+
Note for code simplicity the norm is first as opposed to last.
|
| 304 |
+
"""
|
| 305 |
+
|
| 306 |
+
def __init__(self, size, dropout):
|
| 307 |
+
"""Initialization.
|
| 308 |
+
|
| 309 |
+
:param size: the input dimension.
|
| 310 |
+
:param dropout: the dropout ratio.
|
| 311 |
+
"""
|
| 312 |
+
super(SublayerConnection, self).__init__()
|
| 313 |
+
self.norm = LayerNorm(size, elementwise_affine=True)
|
| 314 |
+
self.dropout = nn.Dropout(dropout)
|
| 315 |
+
|
| 316 |
+
def forward(self, inputs, outputs):
|
| 317 |
+
"""Apply residual connection to any sublayer with the same size."""
|
| 318 |
+
# return x + self.dropout(self.norm(x))
|
| 319 |
+
if inputs is None:
|
| 320 |
+
return self.dropout(self.norm(outputs))
|
| 321 |
+
return inputs + self.dropout(self.norm(outputs))
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
class Attention(nn.Module):
|
| 325 |
+
"""
|
| 326 |
+
Compute 'Scaled Dot Product SelfAttention
|
| 327 |
+
"""
|
| 328 |
+
|
| 329 |
+
def forward(self, query, key, value, mask=None, dropout=None):
|
| 330 |
+
"""
|
| 331 |
+
:param query:
|
| 332 |
+
:param key:
|
| 333 |
+
:param value:
|
| 334 |
+
:param mask:
|
| 335 |
+
:param dropout:
|
| 336 |
+
:return:
|
| 337 |
+
"""
|
| 338 |
+
scores = torch.matmul(query, key.transpose(-2, -1)) \
|
| 339 |
+
/ math.sqrt(query.size(-1))
|
| 340 |
+
|
| 341 |
+
if mask is not None:
|
| 342 |
+
scores = scores.masked_fill(mask == 0, -1e9)
|
| 343 |
+
|
| 344 |
+
p_attn = F.softmax(scores, dim=-1)
|
| 345 |
+
|
| 346 |
+
if dropout is not None:
|
| 347 |
+
p_attn = dropout(p_attn)
|
| 348 |
+
|
| 349 |
+
return torch.matmul(p_attn, value), p_attn
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
class MultiHeadedAttention(nn.Module):
|
| 353 |
+
"""
|
| 354 |
+
The multi-head attention module. Take in model size and number of heads.
|
| 355 |
+
"""
|
| 356 |
+
|
| 357 |
+
def __init__(self, h, d_model, dropout=0.1, bias=False):
|
| 358 |
+
"""
|
| 359 |
+
|
| 360 |
+
:param h:
|
| 361 |
+
:param d_model:
|
| 362 |
+
:param dropout:
|
| 363 |
+
:param bias:
|
| 364 |
+
"""
|
| 365 |
+
super().__init__()
|
| 366 |
+
assert d_model % h == 0
|
| 367 |
+
|
| 368 |
+
# We assume d_v always equals d_k
|
| 369 |
+
self.d_k = d_model // h
|
| 370 |
+
self.h = h # number of heads
|
| 371 |
+
|
| 372 |
+
self.linear_layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(3)]) # why 3: query, key, value
|
| 373 |
+
self.output_linear = nn.Linear(d_model, d_model, bias)
|
| 374 |
+
self.attention = Attention()
|
| 375 |
+
|
| 376 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 377 |
+
|
| 378 |
+
def forward(self, query, key, value, mask=None):
|
| 379 |
+
"""
|
| 380 |
+
|
| 381 |
+
:param query:
|
| 382 |
+
:param key:
|
| 383 |
+
:param value:
|
| 384 |
+
:param mask:
|
| 385 |
+
:return:
|
| 386 |
+
"""
|
| 387 |
+
batch_size = query.size(0)
|
| 388 |
+
|
| 389 |
+
# 1) Do all the linear projections in batch from d_model => h x d_k
|
| 390 |
+
query, key, value = [l(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
|
| 391 |
+
for l, x in zip(self.linear_layers, (query, key, value))]
|
| 392 |
+
|
| 393 |
+
# 2) Apply attention on all the projected vectors in batch.
|
| 394 |
+
x, _ = self.attention(query, key, value, mask=mask, dropout=self.dropout)
|
| 395 |
+
|
| 396 |
+
# 3) "Concat" using a view and apply a final linear.
|
| 397 |
+
x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k)
|
| 398 |
+
|
| 399 |
+
return self.output_linear(x)
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
class Head(nn.Module):
|
| 403 |
+
"""
|
| 404 |
+
One head for multi-headed attention.
|
| 405 |
+
:return: (query, key, value)
|
| 406 |
+
"""
|
| 407 |
+
|
| 408 |
+
def __init__(self, args, hidden_size, atom_messages=False):
|
| 409 |
+
"""
|
| 410 |
+
Initialization.
|
| 411 |
+
:param args: The argument.
|
| 412 |
+
:param hidden_size: the dimension of hidden layer in Head.
|
| 413 |
+
:param atom_messages: the MPNEncoder type.
|
| 414 |
+
"""
|
| 415 |
+
super(Head, self).__init__()
|
| 416 |
+
atom_fdim = hidden_size
|
| 417 |
+
bond_fdim = hidden_size
|
| 418 |
+
hidden_size = hidden_size
|
| 419 |
+
self.atom_messages = atom_messages
|
| 420 |
+
if self.atom_messages:
|
| 421 |
+
init_message_dim = atom_fdim
|
| 422 |
+
attached_fea_dim = bond_fdim
|
| 423 |
+
else:
|
| 424 |
+
init_message_dim = bond_fdim
|
| 425 |
+
attached_fea_dim = atom_fdim
|
| 426 |
+
|
| 427 |
+
# Here we use the message passing network as query, key and value.
|
| 428 |
+
self.mpn_q = MPNEncoder(args=args,
|
| 429 |
+
atom_messages=atom_messages,
|
| 430 |
+
init_message_dim=init_message_dim,
|
| 431 |
+
attached_fea_fdim=attached_fea_dim,
|
| 432 |
+
hidden_size=hidden_size,
|
| 433 |
+
bias=args.bias,
|
| 434 |
+
depth=args.depth,
|
| 435 |
+
dropout=args.dropout,
|
| 436 |
+
undirected=args.undirected,
|
| 437 |
+
dense=args.dense,
|
| 438 |
+
aggregate_to_atom=False,
|
| 439 |
+
attach_fea=False,
|
| 440 |
+
input_layer="none",
|
| 441 |
+
dynamic_depth="truncnorm")
|
| 442 |
+
self.mpn_k = MPNEncoder(args=args,
|
| 443 |
+
atom_messages=atom_messages,
|
| 444 |
+
init_message_dim=init_message_dim,
|
| 445 |
+
attached_fea_fdim=attached_fea_dim,
|
| 446 |
+
hidden_size=hidden_size,
|
| 447 |
+
bias=args.bias,
|
| 448 |
+
depth=args.depth,
|
| 449 |
+
dropout=args.dropout,
|
| 450 |
+
undirected=args.undirected,
|
| 451 |
+
dense=args.dense,
|
| 452 |
+
aggregate_to_atom=False,
|
| 453 |
+
attach_fea=False,
|
| 454 |
+
input_layer="none",
|
| 455 |
+
dynamic_depth="truncnorm")
|
| 456 |
+
self.mpn_v = MPNEncoder(args=args,
|
| 457 |
+
atom_messages=atom_messages,
|
| 458 |
+
init_message_dim=init_message_dim,
|
| 459 |
+
attached_fea_fdim=attached_fea_dim,
|
| 460 |
+
hidden_size=hidden_size,
|
| 461 |
+
bias=args.bias,
|
| 462 |
+
depth=args.depth,
|
| 463 |
+
dropout=args.dropout,
|
| 464 |
+
undirected=args.undirected,
|
| 465 |
+
dense=args.dense,
|
| 466 |
+
aggregate_to_atom=False,
|
| 467 |
+
attach_fea=False,
|
| 468 |
+
input_layer="none",
|
| 469 |
+
dynamic_depth="truncnorm")
|
| 470 |
+
|
| 471 |
+
def forward(self, f_atoms, f_bonds, a2b, a2a, b2a, b2revb):
|
| 472 |
+
"""
|
| 473 |
+
The forward function.
|
| 474 |
+
:param f_atoms: the atom features, num_atoms * atom_dim
|
| 475 |
+
:param f_bonds: the bond features, num_bonds * bond_dim
|
| 476 |
+
:param a2b: mapping from atom index to incoming bond indices.
|
| 477 |
+
:param a2a: mapping from atom index to its neighbors. num_atoms * max_num_bonds
|
| 478 |
+
:param b2a: mapping from bond index to the index of the atom the bond is coming from.
|
| 479 |
+
:param b2revb: mapping from bond index to the index of the reverse bond.
|
| 480 |
+
:return:
|
| 481 |
+
"""
|
| 482 |
+
if self.atom_messages:
|
| 483 |
+
init_messages = f_atoms
|
| 484 |
+
init_attached_features = f_bonds
|
| 485 |
+
a2nei = a2a
|
| 486 |
+
a2attached = a2b
|
| 487 |
+
b2a = b2a
|
| 488 |
+
b2revb = b2revb
|
| 489 |
+
else:
|
| 490 |
+
init_messages = f_bonds
|
| 491 |
+
init_attached_features = f_atoms
|
| 492 |
+
a2nei = a2b
|
| 493 |
+
a2attached = a2a
|
| 494 |
+
b2a = b2a
|
| 495 |
+
b2revb = b2revb
|
| 496 |
+
|
| 497 |
+
q = self.mpn_q(init_messages=init_messages,
|
| 498 |
+
init_attached_features=init_attached_features,
|
| 499 |
+
a2nei=a2nei,
|
| 500 |
+
a2attached=a2attached,
|
| 501 |
+
b2a=b2a,
|
| 502 |
+
b2revb=b2revb)
|
| 503 |
+
k = self.mpn_k(init_messages=init_messages,
|
| 504 |
+
init_attached_features=init_attached_features,
|
| 505 |
+
a2nei=a2nei,
|
| 506 |
+
a2attached=a2attached,
|
| 507 |
+
b2a=b2a,
|
| 508 |
+
b2revb=b2revb)
|
| 509 |
+
v = self.mpn_v(init_messages=init_messages,
|
| 510 |
+
init_attached_features=init_attached_features,
|
| 511 |
+
a2nei=a2nei,
|
| 512 |
+
a2attached=a2attached,
|
| 513 |
+
b2a=b2a,
|
| 514 |
+
b2revb=b2revb)
|
| 515 |
+
return q, k, v
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
class MTBlock(nn.Module):
|
| 519 |
+
"""
|
| 520 |
+
The Multi-headed attention block.
|
| 521 |
+
"""
|
| 522 |
+
|
| 523 |
+
def __init__(self,
|
| 524 |
+
args,
|
| 525 |
+
num_attn_head,
|
| 526 |
+
input_dim,
|
| 527 |
+
hidden_size,
|
| 528 |
+
activation="ReLU",
|
| 529 |
+
dropout=0.0,
|
| 530 |
+
bias=True,
|
| 531 |
+
atom_messages=False,
|
| 532 |
+
cuda=True,
|
| 533 |
+
res_connection=False):
|
| 534 |
+
"""
|
| 535 |
+
|
| 536 |
+
:param args: the arguments.
|
| 537 |
+
:param num_attn_head: the number of attention head.
|
| 538 |
+
:param input_dim: the input dimension.
|
| 539 |
+
:param hidden_size: the hidden size of the model.
|
| 540 |
+
:param activation: the activation function.
|
| 541 |
+
:param dropout: the dropout ratio
|
| 542 |
+
:param bias: if true: all linear layer contains bias term.
|
| 543 |
+
:param atom_messages: the MPNEncoder type
|
| 544 |
+
:param cuda: if true, the model run with GPU.
|
| 545 |
+
:param res_connection: enables the skip-connection in MTBlock.
|
| 546 |
+
"""
|
| 547 |
+
super(MTBlock, self).__init__()
|
| 548 |
+
# self.args = args
|
| 549 |
+
self.atom_messages = atom_messages
|
| 550 |
+
self.hidden_size = hidden_size
|
| 551 |
+
self.heads = nn.ModuleList()
|
| 552 |
+
self.input_dim = input_dim
|
| 553 |
+
self.cuda = cuda
|
| 554 |
+
self.res_connection = res_connection
|
| 555 |
+
self.act_func = get_activation_function(activation)
|
| 556 |
+
self.dropout_layer = nn.Dropout(p=dropout)
|
| 557 |
+
# Note: elementwise_affine has to be consistent with the pre-training phase
|
| 558 |
+
self.layernorm = nn.LayerNorm(self.hidden_size, elementwise_affine=True)
|
| 559 |
+
|
| 560 |
+
self.W_i = nn.Linear(self.input_dim, self.hidden_size, bias=bias)
|
| 561 |
+
self.attn = MultiHeadedAttention(h=num_attn_head,
|
| 562 |
+
d_model=self.hidden_size,
|
| 563 |
+
bias=bias,
|
| 564 |
+
dropout=dropout)
|
| 565 |
+
self.W_o = nn.Linear(self.hidden_size * num_attn_head, self.hidden_size, bias=bias)
|
| 566 |
+
self.sublayer = SublayerConnection(self.hidden_size, dropout)
|
| 567 |
+
for _ in range(num_attn_head):
|
| 568 |
+
self.heads.append(Head(args, hidden_size=hidden_size, atom_messages=atom_messages))
|
| 569 |
+
|
| 570 |
+
def forward(self, batch, features_batch=None):
|
| 571 |
+
"""
|
| 572 |
+
|
| 573 |
+
:param batch: the graph batch generated by GroverCollator.
|
| 574 |
+
:param features_batch: the additional features of molecules. (deprecated)
|
| 575 |
+
:return:
|
| 576 |
+
"""
|
| 577 |
+
f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope, a2a = batch
|
| 578 |
+
|
| 579 |
+
if self.atom_messages:
|
| 580 |
+
# Only add linear transformation in the input feature.
|
| 581 |
+
if f_atoms.shape[1] != self.hidden_size:
|
| 582 |
+
f_atoms = self.W_i(f_atoms)
|
| 583 |
+
f_atoms = self.dropout_layer(self.layernorm(self.act_func(f_atoms)))
|
| 584 |
+
|
| 585 |
+
else: # bond messages
|
| 586 |
+
if f_bonds.shape[1] != self.hidden_size:
|
| 587 |
+
f_bonds = self.W_i(f_bonds)
|
| 588 |
+
f_bonds = self.dropout_layer(self.layernorm(self.act_func(f_bonds)))
|
| 589 |
+
|
| 590 |
+
queries = []
|
| 591 |
+
keys = []
|
| 592 |
+
values = []
|
| 593 |
+
for head in self.heads:
|
| 594 |
+
q, k, v = head(f_atoms, f_bonds, a2b, a2a, b2a, b2revb)
|
| 595 |
+
queries.append(q.unsqueeze(1))
|
| 596 |
+
keys.append(k.unsqueeze(1))
|
| 597 |
+
values.append(v.unsqueeze(1))
|
| 598 |
+
queries = torch.cat(queries, dim=1)
|
| 599 |
+
keys = torch.cat(keys, dim=1)
|
| 600 |
+
values = torch.cat(values, dim=1)
|
| 601 |
+
|
| 602 |
+
x_out = self.attn(queries, keys, values) # multi-headed attention
|
| 603 |
+
x_out = x_out.view(x_out.shape[0], -1)
|
| 604 |
+
x_out = self.W_o(x_out)
|
| 605 |
+
|
| 606 |
+
x_in = None
|
| 607 |
+
# support no residual connection in MTBlock.
|
| 608 |
+
if self.res_connection:
|
| 609 |
+
if self.atom_messages:
|
| 610 |
+
x_in = f_atoms
|
| 611 |
+
else:
|
| 612 |
+
x_in = f_bonds
|
| 613 |
+
|
| 614 |
+
if self.atom_messages:
|
| 615 |
+
f_atoms = self.sublayer(x_in, x_out)
|
| 616 |
+
else:
|
| 617 |
+
f_bonds = self.sublayer(x_in, x_out)
|
| 618 |
+
|
| 619 |
+
batch = f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope, a2a
|
| 620 |
+
features_batch = features_batch
|
| 621 |
+
return batch, features_batch
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
class GTransEncoder(nn.Module):
|
| 625 |
+
def __init__(self,
|
| 626 |
+
args,
|
| 627 |
+
hidden_size,
|
| 628 |
+
edge_fdim,
|
| 629 |
+
node_fdim,
|
| 630 |
+
dropout=0.0,
|
| 631 |
+
activation="ReLU",
|
| 632 |
+
num_mt_block=1,
|
| 633 |
+
num_attn_head=4,
|
| 634 |
+
atom_emb_output: Union[bool, str] = False, # options: True, False, None, "atom", "bond", "both"
|
| 635 |
+
bias=False,
|
| 636 |
+
cuda=True,
|
| 637 |
+
res_connection=False):
|
| 638 |
+
"""
|
| 639 |
+
|
| 640 |
+
:param args: the arguments.
|
| 641 |
+
:param hidden_size: the hidden size of the model.
|
| 642 |
+
:param edge_fdim: the dimension of additional feature for edge/bond.
|
| 643 |
+
:param node_fdim: the dimension of additional feature for node/atom.
|
| 644 |
+
:param dropout: the dropout ratio
|
| 645 |
+
:param activation: the activation function
|
| 646 |
+
:param num_mt_block: the number of mt block.
|
| 647 |
+
:param num_attn_head: the number of attention head.
|
| 648 |
+
:param atom_emb_output: enable the output aggregation after message passing.
|
| 649 |
+
atom_messages: True False
|
| 650 |
+
-False: no aggregating to atom. output size: (num_atoms, hidden_size) (num_bonds, hidden_size)
|
| 651 |
+
-True: aggregating to atom. output size: (num_atoms, hidden_size) (num_atoms, hidden_size)
|
| 652 |
+
-None: same as False
|
| 653 |
+
-"atom": same as True
|
| 654 |
+
-"bond": aggragating to bond. output size: (num_bonds, hidden_size) (num_bonds, hidden_size)
|
| 655 |
+
-"both": aggregating to atom&bond. output size: (num_atoms, hidden_size) (num_bonds, hidden_size)
|
| 656 |
+
(num_bonds, hidden_size) (num_atoms, hidden_size)
|
| 657 |
+
:param bias: enable bias term in all linear layers.
|
| 658 |
+
:param cuda: run with cuda.
|
| 659 |
+
:param res_connection: enables the skip-connection in MTBlock.
|
| 660 |
+
"""
|
| 661 |
+
super(GTransEncoder, self).__init__()
|
| 662 |
+
|
| 663 |
+
# For the compatibility issue.
|
| 664 |
+
if atom_emb_output is False:
|
| 665 |
+
atom_emb_output = None
|
| 666 |
+
if atom_emb_output is True:
|
| 667 |
+
atom_emb_output = 'atom'
|
| 668 |
+
|
| 669 |
+
self.hidden_size = hidden_size
|
| 670 |
+
self.dropout = dropout
|
| 671 |
+
self.activation = activation
|
| 672 |
+
self.cuda = cuda
|
| 673 |
+
self.bias = bias
|
| 674 |
+
self.res_connection = res_connection
|
| 675 |
+
self.edge_blocks = nn.ModuleList()
|
| 676 |
+
self.node_blocks = nn.ModuleList()
|
| 677 |
+
|
| 678 |
+
edge_input_dim = edge_fdim
|
| 679 |
+
node_input_dim = node_fdim
|
| 680 |
+
edge_input_dim_i = edge_input_dim
|
| 681 |
+
node_input_dim_i = node_input_dim
|
| 682 |
+
|
| 683 |
+
for i in range(num_mt_block):
|
| 684 |
+
if i != 0:
|
| 685 |
+
edge_input_dim_i = self.hidden_size
|
| 686 |
+
node_input_dim_i = self.hidden_size
|
| 687 |
+
self.edge_blocks.append(MTBlock(args=args,
|
| 688 |
+
num_attn_head=num_attn_head,
|
| 689 |
+
input_dim=edge_input_dim_i,
|
| 690 |
+
hidden_size=self.hidden_size,
|
| 691 |
+
activation=activation,
|
| 692 |
+
dropout=dropout,
|
| 693 |
+
bias=self.bias,
|
| 694 |
+
atom_messages=False,
|
| 695 |
+
cuda=cuda))
|
| 696 |
+
self.node_blocks.append(MTBlock(args=args,
|
| 697 |
+
num_attn_head=num_attn_head,
|
| 698 |
+
input_dim=node_input_dim_i,
|
| 699 |
+
hidden_size=self.hidden_size,
|
| 700 |
+
activation=activation,
|
| 701 |
+
dropout=dropout,
|
| 702 |
+
bias=self.bias,
|
| 703 |
+
atom_messages=True,
|
| 704 |
+
cuda=cuda))
|
| 705 |
+
|
| 706 |
+
self.atom_emb_output = atom_emb_output
|
| 707 |
+
|
| 708 |
+
self.ffn_atom_from_atom = PositionwiseFeedForward(self.hidden_size + node_fdim,
|
| 709 |
+
self.hidden_size * 4,
|
| 710 |
+
activation=self.activation,
|
| 711 |
+
dropout=self.dropout,
|
| 712 |
+
d_out=self.hidden_size)
|
| 713 |
+
|
| 714 |
+
self.ffn_atom_from_bond = PositionwiseFeedForward(self.hidden_size + node_fdim,
|
| 715 |
+
self.hidden_size * 4,
|
| 716 |
+
activation=self.activation,
|
| 717 |
+
dropout=self.dropout,
|
| 718 |
+
d_out=self.hidden_size)
|
| 719 |
+
|
| 720 |
+
self.ffn_bond_from_atom = PositionwiseFeedForward(self.hidden_size + edge_fdim,
|
| 721 |
+
self.hidden_size * 4,
|
| 722 |
+
activation=self.activation,
|
| 723 |
+
dropout=self.dropout,
|
| 724 |
+
d_out=self.hidden_size)
|
| 725 |
+
|
| 726 |
+
self.ffn_bond_from_bond = PositionwiseFeedForward(self.hidden_size + edge_fdim,
|
| 727 |
+
self.hidden_size * 4,
|
| 728 |
+
activation=self.activation,
|
| 729 |
+
dropout=self.dropout,
|
| 730 |
+
d_out=self.hidden_size)
|
| 731 |
+
|
| 732 |
+
self.atom_from_atom_sublayer = SublayerConnection(size=self.hidden_size, dropout=self.dropout)
|
| 733 |
+
self.atom_from_bond_sublayer = SublayerConnection(size=self.hidden_size, dropout=self.dropout)
|
| 734 |
+
self.bond_from_atom_sublayer = SublayerConnection(size=self.hidden_size, dropout=self.dropout)
|
| 735 |
+
self.bond_from_bond_sublayer = SublayerConnection(size=self.hidden_size, dropout=self.dropout)
|
| 736 |
+
|
| 737 |
+
self.act_func_node = get_activation_function(self.activation)
|
| 738 |
+
self.act_func_edge = get_activation_function(self.activation)
|
| 739 |
+
|
| 740 |
+
self.dropout_layer = nn.Dropout(p=args.dropout)
|
| 741 |
+
|
| 742 |
+
def pointwise_feed_forward_to_atom_embedding(self, emb_output, atom_fea, index, ffn_layer):
|
| 743 |
+
"""
|
| 744 |
+
The point-wise feed forward and long-range residual connection for atom view.
|
| 745 |
+
aggregate to atom.
|
| 746 |
+
:param emb_output: the output embedding from the previous multi-head attentions.
|
| 747 |
+
:param atom_fea: the atom/node feature embedding.
|
| 748 |
+
:param index: the index of neighborhood relations.
|
| 749 |
+
:param ffn_layer: the feed forward layer
|
| 750 |
+
:return:
|
| 751 |
+
"""
|
| 752 |
+
aggr_output = select_neighbor_and_aggregate(emb_output, index)
|
| 753 |
+
aggr_outputx = torch.cat([atom_fea, aggr_output], dim=1)
|
| 754 |
+
return ffn_layer(aggr_outputx), aggr_output
|
| 755 |
+
|
| 756 |
+
def pointwise_feed_forward_to_bond_embedding(self, emb_output, bond_fea, a2nei, b2revb, ffn_layer):
|
| 757 |
+
"""
|
| 758 |
+
The point-wise feed forward and long-range residual connection for bond view.
|
| 759 |
+
aggregate to bond.
|
| 760 |
+
:param emb_output: the output embedding from the previous multi-head attentions.
|
| 761 |
+
:param bond_fea: the bond/edge feature embedding.
|
| 762 |
+
:param index: the index of neighborhood relations.
|
| 763 |
+
:param ffn_layer: the feed forward layer
|
| 764 |
+
:return:
|
| 765 |
+
"""
|
| 766 |
+
aggr_output = select_neighbor_and_aggregate(emb_output, a2nei)
|
| 767 |
+
# remove rev bond / atom --- need for bond view
|
| 768 |
+
aggr_output = self.remove_rev_bond_message(emb_output, aggr_output, b2revb)
|
| 769 |
+
aggr_outputx = torch.cat([bond_fea, aggr_output], dim=1)
|
| 770 |
+
return ffn_layer(aggr_outputx), aggr_output
|
| 771 |
+
|
| 772 |
+
@staticmethod
|
| 773 |
+
def remove_rev_bond_message(orginal_message, aggr_message, b2revb):
|
| 774 |
+
"""
|
| 775 |
+
|
| 776 |
+
:param orginal_message:
|
| 777 |
+
:param aggr_message:
|
| 778 |
+
:param b2revb:
|
| 779 |
+
:return:
|
| 780 |
+
"""
|
| 781 |
+
rev_message = orginal_message[b2revb]
|
| 782 |
+
return aggr_message - rev_message
|
| 783 |
+
|
| 784 |
+
def atom_bond_transform(self,
|
| 785 |
+
to_atom=True, # False: to bond
|
| 786 |
+
atomwise_input=None,
|
| 787 |
+
bondwise_input=None,
|
| 788 |
+
original_f_atoms=None,
|
| 789 |
+
original_f_bonds=None,
|
| 790 |
+
a2a=None,
|
| 791 |
+
a2b=None,
|
| 792 |
+
b2a=None,
|
| 793 |
+
b2revb=None
|
| 794 |
+
):
|
| 795 |
+
"""
|
| 796 |
+
Transfer the output of atom/bond multi-head attention to the final atom/bond output.
|
| 797 |
+
:param to_atom: if true, the output is atom emebedding, otherwise, the output is bond embedding.
|
| 798 |
+
:param atomwise_input: the input embedding of atom/node.
|
| 799 |
+
:param bondwise_input: the input embedding of bond/edge.
|
| 800 |
+
:param original_f_atoms: the initial atom features.
|
| 801 |
+
:param original_f_bonds: the initial bond features.
|
| 802 |
+
:param a2a: mapping from atom index to its neighbors. num_atoms * max_num_bonds
|
| 803 |
+
:param a2b: mapping from atom index to incoming bond indices.
|
| 804 |
+
:param b2a: mapping from bond index to the index of the atom the bond is coming from.
|
| 805 |
+
:param b2revb: mapping from bond index to the index of the reverse bond.
|
| 806 |
+
:return:
|
| 807 |
+
"""
|
| 808 |
+
|
| 809 |
+
if to_atom:
|
| 810 |
+
# atom input to atom output
|
| 811 |
+
atomwise_input, _ = self.pointwise_feed_forward_to_atom_embedding(atomwise_input, original_f_atoms, a2a,
|
| 812 |
+
self.ffn_atom_from_atom)
|
| 813 |
+
atom_in_atom_out = self.atom_from_atom_sublayer(None, atomwise_input)
|
| 814 |
+
# bond to atom
|
| 815 |
+
bondwise_input, _ = self.pointwise_feed_forward_to_atom_embedding(bondwise_input, original_f_atoms, a2b,
|
| 816 |
+
self.ffn_atom_from_bond)
|
| 817 |
+
bond_in_atom_out = self.atom_from_bond_sublayer(None, bondwise_input)
|
| 818 |
+
return atom_in_atom_out, bond_in_atom_out
|
| 819 |
+
else: # to bond embeddings
|
| 820 |
+
|
| 821 |
+
# atom input to bond output
|
| 822 |
+
atom_list_for_bond = torch.cat([b2a.unsqueeze(dim=1), a2a[b2a]], dim=1)
|
| 823 |
+
atomwise_input, _ = self.pointwise_feed_forward_to_bond_embedding(atomwise_input, original_f_bonds,
|
| 824 |
+
atom_list_for_bond,
|
| 825 |
+
b2a[b2revb], self.ffn_bond_from_atom)
|
| 826 |
+
atom_in_bond_out = self.bond_from_atom_sublayer(None, atomwise_input)
|
| 827 |
+
# bond input to bond output
|
| 828 |
+
bond_list_for_bond = a2b[b2a]
|
| 829 |
+
bondwise_input, _ = self.pointwise_feed_forward_to_bond_embedding(bondwise_input, original_f_bonds,
|
| 830 |
+
bond_list_for_bond,
|
| 831 |
+
b2revb, self.ffn_bond_from_bond)
|
| 832 |
+
bond_in_bond_out = self.bond_from_bond_sublayer(None, bondwise_input)
|
| 833 |
+
return atom_in_bond_out, bond_in_bond_out
|
| 834 |
+
|
| 835 |
+
def forward(self, batch, features_batch = None):
|
| 836 |
+
f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope, a2a = batch
|
| 837 |
+
if self.cuda or next(self.parameters()).is_cuda:
|
| 838 |
+
f_atoms, f_bonds, a2b, b2a, b2revb = f_atoms.cuda(), f_bonds.cuda(), a2b.cuda(), b2a.cuda(), b2revb.cuda()
|
| 839 |
+
a2a = a2a.cuda()
|
| 840 |
+
|
| 841 |
+
node_batch = f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope, a2a
|
| 842 |
+
edge_batch = f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope, a2a
|
| 843 |
+
|
| 844 |
+
# opt pointwise_feed_forward
|
| 845 |
+
original_f_atoms, original_f_bonds = f_atoms, f_bonds
|
| 846 |
+
|
| 847 |
+
# Note: features_batch is not used here.
|
| 848 |
+
for nb in self.node_blocks: # atom messages. Multi-headed attention
|
| 849 |
+
node_batch, features_batch = nb(node_batch, features_batch)
|
| 850 |
+
for eb in self.edge_blocks: # bond messages. Multi-headed attention
|
| 851 |
+
edge_batch, features_batch = eb(edge_batch, features_batch)
|
| 852 |
+
|
| 853 |
+
atom_output, _, _, _, _, _, _, _ = node_batch # atom hidden states
|
| 854 |
+
_, bond_output, _, _, _, _, _, _ = edge_batch # bond hidden states
|
| 855 |
+
|
| 856 |
+
if self.atom_emb_output is None:
|
| 857 |
+
# output the embedding from multi-head attention directly.
|
| 858 |
+
return atom_output, bond_output
|
| 859 |
+
|
| 860 |
+
if self.atom_emb_output == 'atom':
|
| 861 |
+
return self.atom_bond_transform(to_atom=True, # False: to bond
|
| 862 |
+
atomwise_input=atom_output,
|
| 863 |
+
bondwise_input=bond_output,
|
| 864 |
+
original_f_atoms=original_f_atoms,
|
| 865 |
+
original_f_bonds=original_f_bonds,
|
| 866 |
+
a2a=a2a,
|
| 867 |
+
a2b=a2b,
|
| 868 |
+
b2a=b2a,
|
| 869 |
+
b2revb=b2revb)
|
| 870 |
+
elif self.atom_emb_output == 'bond':
|
| 871 |
+
return self.atom_bond_transform(to_atom=False, # False: to bond
|
| 872 |
+
atomwise_input=atom_output,
|
| 873 |
+
bondwise_input=bond_output,
|
| 874 |
+
original_f_atoms=original_f_atoms,
|
| 875 |
+
original_f_bonds=original_f_bonds,
|
| 876 |
+
a2a=a2a,
|
| 877 |
+
a2b=a2b,
|
| 878 |
+
b2a=b2a,
|
| 879 |
+
b2revb=b2revb)
|
| 880 |
+
else: # 'both'
|
| 881 |
+
atom_embeddings = self.atom_bond_transform(to_atom=True, # False: to bond
|
| 882 |
+
atomwise_input=atom_output,
|
| 883 |
+
bondwise_input=bond_output,
|
| 884 |
+
original_f_atoms=original_f_atoms,
|
| 885 |
+
original_f_bonds=original_f_bonds,
|
| 886 |
+
a2a=a2a,
|
| 887 |
+
a2b=a2b,
|
| 888 |
+
b2a=b2a,
|
| 889 |
+
b2revb=b2revb)
|
| 890 |
+
|
| 891 |
+
bond_embeddings = self.atom_bond_transform(to_atom=False, # False: to bond
|
| 892 |
+
atomwise_input=atom_output,
|
| 893 |
+
bondwise_input=bond_output,
|
| 894 |
+
original_f_atoms=original_f_atoms,
|
| 895 |
+
original_f_bonds=original_f_bonds,
|
| 896 |
+
a2a=a2a,
|
| 897 |
+
a2b=a2b,
|
| 898 |
+
b2a=b2a,
|
| 899 |
+
b2revb=b2revb)
|
| 900 |
+
# Notice: need to be consistent with output format of DualMPNN encoder
|
| 901 |
+
return ((atom_embeddings[0], bond_embeddings[0]),
|
| 902 |
+
(atom_embeddings[1], bond_embeddings[1]))
|
grover/model/models.py
ADDED
|
@@ -0,0 +1,506 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
The GROVER models for pretraining, finetuning and fingerprint generating.
|
| 3 |
+
"""
|
| 4 |
+
from argparse import Namespace
|
| 5 |
+
from typing import List, Dict, Callable
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
from torch import nn as nn
|
| 10 |
+
|
| 11 |
+
from grover.data import get_atom_fdim, get_bond_fdim
|
| 12 |
+
from grover.model.layers import Readout, GTransEncoder
|
| 13 |
+
from grover.util.nn_utils import get_activation_function
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class GROVEREmbedding(nn.Module):
|
| 17 |
+
"""
|
| 18 |
+
The GROVER Embedding class. It contains the GTransEncoder.
|
| 19 |
+
This GTransEncoder can be replaced by any validate encoders.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, args: Namespace):
|
| 23 |
+
"""
|
| 24 |
+
Initialize the GROVEREmbedding class.
|
| 25 |
+
:param args:
|
| 26 |
+
"""
|
| 27 |
+
super(GROVEREmbedding, self).__init__()
|
| 28 |
+
self.embedding_output_type = args.embedding_output_type
|
| 29 |
+
edge_dim = get_bond_fdim() + get_atom_fdim()
|
| 30 |
+
node_dim = get_atom_fdim()
|
| 31 |
+
if not hasattr(args, "backbone"):
|
| 32 |
+
print("No backbone specified in args, use gtrans backbone.")
|
| 33 |
+
args.backbone = "gtrans"
|
| 34 |
+
if args.backbone == "gtrans" or args.backbone == "dualtrans":
|
| 35 |
+
# dualtrans is the old name.
|
| 36 |
+
self.encoders = GTransEncoder(args,
|
| 37 |
+
hidden_size=args.hidden_size,
|
| 38 |
+
edge_fdim=edge_dim,
|
| 39 |
+
node_fdim=node_dim,
|
| 40 |
+
dropout=args.dropout,
|
| 41 |
+
activation=args.activation,
|
| 42 |
+
num_mt_block=args.num_mt_block,
|
| 43 |
+
num_attn_head=args.num_attn_head,
|
| 44 |
+
atom_emb_output=self.embedding_output_type,
|
| 45 |
+
bias=args.bias,
|
| 46 |
+
cuda=args.cuda)
|
| 47 |
+
|
| 48 |
+
def forward(self, graph_batch: List) -> Dict:
|
| 49 |
+
"""
|
| 50 |
+
The forward function takes graph_batch as input and output a dict. The content of the dict is decided by
|
| 51 |
+
self.embedding_output_type.
|
| 52 |
+
|
| 53 |
+
:param graph_batch: the input graph batch generated by MolCollator.
|
| 54 |
+
:return: a dict containing the embedding results.
|
| 55 |
+
"""
|
| 56 |
+
output = self.encoders(graph_batch)
|
| 57 |
+
if self.embedding_output_type == 'atom':
|
| 58 |
+
return {"atom_from_atom": output[0], "atom_from_bond": output[1],
|
| 59 |
+
"bond_from_atom": None, "bond_from_bond": None} # atom_from_atom, atom_from_bond
|
| 60 |
+
elif self.embedding_output_type == 'bond':
|
| 61 |
+
return {"atom_from_atom": None, "atom_from_bond": None,
|
| 62 |
+
"bond_from_atom": output[0], "bond_from_bond": output[1]} # bond_from_atom, bond_from_bond
|
| 63 |
+
elif self.embedding_output_type == "both":
|
| 64 |
+
return {"atom_from_atom": output[0][0], "bond_from_atom": output[0][1],
|
| 65 |
+
"atom_from_bond": output[1][0], "bond_from_bond": output[1][1]}
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class AtomVocabPrediction(nn.Module):
|
| 69 |
+
"""
|
| 70 |
+
The atom-wise vocabulary prediction task. The atom vocabulary is constructed by the context.
|
| 71 |
+
"""
|
| 72 |
+
def __init__(self, args, vocab_size, hidden_size=None):
|
| 73 |
+
"""
|
| 74 |
+
:param args: the argument.
|
| 75 |
+
:param vocab_size: the size of atom vocabulary.
|
| 76 |
+
"""
|
| 77 |
+
super(AtomVocabPrediction, self).__init__()
|
| 78 |
+
if not hidden_size:
|
| 79 |
+
hidden_size = args.hidden_size
|
| 80 |
+
self.linear = nn.Linear(hidden_size, vocab_size)
|
| 81 |
+
self.logsoftmax = nn.LogSoftmax(dim=1)
|
| 82 |
+
|
| 83 |
+
def forward(self, embeddings):
|
| 84 |
+
"""
|
| 85 |
+
If embeddings is None: do not go through forward pass.
|
| 86 |
+
:param embeddings: the atom embeddings, num_atom X fea_dim.
|
| 87 |
+
:return: the prediction for each atom, num_atom X vocab_size.
|
| 88 |
+
"""
|
| 89 |
+
if embeddings is None:
|
| 90 |
+
return None
|
| 91 |
+
return self.logsoftmax(self.linear(embeddings))
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class BondVocabPrediction(nn.Module):
|
| 95 |
+
"""
|
| 96 |
+
The bond-wise vocabulary prediction task. The bond vocabulary is constructed by the context.
|
| 97 |
+
"""
|
| 98 |
+
def __init__(self, args, vocab_size, hidden_size=None):
|
| 99 |
+
"""
|
| 100 |
+
Might need to use different architecture for bond vocab prediction.
|
| 101 |
+
:param args:
|
| 102 |
+
:param vocab_size: size of bond vocab.
|
| 103 |
+
:param hidden_size: hidden size
|
| 104 |
+
"""
|
| 105 |
+
super(BondVocabPrediction, self).__init__()
|
| 106 |
+
if not hidden_size:
|
| 107 |
+
hidden_size = args.hidden_size
|
| 108 |
+
self.linear = nn.Linear(hidden_size, vocab_size)
|
| 109 |
+
|
| 110 |
+
# ad-hoc here
|
| 111 |
+
# If TWO_FC_4_BOND_VOCAB, we will use two distinct fc layer to deal with the bond and rev bond.
|
| 112 |
+
self.TWO_FC_4_BOND_VOCAB = True
|
| 113 |
+
if self.TWO_FC_4_BOND_VOCAB:
|
| 114 |
+
self.linear_rev = nn.Linear(hidden_size, vocab_size)
|
| 115 |
+
self.logsoftmax = nn.LogSoftmax(dim=1)
|
| 116 |
+
|
| 117 |
+
def forward(self, embeddings):
|
| 118 |
+
"""
|
| 119 |
+
If embeddings is None: do not go through forward pass.
|
| 120 |
+
:param embeddings: the atom embeddings, num_bond X fea_dim.
|
| 121 |
+
:return: the prediction for each atom, num_bond X vocab_size.
|
| 122 |
+
"""
|
| 123 |
+
if embeddings is None:
|
| 124 |
+
return None
|
| 125 |
+
nm_bonds = embeddings.shape[0] # must be an odd number
|
| 126 |
+
# The bond and rev bond have odd and even ids respectively. See definition in molgraph.
|
| 127 |
+
ids1 = [0] + list(range(1, nm_bonds, 2))
|
| 128 |
+
ids2 = list(range(0, nm_bonds, 2))
|
| 129 |
+
if self.TWO_FC_4_BOND_VOCAB:
|
| 130 |
+
logits = self.linear(embeddings[ids1]) + self.linear_rev(embeddings[ids2])
|
| 131 |
+
else:
|
| 132 |
+
logits = self.linear(embeddings[ids1] + embeddings[ids2])
|
| 133 |
+
|
| 134 |
+
return self.logsoftmax(logits)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class FunctionalGroupPrediction(nn.Module):
|
| 138 |
+
"""
|
| 139 |
+
The functional group (semantic motifs) prediction task. This is a graph-level task.
|
| 140 |
+
"""
|
| 141 |
+
def __init__(self, args, fg_size):
|
| 142 |
+
"""
|
| 143 |
+
:param args: The arguments.
|
| 144 |
+
:param fg_size: The size of semantic motifs.
|
| 145 |
+
"""
|
| 146 |
+
super(FunctionalGroupPrediction, self).__init__()
|
| 147 |
+
first_linear_dim = args.hidden_size
|
| 148 |
+
hidden_size = args.hidden_size
|
| 149 |
+
|
| 150 |
+
# In order to retain maximal information in the encoder, we use a simple readout function here.
|
| 151 |
+
self.readout = Readout(rtype="mean", hidden_size=hidden_size)
|
| 152 |
+
# We have four branches here. But the input with less than four branch is OK.
|
| 153 |
+
# Since we use BCEWithLogitsLoss as the loss function, we only need to output logits here.
|
| 154 |
+
self.linear_atom_from_atom = nn.Linear(first_linear_dim, fg_size)
|
| 155 |
+
self.linear_atom_from_bond = nn.Linear(first_linear_dim, fg_size)
|
| 156 |
+
self.linear_bond_from_atom = nn.Linear(first_linear_dim, fg_size)
|
| 157 |
+
self.linear_bond_from_bond = nn.Linear(first_linear_dim, fg_size)
|
| 158 |
+
|
| 159 |
+
def forward(self, embeddings: Dict, ascope: List, bscope: List) -> Dict:
|
| 160 |
+
"""
|
| 161 |
+
The forward function of semantic motif prediction. It takes the node/bond embeddings, and the corresponding
|
| 162 |
+
atom/bond scope as input and produce the prediction logits for different branches.
|
| 163 |
+
:param embeddings: The input embeddings are organized as dict. The output of GROVEREmbedding.
|
| 164 |
+
:param ascope: The scope for bonds. Please refer BatchMolGraph for more details.
|
| 165 |
+
:param bscope: The scope for aotms. Please refer BatchMolGraph for more details.
|
| 166 |
+
:return: a dict contains the predicted logits.
|
| 167 |
+
"""
|
| 168 |
+
|
| 169 |
+
preds_atom_from_atom, preds_atom_from_bond, preds_bond_from_atom, preds_bond_from_bond = \
|
| 170 |
+
None, None, None, None
|
| 171 |
+
|
| 172 |
+
if embeddings["bond_from_atom"] is not None:
|
| 173 |
+
preds_bond_from_atom = self.linear_bond_from_atom(self.readout(embeddings["bond_from_atom"], bscope))
|
| 174 |
+
if embeddings["bond_from_bond"] is not None:
|
| 175 |
+
preds_bond_from_bond = self.linear_bond_from_bond(self.readout(embeddings["bond_from_bond"], bscope))
|
| 176 |
+
|
| 177 |
+
if embeddings["atom_from_atom"] is not None:
|
| 178 |
+
preds_atom_from_atom = self.linear_atom_from_atom(self.readout(embeddings["atom_from_atom"], ascope))
|
| 179 |
+
if embeddings["atom_from_bond"] is not None:
|
| 180 |
+
preds_atom_from_bond = self.linear_atom_from_bond(self.readout(embeddings["atom_from_bond"], ascope))
|
| 181 |
+
|
| 182 |
+
return {"atom_from_atom": preds_atom_from_atom, "atom_from_bond": preds_atom_from_bond,
|
| 183 |
+
"bond_from_atom": preds_bond_from_atom, "bond_from_bond": preds_bond_from_bond}
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class GroverTask(nn.Module):
|
| 187 |
+
"""
|
| 188 |
+
The pretrain module.
|
| 189 |
+
"""
|
| 190 |
+
def __init__(self, args, grover, atom_vocab_size, bond_vocab_size, fg_size):
|
| 191 |
+
super(GroverTask, self).__init__()
|
| 192 |
+
self.grover = grover
|
| 193 |
+
self.av_task_atom = AtomVocabPrediction(args, atom_vocab_size)
|
| 194 |
+
self.av_task_bond = AtomVocabPrediction(args, atom_vocab_size)
|
| 195 |
+
self.bv_task_atom = BondVocabPrediction(args, bond_vocab_size)
|
| 196 |
+
self.bv_task_bond = BondVocabPrediction(args, bond_vocab_size)
|
| 197 |
+
|
| 198 |
+
self.fg_task_all = FunctionalGroupPrediction(args, fg_size)
|
| 199 |
+
|
| 200 |
+
self.embedding_output_type = args.embedding_output_type
|
| 201 |
+
|
| 202 |
+
@staticmethod
|
| 203 |
+
def get_loss_func(args: Namespace) -> Callable:
|
| 204 |
+
"""
|
| 205 |
+
The loss function generator.
|
| 206 |
+
:param args: the arguments.
|
| 207 |
+
:return: the loss fucntion for GroverTask.
|
| 208 |
+
"""
|
| 209 |
+
def loss_func(preds, targets, dist_coff=args.dist_coff):
|
| 210 |
+
"""
|
| 211 |
+
The loss function for GroverTask.
|
| 212 |
+
:param preds: the predictions.
|
| 213 |
+
:param targets: the targets.
|
| 214 |
+
:param dist_coff: the default disagreement coefficient for the distances between different branches.
|
| 215 |
+
:return:
|
| 216 |
+
"""
|
| 217 |
+
av_task_loss = nn.NLLLoss(ignore_index=0, reduction="mean") # same for av and bv
|
| 218 |
+
|
| 219 |
+
fg_task_loss = nn.BCEWithLogitsLoss(reduction="mean")
|
| 220 |
+
# av_task_dist_loss = nn.KLDivLoss(reduction="mean")
|
| 221 |
+
av_task_dist_loss = nn.MSELoss(reduction="mean")
|
| 222 |
+
fg_task_dist_loss = nn.MSELoss(reduction="mean")
|
| 223 |
+
sigmoid = nn.Sigmoid()
|
| 224 |
+
|
| 225 |
+
av_atom_loss, av_bond_loss, av_dist_loss = 0.0, 0.0, 0.0
|
| 226 |
+
fg_atom_from_atom_loss, fg_atom_from_bond_loss, fg_atom_dist_loss = 0.0, 0.0, 0.0
|
| 227 |
+
bv_atom_loss, bv_bond_loss, bv_dist_loss = 0.0, 0.0, 0.0
|
| 228 |
+
fg_bond_from_atom_loss, fg_bond_from_bond_loss, fg_bond_dist_loss = 0.0, 0.0, 0.0
|
| 229 |
+
|
| 230 |
+
if preds["av_task"][0] is not None:
|
| 231 |
+
av_atom_loss = av_task_loss(preds['av_task'][0], targets["av_task"])
|
| 232 |
+
fg_atom_from_atom_loss = fg_task_loss(preds["fg_task"]["atom_from_atom"], targets["fg_task"])
|
| 233 |
+
|
| 234 |
+
if preds["av_task"][1] is not None:
|
| 235 |
+
av_bond_loss = av_task_loss(preds['av_task'][1], targets["av_task"])
|
| 236 |
+
fg_atom_from_bond_loss = fg_task_loss(preds["fg_task"]["atom_from_bond"], targets["fg_task"])
|
| 237 |
+
|
| 238 |
+
if preds["bv_task"][0] is not None:
|
| 239 |
+
bv_atom_loss = av_task_loss(preds['bv_task'][0], targets["bv_task"])
|
| 240 |
+
fg_bond_from_atom_loss = fg_task_loss(preds["fg_task"]["bond_from_atom"], targets["fg_task"])
|
| 241 |
+
|
| 242 |
+
if preds["bv_task"][1] is not None:
|
| 243 |
+
bv_bond_loss = av_task_loss(preds['bv_task'][1], targets["bv_task"])
|
| 244 |
+
fg_bond_from_bond_loss = fg_task_loss(preds["fg_task"]["bond_from_bond"], targets["fg_task"])
|
| 245 |
+
|
| 246 |
+
if preds["av_task"][0] is not None and preds["av_task"][1] is not None:
|
| 247 |
+
av_dist_loss = av_task_dist_loss(preds['av_task'][0], preds['av_task'][1])
|
| 248 |
+
fg_atom_dist_loss = fg_task_dist_loss(sigmoid(preds["fg_task"]["atom_from_atom"]),
|
| 249 |
+
sigmoid(preds["fg_task"]["atom_from_bond"]))
|
| 250 |
+
|
| 251 |
+
if preds["bv_task"][0] is not None and preds["bv_task"][1] is not None:
|
| 252 |
+
bv_dist_loss = av_task_dist_loss(preds['bv_task'][0], preds['bv_task'][1])
|
| 253 |
+
fg_bond_dist_loss = fg_task_dist_loss(sigmoid(preds["fg_task"]["bond_from_atom"]),
|
| 254 |
+
sigmoid(preds["fg_task"]["bond_from_bond"]))
|
| 255 |
+
|
| 256 |
+
av_loss = av_atom_loss + av_bond_loss
|
| 257 |
+
bv_loss = bv_atom_loss + bv_bond_loss
|
| 258 |
+
fg_atom_loss = fg_atom_from_atom_loss + fg_atom_from_bond_loss
|
| 259 |
+
fg_bond_loss = fg_bond_from_atom_loss + fg_bond_from_bond_loss
|
| 260 |
+
|
| 261 |
+
fg_loss = fg_atom_loss + fg_bond_loss
|
| 262 |
+
fg_dist_loss = fg_atom_dist_loss + fg_bond_dist_loss
|
| 263 |
+
|
| 264 |
+
# dist_loss = av_dist_loss + bv_dist_loss + fg_dist_loss
|
| 265 |
+
# print("%.4f %.4f %.4f %.4f %.4f %.4f"%(av_atom_loss,
|
| 266 |
+
# av_bond_loss,
|
| 267 |
+
# fg_atom_loss,
|
| 268 |
+
# fg_bond_loss,
|
| 269 |
+
# av_dist_loss,
|
| 270 |
+
# fg_dist_loss))
|
| 271 |
+
# return av_loss + fg_loss + dist_coff * dist_loss
|
| 272 |
+
overall_loss = av_loss + bv_loss + fg_loss + dist_coff * av_dist_loss + \
|
| 273 |
+
dist_coff * bv_dist_loss + fg_dist_loss
|
| 274 |
+
|
| 275 |
+
return overall_loss, av_loss, bv_loss, fg_loss, av_dist_loss, bv_dist_loss, fg_dist_loss
|
| 276 |
+
|
| 277 |
+
return loss_func
|
| 278 |
+
|
| 279 |
+
def forward(self, graph_batch: List):
|
| 280 |
+
"""
|
| 281 |
+
The forward function.
|
| 282 |
+
:param graph_batch:
|
| 283 |
+
:return:
|
| 284 |
+
"""
|
| 285 |
+
_, _, _, _, _, a_scope, b_scope, _ = graph_batch
|
| 286 |
+
a_scope = a_scope.data.cpu().numpy().tolist()
|
| 287 |
+
|
| 288 |
+
embeddings = self.grover(graph_batch)
|
| 289 |
+
|
| 290 |
+
av_task_pred_atom = self.av_task_atom(
|
| 291 |
+
embeddings["atom_from_atom"]) # if None: means not go through this fowward
|
| 292 |
+
av_task_pred_bond = self.av_task_bond(embeddings["atom_from_bond"])
|
| 293 |
+
|
| 294 |
+
bv_task_pred_atom = self.bv_task_atom(embeddings["bond_from_atom"])
|
| 295 |
+
bv_task_pred_bond = self.bv_task_bond(embeddings["bond_from_bond"])
|
| 296 |
+
|
| 297 |
+
fg_task_pred_all = self.fg_task_all(embeddings, a_scope, b_scope)
|
| 298 |
+
|
| 299 |
+
return {"av_task": (av_task_pred_atom, av_task_pred_bond),
|
| 300 |
+
"bv_task": (bv_task_pred_atom, bv_task_pred_bond),
|
| 301 |
+
"fg_task": fg_task_pred_all}
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
class GroverFpGeneration(nn.Module):
|
| 305 |
+
"""
|
| 306 |
+
GroverFpGeneration class.
|
| 307 |
+
It loads the pre-trained model and produce the fingerprints for input molecules.
|
| 308 |
+
"""
|
| 309 |
+
def __init__(self, args):
|
| 310 |
+
"""
|
| 311 |
+
Init function.
|
| 312 |
+
:param args: the arguments.
|
| 313 |
+
"""
|
| 314 |
+
super(GroverFpGeneration, self).__init__()
|
| 315 |
+
|
| 316 |
+
self.fingerprint_source = args.fingerprint_source
|
| 317 |
+
self.iscuda = args.cuda
|
| 318 |
+
|
| 319 |
+
self.grover = GROVEREmbedding(args)
|
| 320 |
+
self.readout = Readout(rtype="mean", hidden_size=args.hidden_size)
|
| 321 |
+
|
| 322 |
+
def forward(self, batch, features_batch):
|
| 323 |
+
"""
|
| 324 |
+
The forward function.
|
| 325 |
+
It takes graph batch and molecular feature batch as input and produce the fingerprints of this molecules.
|
| 326 |
+
:param batch:
|
| 327 |
+
:param features_batch:
|
| 328 |
+
:return:
|
| 329 |
+
"""
|
| 330 |
+
_, _, _, _, _, a_scope, b_scope, _ = batch
|
| 331 |
+
|
| 332 |
+
output = self.grover(batch)
|
| 333 |
+
# Share readout
|
| 334 |
+
mol_atom_from_bond_output = self.readout(output["atom_from_bond"], a_scope)
|
| 335 |
+
mol_atom_from_atom_output = self.readout(output["atom_from_atom"], a_scope)
|
| 336 |
+
|
| 337 |
+
if self.fingerprint_source == "bond" or self.fingerprint_source == "both":
|
| 338 |
+
mol_bond_from_atom_output = self.readout(output["bond_from_atom"], b_scope)
|
| 339 |
+
mol_bond_from_bodd_output = self.readout(output["bond_from_bond"], b_scope)
|
| 340 |
+
|
| 341 |
+
if features_batch[0] is not None:
|
| 342 |
+
features_batch = torch.from_numpy(np.stack(features_batch)).float()
|
| 343 |
+
if self.iscuda:
|
| 344 |
+
features_batch = features_batch.cuda()
|
| 345 |
+
features_batch = features_batch.to(output["atom_from_atom"])
|
| 346 |
+
if len(features_batch.shape) == 1:
|
| 347 |
+
features_batch = features_batch.view([1, features_batch.shape[0]])
|
| 348 |
+
else:
|
| 349 |
+
features_batch = None
|
| 350 |
+
|
| 351 |
+
if self.fingerprint_source == "atom":
|
| 352 |
+
fp = torch.cat([mol_atom_from_atom_output, mol_atom_from_bond_output], 1)
|
| 353 |
+
elif self.fingerprint_source == "bond":
|
| 354 |
+
fp = torch.cat([mol_bond_from_atom_output, mol_bond_from_bodd_output], 1)
|
| 355 |
+
else:
|
| 356 |
+
# the both case.
|
| 357 |
+
fp = torch.cat([mol_atom_from_atom_output, mol_atom_from_bond_output,
|
| 358 |
+
mol_bond_from_atom_output, mol_bond_from_bodd_output], 1)
|
| 359 |
+
if features_batch is not None:
|
| 360 |
+
fp = torch.cat([fp, features_batch], 1)
|
| 361 |
+
return fp
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
class GroverFinetuneTask(nn.Module):
|
| 365 |
+
"""
|
| 366 |
+
The finetune
|
| 367 |
+
"""
|
| 368 |
+
def __init__(self, args):
|
| 369 |
+
super(GroverFinetuneTask, self).__init__()
|
| 370 |
+
|
| 371 |
+
self.hidden_size = args.hidden_size
|
| 372 |
+
self.iscuda = args.cuda
|
| 373 |
+
|
| 374 |
+
self.grover = GROVEREmbedding(args)
|
| 375 |
+
|
| 376 |
+
if args.self_attention:
|
| 377 |
+
self.readout = Readout(rtype="self_attention", hidden_size=self.hidden_size,
|
| 378 |
+
attn_hidden=args.attn_hidden,
|
| 379 |
+
attn_out=args.attn_out)
|
| 380 |
+
else:
|
| 381 |
+
self.readout = Readout(rtype="mean", hidden_size=self.hidden_size)
|
| 382 |
+
|
| 383 |
+
self.mol_atom_from_atom_ffn = self.create_ffn(args)
|
| 384 |
+
self.mol_atom_from_bond_ffn = self.create_ffn(args)
|
| 385 |
+
#self.ffn = nn.ModuleList()
|
| 386 |
+
#self.ffn.append(self.mol_atom_from_atom_ffn)
|
| 387 |
+
#self.ffn.append(self.mol_atom_from_bond_ffn)
|
| 388 |
+
|
| 389 |
+
self.classification = args.dataset_type == 'classification'
|
| 390 |
+
if self.classification:
|
| 391 |
+
self.sigmoid = nn.Sigmoid()
|
| 392 |
+
|
| 393 |
+
def create_ffn(self, args: Namespace):
|
| 394 |
+
"""
|
| 395 |
+
Creates the feed-forward network for the model.
|
| 396 |
+
|
| 397 |
+
:param args: Arguments.
|
| 398 |
+
"""
|
| 399 |
+
# Note: args.features_dim is set according the real loaded features data
|
| 400 |
+
if args.features_only:
|
| 401 |
+
first_linear_dim = args.features_size + args.features_dim
|
| 402 |
+
else:
|
| 403 |
+
if args.self_attention:
|
| 404 |
+
first_linear_dim = args.hidden_size * args.attn_out
|
| 405 |
+
# TODO: Ad-hoc!
|
| 406 |
+
# if args.use_input_features:
|
| 407 |
+
first_linear_dim += args.features_dim
|
| 408 |
+
else:
|
| 409 |
+
first_linear_dim = args.hidden_size + args.features_dim
|
| 410 |
+
|
| 411 |
+
dropout = nn.Dropout(args.dropout)
|
| 412 |
+
activation = get_activation_function(args.activation)
|
| 413 |
+
# TODO: ffn_hidden_size
|
| 414 |
+
# Create FFN layers
|
| 415 |
+
if args.ffn_num_layers == 1:
|
| 416 |
+
ffn = [
|
| 417 |
+
dropout,
|
| 418 |
+
nn.Linear(first_linear_dim, args.output_size)
|
| 419 |
+
]
|
| 420 |
+
else:
|
| 421 |
+
ffn = [
|
| 422 |
+
dropout,
|
| 423 |
+
nn.Linear(first_linear_dim, args.ffn_hidden_size)
|
| 424 |
+
]
|
| 425 |
+
for _ in range(args.ffn_num_layers - 2):
|
| 426 |
+
ffn.extend([
|
| 427 |
+
activation,
|
| 428 |
+
dropout,
|
| 429 |
+
nn.Linear(args.ffn_hidden_size, args.ffn_hidden_size),
|
| 430 |
+
])
|
| 431 |
+
ffn.extend([
|
| 432 |
+
activation,
|
| 433 |
+
dropout,
|
| 434 |
+
nn.Linear(args.ffn_hidden_size, args.output_size),
|
| 435 |
+
])
|
| 436 |
+
|
| 437 |
+
# Create FFN model
|
| 438 |
+
return nn.Sequential(*ffn)
|
| 439 |
+
|
| 440 |
+
@staticmethod
|
| 441 |
+
def get_loss_func(args):
|
| 442 |
+
def loss_func(preds, targets,
|
| 443 |
+
dt=args.dataset_type,
|
| 444 |
+
dist_coff=args.dist_coff):
|
| 445 |
+
|
| 446 |
+
if dt == 'classification':
|
| 447 |
+
pred_loss = nn.BCEWithLogitsLoss(reduction='none')
|
| 448 |
+
elif dt == 'regression':
|
| 449 |
+
pred_loss = nn.MSELoss(reduction='none')
|
| 450 |
+
else:
|
| 451 |
+
raise ValueError(f'Dataset type "{args.dataset_type}" not supported.')
|
| 452 |
+
|
| 453 |
+
# print(type(preds))
|
| 454 |
+
# TODO: Here, should we need to involve the model status? Using len(preds) is just a hack.
|
| 455 |
+
if type(preds) is not tuple:
|
| 456 |
+
# in eval mode.
|
| 457 |
+
return pred_loss(preds, targets)
|
| 458 |
+
|
| 459 |
+
# in train mode.
|
| 460 |
+
dist_loss = nn.MSELoss(reduction='none')
|
| 461 |
+
# dist_loss = nn.CosineSimilarity(dim=0)
|
| 462 |
+
# print(pred_loss)
|
| 463 |
+
|
| 464 |
+
dist = dist_loss(preds[0], preds[1])
|
| 465 |
+
pred_loss1 = pred_loss(preds[0], targets)
|
| 466 |
+
pred_loss2 = pred_loss(preds[1], targets)
|
| 467 |
+
return pred_loss1 + pred_loss2 + dist_coff * dist
|
| 468 |
+
|
| 469 |
+
return loss_func
|
| 470 |
+
|
| 471 |
+
def forward(self, batch, features_batch):
|
| 472 |
+
_, _, _, _, _, a_scope, _, _ = batch
|
| 473 |
+
|
| 474 |
+
output = self.grover(batch)
|
| 475 |
+
# Share readout
|
| 476 |
+
mol_atom_from_bond_output = self.readout(output["atom_from_bond"], a_scope)
|
| 477 |
+
mol_atom_from_atom_output = self.readout(output["atom_from_atom"], a_scope)
|
| 478 |
+
|
| 479 |
+
if features_batch[0] is not None:
|
| 480 |
+
features_batch = torch.from_numpy(np.stack(features_batch)).float()
|
| 481 |
+
if self.iscuda:
|
| 482 |
+
features_batch = features_batch.cuda()
|
| 483 |
+
features_batch = features_batch.to(output["atom_from_atom"])
|
| 484 |
+
if len(features_batch.shape) == 1:
|
| 485 |
+
features_batch = features_batch.view([1, features_batch.shape[0]])
|
| 486 |
+
else:
|
| 487 |
+
features_batch = None
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
if features_batch is not None:
|
| 491 |
+
mol_atom_from_atom_output = torch.cat([mol_atom_from_atom_output, features_batch], 1)
|
| 492 |
+
mol_atom_from_bond_output = torch.cat([mol_atom_from_bond_output, features_batch], 1)
|
| 493 |
+
|
| 494 |
+
if self.training:
|
| 495 |
+
atom_ffn_output = self.mol_atom_from_atom_ffn(mol_atom_from_atom_output)
|
| 496 |
+
bond_ffn_output = self.mol_atom_from_bond_ffn(mol_atom_from_bond_output)
|
| 497 |
+
return atom_ffn_output, bond_ffn_output
|
| 498 |
+
else:
|
| 499 |
+
atom_ffn_output = self.mol_atom_from_atom_ffn(mol_atom_from_atom_output)
|
| 500 |
+
bond_ffn_output = self.mol_atom_from_bond_ffn(mol_atom_from_bond_output)
|
| 501 |
+
if self.classification:
|
| 502 |
+
atom_ffn_output = self.sigmoid(atom_ffn_output)
|
| 503 |
+
bond_ffn_output = self.sigmoid(bond_ffn_output)
|
| 504 |
+
output = (atom_ffn_output + bond_ffn_output) / 2
|
| 505 |
+
|
| 506 |
+
return output
|
grover/util/metrics.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
The evaluation metrics.
|
| 3 |
+
"""
|
| 4 |
+
import math
|
| 5 |
+
from typing import List, Callable, Union
|
| 6 |
+
|
| 7 |
+
from sklearn.metrics import accuracy_score, mean_squared_error, roc_auc_score, mean_absolute_error, r2_score, \
|
| 8 |
+
precision_recall_curve, auc, recall_score, confusion_matrix
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def accuracy(targets: List[int], preds: List[float], threshold: float = 0.5) -> float:
|
| 12 |
+
"""
|
| 13 |
+
Computes the accuracy of a binary prediction task using a given threshold for generating hard predictions.
|
| 14 |
+
|
| 15 |
+
:param targets: A list of binary targets.
|
| 16 |
+
:param preds: A list of prediction probabilities.
|
| 17 |
+
:param threshold: The threshold above which a prediction is a 1 and below which (inclusive) a prediction is a 0
|
| 18 |
+
:return: The computed accuracy.
|
| 19 |
+
"""
|
| 20 |
+
hard_preds = [1 if p > threshold else 0 for p in preds]
|
| 21 |
+
return accuracy_score(targets, hard_preds)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def recall(targets: List[int], preds: List[float], threshold: float = 0.5) -> float:
|
| 25 |
+
"""
|
| 26 |
+
Computes the recall of a binary prediction task using a given threshold for generating hard predictions.
|
| 27 |
+
|
| 28 |
+
:param targets: A list of binary targets.
|
| 29 |
+
:param preds: A list of prediction probabilities.
|
| 30 |
+
:param threshold: The threshold above which a prediction is a 1 and below which (inclusive) a prediction is a 0
|
| 31 |
+
:return: The computed recall.
|
| 32 |
+
"""
|
| 33 |
+
hard_preds = [1 if p > threshold else 0 for p in preds]
|
| 34 |
+
return recall_score(targets, hard_preds)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def sensitivity(targets: List[int], preds: List[float], threshold: float = 0.5) -> float:
|
| 38 |
+
"""
|
| 39 |
+
Computes the sensitivity of a binary prediction task using a given threshold for generating hard predictions.
|
| 40 |
+
|
| 41 |
+
:param targets: A list of binary targets.
|
| 42 |
+
:param preds: A list of prediction probabilities.
|
| 43 |
+
:param threshold: The threshold above which a prediction is a 1 and below which (inclusive) a prediction is a 0
|
| 44 |
+
:return: The computed sensitivity.
|
| 45 |
+
"""
|
| 46 |
+
return recall(targets, preds, threshold)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def specificity(targets: List[int], preds: List[float], threshold: float = 0.5) -> float:
|
| 50 |
+
"""
|
| 51 |
+
Computes the specificity of a binary prediction task using a given threshold for generating hard predictions.
|
| 52 |
+
|
| 53 |
+
:param targets: A list of binary targets.
|
| 54 |
+
:param preds: A list of prediction probabilities.
|
| 55 |
+
:param threshold: The threshold above which a prediction is a 1 and below which (inclusive) a prediction is a 0
|
| 56 |
+
:return: The computed specificity.
|
| 57 |
+
"""
|
| 58 |
+
hard_preds = [1 if p > threshold else 0 for p in preds]
|
| 59 |
+
tn, fp, _, _ = confusion_matrix(targets, hard_preds).ravel()
|
| 60 |
+
return tn / float(tn + fp)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def rmse(targets: List[float], preds: List[float]) -> float:
|
| 65 |
+
"""
|
| 66 |
+
Computes the root mean squared error.
|
| 67 |
+
|
| 68 |
+
:param targets: A list of targets.
|
| 69 |
+
:param preds: A list of predictions.
|
| 70 |
+
:return: The computed rmse.
|
| 71 |
+
"""
|
| 72 |
+
return math.sqrt(mean_squared_error(targets, preds))
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def get_metric_func(metric: str) -> Callable[[Union[List[int], List[float]], List[float]], float]:
|
| 76 |
+
"""
|
| 77 |
+
Gets the metric function corresponding to a given metric name.
|
| 78 |
+
|
| 79 |
+
:param metric: Metric name.
|
| 80 |
+
:return: A metric function which takes as arguments a list of targets and a list of predictions and returns.
|
| 81 |
+
"""
|
| 82 |
+
# Note: If you want to add a new metric, please also update the parser argument --metric in parsing.py.
|
| 83 |
+
if metric == 'auc':
|
| 84 |
+
return roc_auc_score
|
| 85 |
+
|
| 86 |
+
if metric == 'prc-auc':
|
| 87 |
+
return prc_auc
|
| 88 |
+
|
| 89 |
+
if metric == 'rmse':
|
| 90 |
+
return rmse
|
| 91 |
+
|
| 92 |
+
if metric == 'mae':
|
| 93 |
+
return mean_absolute_error
|
| 94 |
+
|
| 95 |
+
if metric == 'r2':
|
| 96 |
+
return r2_score
|
| 97 |
+
|
| 98 |
+
if metric == 'accuracy':
|
| 99 |
+
return accuracy
|
| 100 |
+
|
| 101 |
+
if metric == 'recall':
|
| 102 |
+
return recall
|
| 103 |
+
|
| 104 |
+
if metric == 'sensitivity':
|
| 105 |
+
return sensitivity
|
| 106 |
+
|
| 107 |
+
if metric == 'specificity':
|
| 108 |
+
return specificity
|
| 109 |
+
|
| 110 |
+
raise ValueError(f'Metric "{metric}" not supported.')
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def prc_auc(targets: List[int], preds: List[float]) -> float:
|
| 114 |
+
"""
|
| 115 |
+
Computes the area under the precision-recall curve.
|
| 116 |
+
|
| 117 |
+
:param targets: A list of binary targets.
|
| 118 |
+
:param preds: A list of prediction probabilities.
|
| 119 |
+
:return: The computed prc-auc.
|
| 120 |
+
"""
|
| 121 |
+
precision, recall, _ = precision_recall_curve(targets, preds)
|
| 122 |
+
return auc(recall, precision)
|
grover/util/multi_gpu_wrapper.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Wrapper for multi-GPU training.
|
| 3 |
+
"""
|
| 4 |
+
# use Hovorod for multi-GPU pytorch training
|
| 5 |
+
try:
|
| 6 |
+
import horovod.torch as mgw
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
print('using Horovod for multi-GPU training')
|
| 10 |
+
except ImportError:
|
| 11 |
+
print('[WARNING] Horovod cannot be imported; multi-GPU training is unsupported')
|
| 12 |
+
pass
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class MultiGpuWrapper(object):
|
| 16 |
+
"""Wrapper for multi-GPU training."""
|
| 17 |
+
|
| 18 |
+
def __init__(self):
|
| 19 |
+
"""Constructor function."""
|
| 20 |
+
pass
|
| 21 |
+
|
| 22 |
+
@classmethod
|
| 23 |
+
def init(cls, *args):
|
| 24 |
+
"""Initialization."""
|
| 25 |
+
|
| 26 |
+
try:
|
| 27 |
+
return mgw.init(*args)
|
| 28 |
+
except NameError:
|
| 29 |
+
raise NameError('module <mgw> not imported')
|
| 30 |
+
|
| 31 |
+
@classmethod
|
| 32 |
+
def size(cls, *args):
|
| 33 |
+
"""Get the number of workers at all nodes."""
|
| 34 |
+
|
| 35 |
+
try:
|
| 36 |
+
return mgw.size(*args)
|
| 37 |
+
except NameError:
|
| 38 |
+
raise NameError('module <mgw> not imported')
|
| 39 |
+
|
| 40 |
+
@classmethod
|
| 41 |
+
def rank(cls, *args):
|
| 42 |
+
"""Get the rank of current worker at all nodes."""
|
| 43 |
+
|
| 44 |
+
try:
|
| 45 |
+
return mgw.rank(*args)
|
| 46 |
+
except NameError:
|
| 47 |
+
raise NameError('module <mgw> not imported')
|
| 48 |
+
|
| 49 |
+
@classmethod
|
| 50 |
+
def local_size(cls, *args):
|
| 51 |
+
"""Get the number of workers at the current node."""
|
| 52 |
+
|
| 53 |
+
try:
|
| 54 |
+
return mgw.local_size(*args)
|
| 55 |
+
except NameError:
|
| 56 |
+
raise NameError('module <mgw> not imported')
|
| 57 |
+
|
| 58 |
+
@classmethod
|
| 59 |
+
def local_rank(cls, *args):
|
| 60 |
+
"""Get the rank of current worker at the current node."""
|
| 61 |
+
|
| 62 |
+
try:
|
| 63 |
+
return mgw.local_rank(*args)
|
| 64 |
+
except NameError:
|
| 65 |
+
raise NameError('module <mgw> not imported')
|
| 66 |
+
|
| 67 |
+
@classmethod
|
| 68 |
+
def DistributedOptimizer(cls, *args, **kwargs):
|
| 69 |
+
"""Get a distributed optimizer from the base optimizer."""
|
| 70 |
+
|
| 71 |
+
try:
|
| 72 |
+
return mgw.DistributedOptimizer(*args, **kwargs)
|
| 73 |
+
except NameError:
|
| 74 |
+
raise NameError('module <mgw> not imported')
|
| 75 |
+
|
| 76 |
+
@classmethod
|
| 77 |
+
def broadcast_parameters(cls, *args, **kwargs):
|
| 78 |
+
"""Get a operation to broadcast all the parameters."""
|
| 79 |
+
|
| 80 |
+
try:
|
| 81 |
+
return mgw.broadcast_parameters(*args, **kwargs)
|
| 82 |
+
except NameError:
|
| 83 |
+
raise NameError('module <mgw> not imported')
|
| 84 |
+
|
| 85 |
+
@classmethod
|
| 86 |
+
def broadcast_optimizer_state(cls, *args, **kwargs):
|
| 87 |
+
"""Get a operation to broadcast all the optimizer state."""
|
| 88 |
+
|
| 89 |
+
try:
|
| 90 |
+
return mgw.broadcast_optimizer_state(*args, **kwargs)
|
| 91 |
+
except NameError:
|
| 92 |
+
raise NameError('module <mgw> not imported')
|
| 93 |
+
|
| 94 |
+
@classmethod
|
| 95 |
+
def broadcast(cls, *args, **kwargs):
|
| 96 |
+
"""Get a operation to broadcast all the optimizer state."""
|
| 97 |
+
|
| 98 |
+
try:
|
| 99 |
+
return mgw.broadcast(*args, **kwargs)
|
| 100 |
+
except NameError:
|
| 101 |
+
raise NameError('module <mgw> not imported')
|
| 102 |
+
|
| 103 |
+
@classmethod
|
| 104 |
+
def barrier(cls):
|
| 105 |
+
"""Add a barrier to synchronize different processes"""
|
| 106 |
+
|
| 107 |
+
try:
|
| 108 |
+
return mgw.allreduce(torch.tensor(0), name='barrier')
|
| 109 |
+
except NameError:
|
| 110 |
+
raise NameError('module <mgw> not imported')
|
grover/util/nn_utils.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
The utility function for model construction.
|
| 3 |
+
This implementation is adapted from
|
| 4 |
+
https://github.com/chemprop/chemprop/blob/master/chemprop/nn_utils.py
|
| 5 |
+
"""
|
| 6 |
+
import torch
|
| 7 |
+
from torch import nn as nn
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def param_count(model: nn.Module) -> int:
|
| 11 |
+
"""
|
| 12 |
+
Determines number of trainable parameters.
|
| 13 |
+
:param model: An nn.Module.
|
| 14 |
+
:return: The number of trainable parameters.
|
| 15 |
+
"""
|
| 16 |
+
return sum(param.numel() for param in model.parameters() if param.requires_grad)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def index_select_nd(source: torch.Tensor, index: torch.Tensor) -> torch.Tensor:
|
| 20 |
+
"""
|
| 21 |
+
Selects the message features from source corresponding to the atom or bond indices in index.
|
| 22 |
+
|
| 23 |
+
:param source: A tensor of shape (num_bonds, hidden_size) containing message features.
|
| 24 |
+
:param index: A tensor of shape (num_atoms/num_bonds, max_num_bonds) containing the atom or bond
|
| 25 |
+
indices to select from source.
|
| 26 |
+
:return: A tensor of shape (num_atoms/num_bonds, max_num_bonds, hidden_size) containing the message
|
| 27 |
+
features corresponding to the atoms/bonds specified in index.
|
| 28 |
+
"""
|
| 29 |
+
index_size = index.size() # (num_atoms/num_bonds, max_num_bonds)
|
| 30 |
+
suffix_dim = source.size()[1:] # (hidden_size,)
|
| 31 |
+
final_size = index_size + suffix_dim # (num_atoms/num_bonds, max_num_bonds, hidden_size)
|
| 32 |
+
|
| 33 |
+
target = source.index_select(dim=0, index=index.view(-1)) # (num_atoms/num_bonds * max_num_bonds, hidden_size)
|
| 34 |
+
target = target.view(final_size) # (num_atoms/num_bonds, max_num_bonds, hidden_size)
|
| 35 |
+
|
| 36 |
+
return target
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def get_activation_function(activation: str) -> nn.Module:
|
| 40 |
+
"""
|
| 41 |
+
Gets an activation function module given the name of the activation.
|
| 42 |
+
|
| 43 |
+
:param activation: The name of the activation function.
|
| 44 |
+
:return: The activation function module.
|
| 45 |
+
"""
|
| 46 |
+
if activation == 'ReLU':
|
| 47 |
+
return nn.ReLU()
|
| 48 |
+
elif activation == 'LeakyReLU':
|
| 49 |
+
return nn.LeakyReLU(0.1)
|
| 50 |
+
elif activation == 'PReLU':
|
| 51 |
+
return nn.PReLU()
|
| 52 |
+
elif activation == 'tanh':
|
| 53 |
+
return nn.Tanh()
|
| 54 |
+
elif activation == 'SELU':
|
| 55 |
+
return nn.SELU()
|
| 56 |
+
elif activation == 'ELU':
|
| 57 |
+
return nn.ELU()
|
| 58 |
+
elif activation == "Linear":
|
| 59 |
+
return lambda x: x
|
| 60 |
+
else:
|
| 61 |
+
raise ValueError(f'Activation "{activation}" not supported.')
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def initialize_weights(model: nn.Module, distinct_init=False, model_idx=0):
|
| 65 |
+
"""
|
| 66 |
+
Initializes the weights of a model in place.
|
| 67 |
+
|
| 68 |
+
:param model: An nn.Module.
|
| 69 |
+
"""
|
| 70 |
+
init_fns = [nn.init.kaiming_normal_, nn.init.kaiming_uniform_,
|
| 71 |
+
nn.init.xavier_normal_, nn.init.xavier_uniform_]
|
| 72 |
+
for param in model.parameters():
|
| 73 |
+
if param.dim() == 1:
|
| 74 |
+
nn.init.constant_(param, 0)
|
| 75 |
+
else:
|
| 76 |
+
if distinct_init:
|
| 77 |
+
init_fn = init_fns[model_idx % 4]
|
| 78 |
+
if 'kaiming' in init_fn.__name__:
|
| 79 |
+
init_fn(param, nonlinearity='relu')
|
| 80 |
+
else:
|
| 81 |
+
init_fn(param)
|
| 82 |
+
else:
|
| 83 |
+
nn.init.xavier_normal_(param)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def select_neighbor_and_aggregate(feature, index):
|
| 87 |
+
"""
|
| 88 |
+
The basic operation in message passing.
|
| 89 |
+
Caution: the index_selec_ND would cause the reproducibility issue when performing the training on CUDA.
|
| 90 |
+
See: https://pytorch.org/docs/stable/notes/randomness.html
|
| 91 |
+
:param feature: the candidate feature for aggregate. (n_nodes, hidden)
|
| 92 |
+
:param index: the selected index (neighbor indexes).
|
| 93 |
+
:return:
|
| 94 |
+
"""
|
| 95 |
+
neighbor = index_select_nd(feature, index)
|
| 96 |
+
return neighbor.sum(dim=1)
|
grover/util/parsing.py
ADDED
|
@@ -0,0 +1,487 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
The parsing functions for the argument input.
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
import pickle
|
| 6 |
+
from argparse import ArgumentParser, Namespace
|
| 7 |
+
from tempfile import TemporaryDirectory
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from grover.data.molfeaturegenerator import get_available_features_generators
|
| 12 |
+
from grover.util.utils import makedirs
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def add_common_args(parser: ArgumentParser):
|
| 16 |
+
parser.add_argument('--no_cache', action='store_true', default=True,
|
| 17 |
+
help='Turn off caching mol2graph computation')
|
| 18 |
+
parser.add_argument('--gpu', type=int, default=0,
|
| 19 |
+
choices=list(range(torch.cuda.device_count())),
|
| 20 |
+
help='Which GPU to use')
|
| 21 |
+
parser.add_argument('--no_cuda', action='store_true', default=False,
|
| 22 |
+
help='Turn off cuda')
|
| 23 |
+
parser.add_argument('--batch_size', type=int, default=32,
|
| 24 |
+
help='Batch size')
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def add_predict_args(parser: ArgumentParser):
|
| 28 |
+
"""
|
| 29 |
+
Adds predict arguments to an ArgumentParser.
|
| 30 |
+
|
| 31 |
+
:param parser: An ArgumentParser.
|
| 32 |
+
"""
|
| 33 |
+
add_common_args(parser)
|
| 34 |
+
|
| 35 |
+
parser.add_argument('--data_path', type=str,
|
| 36 |
+
help='Path to CSV file containing testing data for which predictions will be made')
|
| 37 |
+
|
| 38 |
+
parser.add_argument('--output_path', type=str,
|
| 39 |
+
help='Path to CSV file where predictions will be saved')
|
| 40 |
+
parser.add_argument('--checkpoint_dir', type=str,
|
| 41 |
+
help='Directory from which to load model checkpoints'
|
| 42 |
+
'(walks directory and ensembles all models that are found)')
|
| 43 |
+
|
| 44 |
+
parser.add_argument('--features_generator', type=str, nargs='*',
|
| 45 |
+
choices=get_available_features_generators(),
|
| 46 |
+
help='Method of generating additional features')
|
| 47 |
+
parser.add_argument('--features_path', type=str, nargs='*',
|
| 48 |
+
help='Path to features to use in FNN (instead of features_generator)')
|
| 49 |
+
parser.add_argument('--no_features_scaling', action='store_true', default=False,
|
| 50 |
+
help='Turn off scaling of features')
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def add_fingerprint_args(parser):
|
| 54 |
+
add_common_args(parser)
|
| 55 |
+
# parameters for fingerprints generation
|
| 56 |
+
parser.add_argument('--data_path', type=str, help='Input csv file which contains SMILES')
|
| 57 |
+
parser.add_argument('--output_path', type=str,
|
| 58 |
+
help='Path to npz file where predictions will be saved')
|
| 59 |
+
parser.add_argument('--features_path', type=str, nargs='*',
|
| 60 |
+
help='Path to features to use in FNN (instead of features_generator)')
|
| 61 |
+
parser.add_argument('--fingerprint_source', type=str,
|
| 62 |
+
choices=['atom', 'bond', 'both'], default='both',
|
| 63 |
+
help='The source to generate the fingerprints.')
|
| 64 |
+
parser.add_argument('--checkpoint_path', type=str, help='model path')
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def add_finetune_args(parser: ArgumentParser):
|
| 68 |
+
"""
|
| 69 |
+
Adds training arguments to an ArgumentParser.
|
| 70 |
+
|
| 71 |
+
:param parser: An ArgumentParser.
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
# General arguments
|
| 75 |
+
add_common_args(parser)
|
| 76 |
+
parser.add_argument('--tensorboard', action='store_true', default=False, help='Add tensorboard logger')
|
| 77 |
+
|
| 78 |
+
# Data argumenets
|
| 79 |
+
parser.add_argument('--data_path', type=str,
|
| 80 |
+
help='Path to data CSV file.')
|
| 81 |
+
parser.add_argument('--use_compound_names', action='store_true', default=False,
|
| 82 |
+
help='Use when test data file contains compound names in addition to SMILES strings')
|
| 83 |
+
parser.add_argument('--max_data_size', type=int,
|
| 84 |
+
help='Maximum number of data points to load')
|
| 85 |
+
# Disable this option due to some bugs.
|
| 86 |
+
# parser.add_argument('--test', action='store_true', default=False,
|
| 87 |
+
# help='Whether to skip training and only test the model')
|
| 88 |
+
parser.add_argument('--features_only', action='store_true', default=False,
|
| 89 |
+
help='Use only the additional features in an FFN, no graph network')
|
| 90 |
+
parser.add_argument('--features_generator', type=str, nargs='*',
|
| 91 |
+
choices=get_available_features_generators(),
|
| 92 |
+
help='Method of generating additional features.')
|
| 93 |
+
parser.add_argument('--features_path', type=str, nargs='*',
|
| 94 |
+
help='Path to features to use in FNN (instead of features_generator).')
|
| 95 |
+
parser.add_argument('--save_dir', type=str, default=None,
|
| 96 |
+
help='Directory where model checkpoints will be saved')
|
| 97 |
+
parser.add_argument('--save_smiles_splits', action='store_true', default=False,
|
| 98 |
+
help='Save smiles for each train/val/test splits for prediction convenience later')
|
| 99 |
+
parser.add_argument('--checkpoint_dir', type=str, default=None,
|
| 100 |
+
help='Directory from which to load model checkpoints'
|
| 101 |
+
'(walks directory and ensembles all models that are found)')
|
| 102 |
+
parser.add_argument('--checkpoint_path', type=str, default=None,
|
| 103 |
+
help='Path to model checkpoint (.pt file)')
|
| 104 |
+
|
| 105 |
+
# Data splitting.
|
| 106 |
+
parser.add_argument('--dataset_type', type=str,
|
| 107 |
+
choices=['classification', 'regression'], default='classification',
|
| 108 |
+
help='Type of dataset, e.g. classification or regression.'
|
| 109 |
+
'This determines the loss function used during training.')
|
| 110 |
+
parser.add_argument('--separate_val_path', type=str,
|
| 111 |
+
help='Path to separate val set, optional')
|
| 112 |
+
parser.add_argument('--separate_val_features_path', type=str, nargs='*',
|
| 113 |
+
help='Path to file with features for separate val set')
|
| 114 |
+
parser.add_argument('--separate_test_path', type=str,
|
| 115 |
+
help='Path to separate test set, optional')
|
| 116 |
+
parser.add_argument('--separate_test_features_path', type=str, nargs='*',
|
| 117 |
+
help='Path to file with features for separate test set')
|
| 118 |
+
parser.add_argument('--split_type', type=str, default='random',
|
| 119 |
+
choices=['random', 'scaffold_balanced', 'predetermined', 'crossval', 'index_predetermined'],
|
| 120 |
+
help='Method of splitting the data into train/val/test')
|
| 121 |
+
parser.add_argument('--split_sizes', type=float, nargs=3, default=[0.8, 0.1, 0.1],
|
| 122 |
+
help='Split proportions for train/validation/test sets')
|
| 123 |
+
parser.add_argument('--num_folds', type=int, default=1,
|
| 124 |
+
help='Number of folds when performing cross validation')
|
| 125 |
+
parser.add_argument('--folds_file', type=str, default=None,
|
| 126 |
+
help='Optional file of fold labels')
|
| 127 |
+
parser.add_argument('--val_fold_index', type=int, default=None,
|
| 128 |
+
help='Which fold to use as val for leave-one-out cross val')
|
| 129 |
+
parser.add_argument('--test_fold_index', type=int, default=None,
|
| 130 |
+
help='Which fold to use as test for leave-one-out cross val')
|
| 131 |
+
parser.add_argument('--crossval_index_dir', type=str,
|
| 132 |
+
help='Directory in which to find cross validation index files')
|
| 133 |
+
parser.add_argument('--crossval_index_file', type=str,
|
| 134 |
+
help='Indices of files to use as train/val/test'
|
| 135 |
+
'Overrides --num_folds and --seed.')
|
| 136 |
+
parser.add_argument('--seed', type=int, default=0,
|
| 137 |
+
help='Random seed to use when splitting data into train/val/test sets.'
|
| 138 |
+
'When `num_folds` > 1, the first fold uses this seed and all'
|
| 139 |
+
'subsequent folds add 1 to the seed.')
|
| 140 |
+
|
| 141 |
+
# Metric
|
| 142 |
+
parser.add_argument('--metric', type=str, default=None,
|
| 143 |
+
choices=['auc',
|
| 144 |
+
'prc-auc',
|
| 145 |
+
'rmse',
|
| 146 |
+
'mae',
|
| 147 |
+
'r2',
|
| 148 |
+
'accuracy',
|
| 149 |
+
'recall',
|
| 150 |
+
'sensitivity',
|
| 151 |
+
'specificity',
|
| 152 |
+
'matthews_corrcoef'],
|
| 153 |
+
help='Metric to use during evaluation.'
|
| 154 |
+
'Note: Does NOT affect loss function used during training'
|
| 155 |
+
'(loss is determined by the `dataset_type` argument).'
|
| 156 |
+
'Note: Defaults to "auc" for classification and "rmse" for regression.')
|
| 157 |
+
parser.add_argument('--show_individual_scores', action='store_true', default=False,
|
| 158 |
+
help='Show all scores for individual targets, not just average, at the end')
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
# Training arguments
|
| 164 |
+
parser.add_argument('--epochs', type=int, default=30,
|
| 165 |
+
help='Number of epochs to task')
|
| 166 |
+
parser.add_argument('--warmup_epochs', type=float, default=2.0,
|
| 167 |
+
help='Number of epochs during which learning rate increases linearly from'
|
| 168 |
+
'init_lr to max_lr. Afterwards, learning rate decreases exponentially'
|
| 169 |
+
'from max_lr to final_lr.')
|
| 170 |
+
parser.add_argument('--init_lr', type=float, default=1e-4,
|
| 171 |
+
help='Initial learning rate')
|
| 172 |
+
parser.add_argument('--max_lr', type=float, default=1e-3,
|
| 173 |
+
help='Maximum learning rate')
|
| 174 |
+
parser.add_argument('--final_lr', type=float, default=1e-4,
|
| 175 |
+
help='Final learning rate')
|
| 176 |
+
parser.add_argument('--no_features_scaling', action='store_true', default=False,
|
| 177 |
+
help='Turn off scaling of features')
|
| 178 |
+
parser.add_argument('--early_stop_epoch', type=int, default=1000, help='If val loss did not drop in '
|
| 179 |
+
'this epochs, stop running')
|
| 180 |
+
|
| 181 |
+
# Model arguments
|
| 182 |
+
parser.add_argument('--ensemble_size', type=int, default=1,
|
| 183 |
+
help='Number of models for ensemble prediction.')
|
| 184 |
+
parser.add_argument('--dropout', type=float, default=0.0,
|
| 185 |
+
help='Dropout probability')
|
| 186 |
+
parser.add_argument('--activation', type=str, default='ReLU',
|
| 187 |
+
choices=['ReLU', 'LeakyReLU', 'PReLU', 'tanh', 'SELU', 'ELU'],
|
| 188 |
+
help='Activation function')
|
| 189 |
+
parser.add_argument('--ffn_hidden_size', type=int, default=None,
|
| 190 |
+
help='Hidden dim for higher-capacity FFN (defaults to hidden_size)')
|
| 191 |
+
parser.add_argument('--ffn_num_layers', type=int, default=2,
|
| 192 |
+
help='Number of layers in FFN after MPN encoding')
|
| 193 |
+
parser.add_argument('--weight_decay', type=float, default=0.0, help='weight_decay')
|
| 194 |
+
parser.add_argument('--select_by_loss', action='store_true', default=False,
|
| 195 |
+
help='Use validation loss as refence standard to select best model to predict')
|
| 196 |
+
|
| 197 |
+
parser.add_argument("--embedding_output_type", default="atom", choices=["atom", "bond", "both"],
|
| 198 |
+
help="This the model parameters for pretrain model. The current finetuning task only use the "
|
| 199 |
+
"embeddings from atom branch. ")
|
| 200 |
+
|
| 201 |
+
# Self-attentive readout.
|
| 202 |
+
parser.add_argument('--self_attention', action='store_true', default=False, help='Use self attention layer. '
|
| 203 |
+
'Otherwise use mean aggregation '
|
| 204 |
+
'layer.')
|
| 205 |
+
parser.add_argument('--attn_hidden', type=int, default=4, nargs='?', help='Self attention layer '
|
| 206 |
+
'hidden layer size.')
|
| 207 |
+
parser.add_argument('--attn_out', type=int, default=128, nargs='?', help='Self attention layer '
|
| 208 |
+
'output feature size.')
|
| 209 |
+
|
| 210 |
+
parser.add_argument('--dist_coff', type=float, default=0.1, help='The dist coefficient for output of two branches.')
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
parser.add_argument('--bond_drop_rate', type=float, default=0, help='Drop out bond in molecular.')
|
| 214 |
+
parser.add_argument('--distinct_init', action='store_true', default=False,
|
| 215 |
+
help='Using distinct weight init for model ensemble')
|
| 216 |
+
parser.add_argument('--fine_tune_coff', type=float, default=1,
|
| 217 |
+
help='Enable distinct fine tune learning rate for fc and other layer')
|
| 218 |
+
|
| 219 |
+
# For multi-gpu finetune.
|
| 220 |
+
parser.add_argument('--enbl_multi_gpu', dest='enbl_multi_gpu',
|
| 221 |
+
action='store_true', default=False,
|
| 222 |
+
help='enable multi-GPU training')
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def add_pretrain_args(parser: ArgumentParser):
|
| 226 |
+
parser.add_argument('--cuda', type=bool, default=True,
|
| 227 |
+
help='Enable gpu traning or not.')
|
| 228 |
+
parser.add_argument('--enable_multi_gpu', dest='enable_multi_gpu',
|
| 229 |
+
action='store_true', default=False,
|
| 230 |
+
help='enable multi-GPU training')
|
| 231 |
+
|
| 232 |
+
# Data arguments
|
| 233 |
+
parser.add_argument('--data_path', type=str,
|
| 234 |
+
help='Path to data CSV file')
|
| 235 |
+
parser.add_argument('--fg_label_path', type=str, nargs='*',
|
| 236 |
+
help='Path to the label of fg task.')
|
| 237 |
+
parser.add_argument('--atom_vocab_path', type=str, help="Path to the vocabulary.")
|
| 238 |
+
parser.add_argument('--bond_vocab_path', type=str,
|
| 239 |
+
help="Path to the bond vocabulary.")
|
| 240 |
+
|
| 241 |
+
# Model arguments
|
| 242 |
+
parser.add_argument('--embedding_output_type', type=str, default='both', nargs='?',
|
| 243 |
+
choices=("atom", "bond", "both"),
|
| 244 |
+
help="Type of output embeddings. Options: atom, bond, both")
|
| 245 |
+
|
| 246 |
+
#parser.add_argument('--source_branch', type=str, default='both', nargs='?', choices=("atom", "bond", "both"),
|
| 247 |
+
# help="Type of source branch in gtrans. Options: atom, bond, both")
|
| 248 |
+
|
| 249 |
+
parser.add_argument('--save_dir', type=str, default=None,
|
| 250 |
+
help='Directory where model checkpoints will be saved')
|
| 251 |
+
parser.add_argument('--save_interval', type=int, default=9999999999, help='The model saving interval.')
|
| 252 |
+
parser.add_argument('--hidden_size', type=float, default=3,
|
| 253 |
+
help='Dimensionality of hidden layers. The actual dimension is hidden_size * 100.')
|
| 254 |
+
parser.add_argument('--bias', action='store_true', default=False,
|
| 255 |
+
help='Whether to add bias to linear layers')
|
| 256 |
+
parser.add_argument('--depth', type=int, default=3,
|
| 257 |
+
help='Number of message passing steps')
|
| 258 |
+
parser.add_argument('--dropout', type=float, default=0.0,
|
| 259 |
+
help='Dropout probability')
|
| 260 |
+
parser.add_argument('--activation', type=str, default='PReLU',
|
| 261 |
+
choices=['ReLU', 'LeakyReLU', 'PReLU', 'tanh', 'SELU', 'ELU'],
|
| 262 |
+
help='Activation function')
|
| 263 |
+
parser.add_argument('--undirected', action='store_true', default=False,
|
| 264 |
+
help='Undirected edges (always sum the two relevant bond vectors)')
|
| 265 |
+
parser.add_argument('--weight_decay', type=float, default=0.0, help='weight_decay')
|
| 266 |
+
parser.add_argument('--num_attn_head', type=int, default=4, help='The attention head in MTBlock.')
|
| 267 |
+
parser.add_argument('--num_mt_block', type=int, default=1, help="The number of MTBlock.")
|
| 268 |
+
parser.add_argument('--dist_coff', type=float, default=0.1, help='The disagreement coefficient for '
|
| 269 |
+
'the atom and bond branch.')
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
# Training arguments
|
| 273 |
+
parser.add_argument("--backbone", default="gtrans", choices=["gtrans"])
|
| 274 |
+
parser.add_argument('--epochs', type=int, default=30,
|
| 275 |
+
help='Number of epochs to run')
|
| 276 |
+
parser.add_argument('--batch_size', type=int, default=32,
|
| 277 |
+
help='Batch size')
|
| 278 |
+
parser.add_argument('--warmup_epochs', type=float, default=2.0,
|
| 279 |
+
help='Number of epochs during which learning rate increases linearly from'
|
| 280 |
+
'init_lr to max_lr. Afterwards, learning rate decreases exponentially'
|
| 281 |
+
'from max_lr to final_lr.')
|
| 282 |
+
parser.add_argument('--init_lr', type=float, default=1e-4,
|
| 283 |
+
help='Initial learning rate')
|
| 284 |
+
parser.add_argument('--max_lr', type=float, default=1e-3,
|
| 285 |
+
help='Maximum learning rate')
|
| 286 |
+
parser.add_argument('--final_lr', type=float, default=1e-4,
|
| 287 |
+
help='Final learning rate')
|
| 288 |
+
parser.add_argument('--bond_drop_rate', type=float, default=0, help='Drop out bond in molecular')
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def update_checkpoint_args(args: Namespace):
|
| 293 |
+
"""
|
| 294 |
+
Walks the checkpoint directory to find all checkpoints, updating args.checkpoint_paths and args.ensemble_size.
|
| 295 |
+
|
| 296 |
+
:param args: Arguments.
|
| 297 |
+
"""
|
| 298 |
+
if hasattr(args, 'checkpoint_paths') and args.checkpoint_paths is not None:
|
| 299 |
+
return
|
| 300 |
+
if not hasattr(args, 'checkpoint_path'):
|
| 301 |
+
args.checkpoint_path = None
|
| 302 |
+
|
| 303 |
+
if not hasattr(args, 'checkpoint_dir'):
|
| 304 |
+
args.checkpoint_dir = None
|
| 305 |
+
|
| 306 |
+
if args.checkpoint_dir is not None and args.checkpoint_path is not None:
|
| 307 |
+
raise ValueError('Only one of checkpoint_dir and checkpoint_path can be specified.')
|
| 308 |
+
|
| 309 |
+
if args.checkpoint_dir is None:
|
| 310 |
+
args.checkpoint_paths = [args.checkpoint_path] if args.checkpoint_path is not None else None
|
| 311 |
+
return
|
| 312 |
+
|
| 313 |
+
args.checkpoint_paths = []
|
| 314 |
+
|
| 315 |
+
for root, _, files in os.walk(args.checkpoint_dir):
|
| 316 |
+
for fname in files:
|
| 317 |
+
if fname.endswith('.pt'):
|
| 318 |
+
args.checkpoint_paths.append(os.path.join(root, fname))
|
| 319 |
+
|
| 320 |
+
if args.parser_name == "eval":
|
| 321 |
+
assert args.ensemble_size * args.num_folds == len(args.checkpoint_paths)
|
| 322 |
+
|
| 323 |
+
args.ensemble_size = len(args.checkpoint_paths)
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
if args.ensemble_size == 0:
|
| 328 |
+
raise ValueError(f'Failed to find any model checkpoints in directory "{args.checkpoint_dir}"')
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def modify_predict_args(args: Namespace):
|
| 332 |
+
"""
|
| 333 |
+
Modifies and validates predicting args in place.
|
| 334 |
+
|
| 335 |
+
:param args: Arguments.
|
| 336 |
+
"""
|
| 337 |
+
assert args.data_path
|
| 338 |
+
assert args.output_path
|
| 339 |
+
assert args.checkpoint_dir is not None or args.checkpoint_path is not None or args.checkpoint_paths is not None
|
| 340 |
+
|
| 341 |
+
update_checkpoint_args(args)
|
| 342 |
+
|
| 343 |
+
args.cuda = not args.no_cuda and torch.cuda.is_available()
|
| 344 |
+
del args.no_cuda
|
| 345 |
+
|
| 346 |
+
# Create directory for preds path
|
| 347 |
+
makedirs(args.output_path, isfile=True)
|
| 348 |
+
setattr(args, 'fingerprint', False)
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def modify_fingerprint_args(args):
|
| 352 |
+
assert args.data_path
|
| 353 |
+
assert args.output_path
|
| 354 |
+
assert args.checkpoint_path is not None or args.checkpoint_paths is not None
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
update_checkpoint_args(args)
|
| 358 |
+
args.cuda = not args.no_cuda and torch.cuda.is_available()
|
| 359 |
+
del args.no_cuda
|
| 360 |
+
makedirs(args.output_path, isfile=True)
|
| 361 |
+
setattr(args, 'fingerprint', True)
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def get_newest_train_args():
|
| 365 |
+
"""
|
| 366 |
+
For backward compatibility.
|
| 367 |
+
|
| 368 |
+
:return: A Namespace containing the newest training arguments
|
| 369 |
+
"""
|
| 370 |
+
dummy_parser = ArgumentParser()
|
| 371 |
+
add_finetune_args(dummy_parser)
|
| 372 |
+
args = dummy_parser.parse_args(args=[])
|
| 373 |
+
args.data_path = ''
|
| 374 |
+
modify_train_args(args)
|
| 375 |
+
return args
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
def modify_train_args(args: Namespace):
|
| 379 |
+
"""
|
| 380 |
+
Modifies and validates training arguments in place.
|
| 381 |
+
|
| 382 |
+
:param args: Arguments.
|
| 383 |
+
"""
|
| 384 |
+
global TEMP_DIR # Prevents the temporary directory from being deleted upon function return
|
| 385 |
+
|
| 386 |
+
assert args.data_path is not None
|
| 387 |
+
assert args.dataset_type is not None
|
| 388 |
+
|
| 389 |
+
if args.save_dir is not None:
|
| 390 |
+
makedirs(args.save_dir)
|
| 391 |
+
else:
|
| 392 |
+
TEMP_DIR = TemporaryDirectory()
|
| 393 |
+
args.save_dir = TEMP_DIR.name
|
| 394 |
+
|
| 395 |
+
args.cuda = not args.no_cuda and torch.cuda.is_available()
|
| 396 |
+
del args.no_cuda
|
| 397 |
+
|
| 398 |
+
args.features_scaling = not args.no_features_scaling
|
| 399 |
+
del args.no_features_scaling
|
| 400 |
+
|
| 401 |
+
if args.metric is None:
|
| 402 |
+
if args.dataset_type == 'classification':
|
| 403 |
+
args.metric = 'auc'
|
| 404 |
+
else:
|
| 405 |
+
args.metric = 'rmse'
|
| 406 |
+
|
| 407 |
+
if not ((args.dataset_type == 'classification' and args.metric in ['auc', 'prc-auc', 'accuracy']) or
|
| 408 |
+
(args.dataset_type == 'regression' and args.metric in ['rmse', 'mae', 'r2'])):
|
| 409 |
+
raise ValueError(f'Metric "{args.metric}" invalid for dataset type "{args.dataset_type}".')
|
| 410 |
+
|
| 411 |
+
args.minimize_score = args.metric in ['rmse', 'mae']
|
| 412 |
+
|
| 413 |
+
update_checkpoint_args(args)
|
| 414 |
+
|
| 415 |
+
if args.features_only:
|
| 416 |
+
assert args.features_generator or args.features_path
|
| 417 |
+
|
| 418 |
+
args.use_input_features = args.features_generator or args.features_path
|
| 419 |
+
|
| 420 |
+
if args.features_generator is not None and 'rdkit_2d_normalized' in args.features_generator:
|
| 421 |
+
assert not args.features_scaling
|
| 422 |
+
|
| 423 |
+
args.num_lrs = 1
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
assert (args.split_type == 'predetermined') == (args.folds_file is not None) == (args.test_fold_index is not None)
|
| 428 |
+
assert (args.split_type == 'crossval') == (args.crossval_index_dir is not None)
|
| 429 |
+
assert (args.split_type in ['crossval', 'index_predetermined']) == (args.crossval_index_file is not None)
|
| 430 |
+
if args.split_type in ['crossval', 'index_predetermined']:
|
| 431 |
+
with open(args.crossval_index_file, 'rb') as rf:
|
| 432 |
+
args.crossval_index_sets = pickle.load(rf)
|
| 433 |
+
args.num_folds = len(args.crossval_index_sets)
|
| 434 |
+
args.seed = 0
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
if args.bond_drop_rate > 0:
|
| 438 |
+
args.no_cache = True
|
| 439 |
+
|
| 440 |
+
setattr(args, 'fingerprint', False)
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
def modify_pretrain_args(args: Namespace):
|
| 444 |
+
"""
|
| 445 |
+
|
| 446 |
+
:param args:
|
| 447 |
+
:return:
|
| 448 |
+
"""
|
| 449 |
+
args.dense = False
|
| 450 |
+
args.fine_tune_coff = 1
|
| 451 |
+
args.no_cache = True
|
| 452 |
+
args.hidden_size = int(args.hidden_size)
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
def parse_args() -> Namespace:
|
| 456 |
+
"""
|
| 457 |
+
Parses arguments for training and testing (includes modifying/validating arguments).
|
| 458 |
+
|
| 459 |
+
:return: A Namespace containing the parsed, modified, and validated args.
|
| 460 |
+
"""
|
| 461 |
+
parser = ArgumentParser()
|
| 462 |
+
subparser = parser.add_subparsers(title="subcommands",
|
| 463 |
+
dest="parser_name",
|
| 464 |
+
help="Subcommands for fintune, prediction, and fingerprint.")
|
| 465 |
+
parser_finetune = subparser.add_parser('finetune', help="Fine tune the pre-trained model.")
|
| 466 |
+
add_finetune_args(parser_finetune)
|
| 467 |
+
parser_eval = subparser.add_parser('eval', help="Evaluate the results of the pre-trained model.")
|
| 468 |
+
add_finetune_args(parser_eval)
|
| 469 |
+
parser_predict = subparser.add_parser('predict', help="Predict results from fine tuned model.")
|
| 470 |
+
add_predict_args(parser_predict)
|
| 471 |
+
parser_fp = subparser.add_parser('fingerprint', help="Get the fingerprints of SMILES.")
|
| 472 |
+
add_fingerprint_args(parser_fp)
|
| 473 |
+
parser_pretrain = subparser.add_parser('pretrain', help="Pretrain with unlabelled SMILES.")
|
| 474 |
+
add_pretrain_args(parser_pretrain)
|
| 475 |
+
|
| 476 |
+
args = parser.parse_args()
|
| 477 |
+
|
| 478 |
+
if args.parser_name == 'finetune' or args.parser_name == 'eval':
|
| 479 |
+
modify_train_args(args)
|
| 480 |
+
elif args.parser_name == "pretrain":
|
| 481 |
+
modify_pretrain_args(args)
|
| 482 |
+
elif args.parser_name == 'predict':
|
| 483 |
+
modify_predict_args(args)
|
| 484 |
+
elif args.parser_name == 'fingerprint':
|
| 485 |
+
modify_fingerprint_args(args)
|
| 486 |
+
|
| 487 |
+
return args
|
grover/util/scheduler.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
The learning rate scheduler.
|
| 3 |
+
This implementation is adapted from
|
| 4 |
+
https://github.com/chemprop/chemprop/blob/master/chemprop/nn_utils.py
|
| 5 |
+
"""
|
| 6 |
+
from typing import List, Union
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class NoamLR(_LRScheduler):
|
| 13 |
+
"""
|
| 14 |
+
Noam learning rate scheduler with piecewise linear increase and exponential decay.
|
| 15 |
+
|
| 16 |
+
The learning rate increases linearly from init_lr to max_lr over the course of
|
| 17 |
+
the first warmup_steps (where warmup_steps = warmup_epochs * steps_per_epoch).
|
| 18 |
+
Then the learning rate decreases exponentially from max_lr to final_lr over the
|
| 19 |
+
course of the remaining total_steps - warmup_steps (where total_steps =
|
| 20 |
+
total_epochs * steps_per_epoch). This is roughly based on the learning rate
|
| 21 |
+
schedule from SelfAttention is All You Need, section 5.3 (https://arxiv.org/abs/1706.03762).
|
| 22 |
+
"""
|
| 23 |
+
def __init__(self,
|
| 24 |
+
optimizer,
|
| 25 |
+
warmup_epochs: List[Union[float, int]],
|
| 26 |
+
total_epochs: List[int],
|
| 27 |
+
steps_per_epoch: int,
|
| 28 |
+
init_lr: List[float],
|
| 29 |
+
max_lr: List[float],
|
| 30 |
+
final_lr: List[float],
|
| 31 |
+
fine_tune_coff: float = 1.0,
|
| 32 |
+
fine_tune_param_idx: int = 0):
|
| 33 |
+
"""
|
| 34 |
+
Initializes the learning rate scheduler.
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
:param optimizer: A PyTorch optimizer.
|
| 38 |
+
:param warmup_epochs: The number of epochs during which to linearly increase the learning rate.
|
| 39 |
+
:param total_epochs: The total number of epochs.
|
| 40 |
+
:param steps_per_epoch: The number of steps (batches) per epoch.
|
| 41 |
+
:param init_lr: The initial learning rate.
|
| 42 |
+
:param max_lr: The maximum learning rate (achieved after warmup_epochs).
|
| 43 |
+
:param final_lr: The final learning rate (achieved after total_epochs).
|
| 44 |
+
:param fine_tune_coff: The fine tune coefficient for the target param group. The true learning rate for the
|
| 45 |
+
target param group would be lr*fine_tune_coff.
|
| 46 |
+
:param fine_tune_param_idx: The index of target param group. Default is index 0.
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
# assert len(optimizer.param_groups) == len(warmup_epochs) == len(total_epochs) == len(init_lr) == \
|
| 50 |
+
# len(max_lr) == len(final_lr)
|
| 51 |
+
|
| 52 |
+
self.num_lrs = len(optimizer.param_groups)
|
| 53 |
+
|
| 54 |
+
self.optimizer = optimizer
|
| 55 |
+
self.warmup_epochs = np.array([warmup_epochs] * self.num_lrs)
|
| 56 |
+
self.total_epochs = np.array([total_epochs] * self.num_lrs)
|
| 57 |
+
self.steps_per_epoch = steps_per_epoch
|
| 58 |
+
self.init_lr = np.array([init_lr] * self.num_lrs)
|
| 59 |
+
self.max_lr = np.array([max_lr] * self.num_lrs)
|
| 60 |
+
self.final_lr = np.array([final_lr] * self.num_lrs)
|
| 61 |
+
self.lr_coff = np.array([1] * self.num_lrs)
|
| 62 |
+
self.fine_tune_param_idx = fine_tune_param_idx
|
| 63 |
+
self.lr_coff[self.fine_tune_param_idx] = fine_tune_coff
|
| 64 |
+
|
| 65 |
+
self.current_step = 0
|
| 66 |
+
self.lr = [init_lr] * self.num_lrs
|
| 67 |
+
self.warmup_steps = (self.warmup_epochs * self.steps_per_epoch).astype(int)
|
| 68 |
+
self.total_steps = self.total_epochs * self.steps_per_epoch
|
| 69 |
+
self.linear_increment = (self.max_lr - self.init_lr) / self.warmup_steps
|
| 70 |
+
|
| 71 |
+
self.exponential_gamma = (self.final_lr / self.max_lr) ** (1 / (self.total_steps - self.warmup_steps))
|
| 72 |
+
super(NoamLR, self).__init__(optimizer)
|
| 73 |
+
|
| 74 |
+
def get_lr(self) -> List[float]:
|
| 75 |
+
"""Gets a list of the current learning rates."""
|
| 76 |
+
return list(self.lr)
|
| 77 |
+
|
| 78 |
+
def step(self, current_step: int = None):
|
| 79 |
+
"""
|
| 80 |
+
Updates the learning rate by taking a step.
|
| 81 |
+
|
| 82 |
+
:param current_step: Optionally specify what step to set the learning rate to.
|
| 83 |
+
If None, current_step = self.current_step + 1.
|
| 84 |
+
"""
|
| 85 |
+
if current_step is not None:
|
| 86 |
+
self.current_step = current_step
|
| 87 |
+
else:
|
| 88 |
+
self.current_step += 1
|
| 89 |
+
for i in range(self.num_lrs):
|
| 90 |
+
if self.current_step <= self.warmup_steps[i]:
|
| 91 |
+
self.lr[i] = self.init_lr[i] + self.current_step * self.linear_increment[i]
|
| 92 |
+
elif self.current_step <= self.total_steps[i]:
|
| 93 |
+
self.lr[i] = self.max_lr[i] * (self.exponential_gamma[i] ** (self.current_step - self.warmup_steps[i]))
|
| 94 |
+
else: # theoretically this case should never be reached since training should stop at total_steps
|
| 95 |
+
self.lr[i] = self.final_lr[i]
|
| 96 |
+
self.lr[i] *= self.lr_coff[i]
|
| 97 |
+
self.optimizer.param_groups[i]['lr'] = self.lr[i]
|
grover/util/utils.py
ADDED
|
@@ -0,0 +1,797 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
The general utility functions.
|
| 3 |
+
"""
|
| 4 |
+
import csv
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
import pickle
|
| 8 |
+
import random
|
| 9 |
+
from argparse import Namespace
|
| 10 |
+
from collections import defaultdict
|
| 11 |
+
from logging import Logger
|
| 12 |
+
from typing import List, Set, Tuple, Union, Dict
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
import torch
|
| 16 |
+
from rdkit import Chem
|
| 17 |
+
from rdkit.Chem.Scaffolds import MurckoScaffold
|
| 18 |
+
from torch import nn as nn
|
| 19 |
+
from tqdm import tqdm as core_tqdm
|
| 20 |
+
|
| 21 |
+
from grover.data import MoleculeDatapoint, MoleculeDataset, StandardScaler
|
| 22 |
+
from grover.model.models import GroverFpGeneration, GroverFinetuneTask
|
| 23 |
+
from grover.util.nn_utils import initialize_weights
|
| 24 |
+
from grover.util.scheduler import NoamLR
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
np.float = float
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def get_model_args():
|
| 31 |
+
"""
|
| 32 |
+
Get model structure related parameters
|
| 33 |
+
|
| 34 |
+
:return: a list containing parameters
|
| 35 |
+
"""
|
| 36 |
+
return ['model_type', 'ensemble_size', 'input_layer', 'hidden_size', 'bias', 'depth',
|
| 37 |
+
'dropout', 'activation', 'undirected', 'ffn_hidden_size', 'ffn_num_layers',
|
| 38 |
+
'atom_message', 'weight_decay', 'select_by_loss', 'skip_epoch', 'backbone',
|
| 39 |
+
'embedding_output_type', 'self_attention', 'attn_hidden', 'attn_out', 'dense',
|
| 40 |
+
'bond_drop_rate', 'distinct_init', 'aug_rate', 'fine_tune_coff', 'nencoders',
|
| 41 |
+
'dist_coff', 'no_attach_fea', 'coord', "num_attn_head", "num_mt_block",
|
| 42 |
+
]
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def save_features(path: str, features: List[np.ndarray]):
|
| 46 |
+
"""
|
| 47 |
+
Saves features to a compressed .npz file with array name "features".
|
| 48 |
+
|
| 49 |
+
:param path: Path to a .npz file where the features will be saved.
|
| 50 |
+
:param features: A list of 1D numpy arrays containing the features for molecules.
|
| 51 |
+
"""
|
| 52 |
+
np.savez_compressed(path, features=features)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def load_features(path: str) -> np.ndarray:
|
| 56 |
+
"""
|
| 57 |
+
Loads features saved in a variety of formats.
|
| 58 |
+
|
| 59 |
+
Supported formats:
|
| 60 |
+
- .npz compressed (assumes features are saved with name "features")
|
| 61 |
+
|
| 62 |
+
All formats assume that the SMILES strings loaded elsewhere in the code are in the same
|
| 63 |
+
order as the features loaded here.
|
| 64 |
+
|
| 65 |
+
:param path: Path to a file containing features.
|
| 66 |
+
:return: A 2D numpy array of size (num_molecules, features_size) containing the features.
|
| 67 |
+
"""
|
| 68 |
+
extension = os.path.splitext(path)[1]
|
| 69 |
+
|
| 70 |
+
if extension == '.npz':
|
| 71 |
+
features = np.load(path)['features']
|
| 72 |
+
else:
|
| 73 |
+
raise ValueError(f'Features path extension {extension} not supported.')
|
| 74 |
+
|
| 75 |
+
return features
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class tqdm(core_tqdm):
|
| 79 |
+
def __init__(self, *args, **kwargs):
|
| 80 |
+
kwargs.setdefault("ascii", True)
|
| 81 |
+
super(tqdm, self).__init__(*args, **kwargs)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def get_task_names(path: str, use_compound_names: bool = False) -> List[str]:
|
| 85 |
+
"""
|
| 86 |
+
Gets the task names from a data CSV file.
|
| 87 |
+
|
| 88 |
+
:param path: Path to a CSV file.
|
| 89 |
+
:param use_compound_names: Whether file has compound names in addition to smiles strings.
|
| 90 |
+
:return: A list of task names.
|
| 91 |
+
"""
|
| 92 |
+
index = 2 if use_compound_names else 1
|
| 93 |
+
task_names = get_header(path)[index:]
|
| 94 |
+
|
| 95 |
+
return task_names
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def get_header(path: str) -> List[str]:
|
| 99 |
+
"""
|
| 100 |
+
Returns the header of a data CSV file.
|
| 101 |
+
|
| 102 |
+
:param path: Path to a CSV file.
|
| 103 |
+
:return: A list of strings containing the strings in the comma-separated header.
|
| 104 |
+
"""
|
| 105 |
+
with open(path) as f:
|
| 106 |
+
header = next(csv.reader(f))
|
| 107 |
+
|
| 108 |
+
return header
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def get_num_tasks(path: str) -> int:
|
| 112 |
+
"""
|
| 113 |
+
Gets the number of tasks in a data CSV file.
|
| 114 |
+
|
| 115 |
+
:param path: Path to a CSV file.
|
| 116 |
+
:return: The number of tasks.
|
| 117 |
+
"""
|
| 118 |
+
return len(get_header(path)) - 1
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def filter_invalid_smiles(data: MoleculeDataset) -> MoleculeDataset:
|
| 123 |
+
"""
|
| 124 |
+
Filters out invalid SMILES.
|
| 125 |
+
|
| 126 |
+
:param data: A MoleculeDataset.
|
| 127 |
+
:return: A MoleculeDataset with only valid molecules.
|
| 128 |
+
"""
|
| 129 |
+
datapoint_list = []
|
| 130 |
+
for idx, datapoint in enumerate(data):
|
| 131 |
+
if datapoint.smiles == '':
|
| 132 |
+
print(f'invalid smiles {idx}: {datapoint.smiles}')
|
| 133 |
+
continue
|
| 134 |
+
mol = Chem.MolFromSmiles(datapoint.smiles)
|
| 135 |
+
if mol.GetNumHeavyAtoms() == 0:
|
| 136 |
+
print(f'invalid heavy {idx}')
|
| 137 |
+
continue
|
| 138 |
+
datapoint_list.append(datapoint)
|
| 139 |
+
return MoleculeDataset(datapoint_list)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def get_data(path: str,
|
| 143 |
+
skip_invalid_smiles: bool = True,
|
| 144 |
+
args: Namespace = None,
|
| 145 |
+
features_path: List[str] = None,
|
| 146 |
+
max_data_size: int = None,
|
| 147 |
+
use_compound_names: bool = None,
|
| 148 |
+
logger: Logger = None) -> MoleculeDataset:
|
| 149 |
+
"""
|
| 150 |
+
Gets smiles string and target values (and optionally compound names if provided) from a CSV file.
|
| 151 |
+
|
| 152 |
+
:param path: Path to a CSV file.
|
| 153 |
+
:param skip_invalid_smiles: Whether to skip and filter out invalid smiles.
|
| 154 |
+
:param args: Arguments.
|
| 155 |
+
:param features_path: A list of paths to files containing features. If provided, it is used
|
| 156 |
+
in place of args.features_path.
|
| 157 |
+
:param max_data_size: The maximum number of data points to load.
|
| 158 |
+
:param use_compound_names: Whether file has compound names in addition to smiles strings.
|
| 159 |
+
:param logger: Logger.
|
| 160 |
+
:return: A MoleculeDataset containing smiles strings and target values along
|
| 161 |
+
with other info such as additional features and compound names when desired.
|
| 162 |
+
"""
|
| 163 |
+
debug = logger.debug if logger is not None else print
|
| 164 |
+
|
| 165 |
+
if args is not None:
|
| 166 |
+
# Prefer explicit function arguments but default to args if not provided
|
| 167 |
+
features_path = features_path if features_path is not None else args.features_path
|
| 168 |
+
max_data_size = max_data_size if max_data_size is not None else args.max_data_size
|
| 169 |
+
use_compound_names = use_compound_names if use_compound_names is not None else args.use_compound_names
|
| 170 |
+
else:
|
| 171 |
+
use_compound_names = False
|
| 172 |
+
|
| 173 |
+
max_data_size = max_data_size or float('inf')
|
| 174 |
+
|
| 175 |
+
# Load features
|
| 176 |
+
if features_path is not None:
|
| 177 |
+
features_data = []
|
| 178 |
+
for feat_path in features_path:
|
| 179 |
+
features_data.append(load_features(feat_path)) # each is num_data x num_features
|
| 180 |
+
features_data = np.concatenate(features_data, axis=1)
|
| 181 |
+
args.features_dim = len(features_data[0])
|
| 182 |
+
else:
|
| 183 |
+
features_data = None
|
| 184 |
+
if args is not None:
|
| 185 |
+
args.features_dim = 0
|
| 186 |
+
|
| 187 |
+
skip_smiles = set()
|
| 188 |
+
|
| 189 |
+
# Load data
|
| 190 |
+
with open(path) as f:
|
| 191 |
+
reader = csv.reader(f)
|
| 192 |
+
next(reader) # skip header
|
| 193 |
+
|
| 194 |
+
lines = []
|
| 195 |
+
for line in reader:
|
| 196 |
+
smiles = line[0]
|
| 197 |
+
|
| 198 |
+
if smiles in skip_smiles:
|
| 199 |
+
continue
|
| 200 |
+
|
| 201 |
+
lines.append(line)
|
| 202 |
+
|
| 203 |
+
if len(lines) >= max_data_size:
|
| 204 |
+
break
|
| 205 |
+
|
| 206 |
+
data = MoleculeDataset([
|
| 207 |
+
MoleculeDatapoint(
|
| 208 |
+
line=line,
|
| 209 |
+
args=args,
|
| 210 |
+
features=features_data[i] if features_data is not None else None,
|
| 211 |
+
use_compound_names=use_compound_names
|
| 212 |
+
) for i, line in tqdm(enumerate(lines), total=len(lines), disable=True)
|
| 213 |
+
])
|
| 214 |
+
|
| 215 |
+
# Filter out invalid SMILES
|
| 216 |
+
if skip_invalid_smiles:
|
| 217 |
+
original_data_len = len(data)
|
| 218 |
+
data = filter_invalid_smiles(data)
|
| 219 |
+
|
| 220 |
+
if len(data) < original_data_len:
|
| 221 |
+
debug(f'Warning: {original_data_len - len(data)} SMILES are invalid.')
|
| 222 |
+
|
| 223 |
+
return data
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def get_data_from_smiles(smiles: List[str], skip_invalid_smiles: bool = True, logger: Logger = None,
|
| 227 |
+
args: Namespace = None) -> MoleculeDataset:
|
| 228 |
+
"""
|
| 229 |
+
Converts SMILES to a MoleculeDataset.
|
| 230 |
+
|
| 231 |
+
:param smiles: A list of SMILES strings.
|
| 232 |
+
:param skip_invalid_smiles: Whether to skip and filter out invalid smiles.
|
| 233 |
+
:param logger: Logger.
|
| 234 |
+
:return: A MoleculeDataset with all of the provided SMILES.
|
| 235 |
+
"""
|
| 236 |
+
debug = logger.debug if logger is not None else print
|
| 237 |
+
|
| 238 |
+
data = MoleculeDataset([MoleculeDatapoint(line=[smile], args=args) for smile in smiles])
|
| 239 |
+
|
| 240 |
+
# Filter out invalid SMILES
|
| 241 |
+
if skip_invalid_smiles:
|
| 242 |
+
original_data_len = len(data)
|
| 243 |
+
data = filter_invalid_smiles(data)
|
| 244 |
+
|
| 245 |
+
if len(data) < original_data_len:
|
| 246 |
+
debug(f'Warning: {original_data_len - len(data)} SMILES are invalid.')
|
| 247 |
+
|
| 248 |
+
return data
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def split_data(data: MoleculeDataset,
|
| 252 |
+
split_type: str = 'random',
|
| 253 |
+
sizes: Tuple[float, float, float] = (0.8, 0.1, 0.1),
|
| 254 |
+
seed: int = 0,
|
| 255 |
+
args: Namespace = None,
|
| 256 |
+
logger: Logger = None) -> Tuple[MoleculeDataset,
|
| 257 |
+
MoleculeDataset,
|
| 258 |
+
MoleculeDataset]:
|
| 259 |
+
"""
|
| 260 |
+
Splits data into training, validation, and test splits.
|
| 261 |
+
|
| 262 |
+
:param data: A MoleculeDataset.
|
| 263 |
+
:param split_type: Split type.
|
| 264 |
+
:param sizes: A length-3 tuple with the proportions of data in the
|
| 265 |
+
train, validation, and test sets.
|
| 266 |
+
:param seed: The random seed to use before shuffling data.
|
| 267 |
+
:param args: Namespace of arguments.
|
| 268 |
+
:param logger: A logger.
|
| 269 |
+
:return: A tuple containing the train, validation, and test splits of the data.
|
| 270 |
+
"""
|
| 271 |
+
assert len(sizes) == 3 and sum(sizes) == 1
|
| 272 |
+
|
| 273 |
+
if args is not None:
|
| 274 |
+
folds_file, val_fold_index, test_fold_index = \
|
| 275 |
+
args.folds_file, args.val_fold_index, args.test_fold_index
|
| 276 |
+
else:
|
| 277 |
+
folds_file = val_fold_index = test_fold_index = None
|
| 278 |
+
|
| 279 |
+
if split_type == 'crossval':
|
| 280 |
+
index_set = args.crossval_index_sets[args.seed]
|
| 281 |
+
data_split = []
|
| 282 |
+
for split in range(3):
|
| 283 |
+
split_indices = []
|
| 284 |
+
for index in index_set[split]:
|
| 285 |
+
with open(os.path.join(args.crossval_index_dir, f'{index}.pkl'), 'rb') as rf:
|
| 286 |
+
split_indices.extend(pickle.load(rf))
|
| 287 |
+
data_split.append([data[i] for i in split_indices])
|
| 288 |
+
train, val, test = tuple(data_split)
|
| 289 |
+
return MoleculeDataset(train), MoleculeDataset(val), MoleculeDataset(test)
|
| 290 |
+
|
| 291 |
+
elif split_type == 'index_predetermined':
|
| 292 |
+
split_indices = args.crossval_index_sets[args.seed]
|
| 293 |
+
assert len(split_indices) == 3
|
| 294 |
+
data_split = []
|
| 295 |
+
for split in range(3):
|
| 296 |
+
data_split.append([data[i] for i in split_indices[split]])
|
| 297 |
+
train, val, test = tuple(data_split)
|
| 298 |
+
return MoleculeDataset(train), MoleculeDataset(val), MoleculeDataset(test)
|
| 299 |
+
|
| 300 |
+
elif split_type == 'predetermined':
|
| 301 |
+
if not val_fold_index:
|
| 302 |
+
assert sizes[2] == 0 # test set is created separately so use all of the other data for train and val
|
| 303 |
+
assert folds_file is not None
|
| 304 |
+
assert test_fold_index is not None
|
| 305 |
+
|
| 306 |
+
try:
|
| 307 |
+
with open(folds_file, 'rb') as f:
|
| 308 |
+
all_fold_indices = pickle.load(f)
|
| 309 |
+
except UnicodeDecodeError:
|
| 310 |
+
with open(folds_file, 'rb') as f:
|
| 311 |
+
all_fold_indices = pickle.load(f, encoding='latin1') # in case we're loading indices from python2
|
| 312 |
+
# assert len(data) == sum([len(fold_indices) for fold_indices in all_fold_indices])
|
| 313 |
+
|
| 314 |
+
log_scaffold_stats(data, all_fold_indices, logger=logger)
|
| 315 |
+
|
| 316 |
+
folds = [[data[i] for i in fold_indices] for fold_indices in all_fold_indices]
|
| 317 |
+
|
| 318 |
+
test = folds[test_fold_index]
|
| 319 |
+
if val_fold_index is not None:
|
| 320 |
+
val = folds[val_fold_index]
|
| 321 |
+
|
| 322 |
+
train_val = []
|
| 323 |
+
for i in range(len(folds)):
|
| 324 |
+
if i != test_fold_index and (val_fold_index is None or i != val_fold_index):
|
| 325 |
+
train_val.extend(folds[i])
|
| 326 |
+
|
| 327 |
+
if val_fold_index is not None:
|
| 328 |
+
train = train_val
|
| 329 |
+
else:
|
| 330 |
+
random.seed(seed)
|
| 331 |
+
random.shuffle(train_val)
|
| 332 |
+
train_size = int(sizes[0] * len(train_val))
|
| 333 |
+
train = train_val[:train_size]
|
| 334 |
+
val = train_val[train_size:]
|
| 335 |
+
|
| 336 |
+
return MoleculeDataset(train), MoleculeDataset(val), MoleculeDataset(test)
|
| 337 |
+
|
| 338 |
+
elif split_type == 'scaffold_balanced':
|
| 339 |
+
return scaffold_split(data, sizes=sizes, balanced=True, seed=seed, logger=logger)
|
| 340 |
+
|
| 341 |
+
elif split_type == 'random':
|
| 342 |
+
data.shuffle(seed=seed)
|
| 343 |
+
|
| 344 |
+
train_size = int(sizes[0] * len(data))
|
| 345 |
+
train_val_size = int((sizes[0] + sizes[1]) * len(data))
|
| 346 |
+
|
| 347 |
+
train = data[:train_size]
|
| 348 |
+
val = data[train_size:train_val_size]
|
| 349 |
+
test = data[train_val_size:]
|
| 350 |
+
|
| 351 |
+
return MoleculeDataset(train), MoleculeDataset(val), MoleculeDataset(test)
|
| 352 |
+
|
| 353 |
+
else:
|
| 354 |
+
raise ValueError(f'split_type "{split_type}" not supported.')
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
def get_class_sizes(data: MoleculeDataset) -> List[List[float]]:
|
| 358 |
+
"""
|
| 359 |
+
Determines the proportions of the different classes in the classification dataset.
|
| 360 |
+
|
| 361 |
+
:param data: A classification dataset
|
| 362 |
+
:return: A list of lists of class proportions. Each inner list contains the class proportions
|
| 363 |
+
for a task.
|
| 364 |
+
"""
|
| 365 |
+
targets = data.targets()
|
| 366 |
+
|
| 367 |
+
# Filter out Nones
|
| 368 |
+
valid_targets = [[] for _ in range(data.num_tasks())]
|
| 369 |
+
for i in range(len(targets)):
|
| 370 |
+
for task_num in range(len(targets[i])):
|
| 371 |
+
if targets[i][task_num] is not None:
|
| 372 |
+
valid_targets[task_num].append(targets[i][task_num])
|
| 373 |
+
|
| 374 |
+
class_sizes = []
|
| 375 |
+
for task_targets in valid_targets:
|
| 376 |
+
# Make sure we're dealing with a binary classification task
|
| 377 |
+
assert set(np.unique(task_targets)) <= {0, 1}
|
| 378 |
+
|
| 379 |
+
try:
|
| 380 |
+
ones = np.count_nonzero(task_targets) / len(task_targets)
|
| 381 |
+
except ZeroDivisionError:
|
| 382 |
+
ones = float('nan')
|
| 383 |
+
print('Warning: class has no targets')
|
| 384 |
+
class_sizes.append([1 - ones, ones])
|
| 385 |
+
|
| 386 |
+
return class_sizes
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def generate_scaffold(mol: Union[str, Chem.Mol], include_chirality: bool = False) -> str:
|
| 390 |
+
"""
|
| 391 |
+
Compute the Bemis-Murcko scaffold for a SMILES string.
|
| 392 |
+
|
| 393 |
+
:param mol: A smiles string or an RDKit molecule.
|
| 394 |
+
:param include_chirality: Whether to include chirality.
|
| 395 |
+
:return:
|
| 396 |
+
"""
|
| 397 |
+
mol = Chem.MolFromSmiles(mol) if type(mol) == str else mol
|
| 398 |
+
scaffold = MurckoScaffold.MurckoScaffoldSmiles(mol=mol, includeChirality=include_chirality)
|
| 399 |
+
|
| 400 |
+
return scaffold
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
def scaffold_to_smiles(mols: Union[List[str], List[Chem.Mol]],
|
| 404 |
+
use_indices: bool = False) -> Dict[str, Union[Set[str], Set[int]]]:
|
| 405 |
+
"""
|
| 406 |
+
Computes scaffold for each smiles string and returns a mapping from scaffolds to sets of smiles.
|
| 407 |
+
|
| 408 |
+
:param mols: A list of smiles strings or RDKit molecules.
|
| 409 |
+
:param use_indices: Whether to map to the smiles' index in all_smiles rather than mapping
|
| 410 |
+
to the smiles string itself. This is necessary if there are duplicate smiles.
|
| 411 |
+
:return: A dictionary mapping each unique scaffold to all smiles (or smiles indices) which have that scaffold.
|
| 412 |
+
"""
|
| 413 |
+
scaffolds = defaultdict(set)
|
| 414 |
+
for i, mol in tqdm(enumerate(mols), total=len(mols)):
|
| 415 |
+
scaffold = generate_scaffold(mol)
|
| 416 |
+
if use_indices:
|
| 417 |
+
scaffolds[scaffold].add(i)
|
| 418 |
+
else:
|
| 419 |
+
scaffolds[scaffold].add(mol)
|
| 420 |
+
|
| 421 |
+
return scaffolds
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
def scaffold_split(data: MoleculeDataset,
|
| 425 |
+
sizes: Tuple[float, float, float] = (0.8, 0.1, 0.1),
|
| 426 |
+
balanced: bool = False,
|
| 427 |
+
seed: int = 0,
|
| 428 |
+
logger: logging.Logger = None) -> Tuple[MoleculeDataset,
|
| 429 |
+
MoleculeDataset,
|
| 430 |
+
MoleculeDataset]:
|
| 431 |
+
"""
|
| 432 |
+
Split a dataset by scaffold so that no molecules sharing a scaffold are in the same split.
|
| 433 |
+
|
| 434 |
+
:param data: A MoleculeDataset.
|
| 435 |
+
:param sizes: A length-3 tuple with the proportions of data in the
|
| 436 |
+
train, validation, and test sets.
|
| 437 |
+
:param balanced: Try to balance sizes of scaffolds in each set, rather than just putting smallest in test set.
|
| 438 |
+
:param seed: Seed for shuffling when doing balanced splitting.
|
| 439 |
+
:param logger: A logger.
|
| 440 |
+
:return: A tuple containing the train, validation, and test splits of the data.
|
| 441 |
+
"""
|
| 442 |
+
assert sum(sizes) == 1
|
| 443 |
+
|
| 444 |
+
# Split
|
| 445 |
+
train_size, val_size, test_size = sizes[0] * len(data), sizes[1] * len(data), sizes[2] * len(data)
|
| 446 |
+
train, val, test = [], [], []
|
| 447 |
+
train_scaffold_count, val_scaffold_count, test_scaffold_count = 0, 0, 0
|
| 448 |
+
|
| 449 |
+
# Map from scaffold to index in the data
|
| 450 |
+
scaffold_to_indices = scaffold_to_smiles(data.smiles(), use_indices=True)
|
| 451 |
+
|
| 452 |
+
if balanced: # Put stuff that's bigger than half the val/test size into train, rest just order randomly
|
| 453 |
+
index_sets = list(scaffold_to_indices.values())
|
| 454 |
+
big_index_sets = []
|
| 455 |
+
small_index_sets = []
|
| 456 |
+
for index_set in index_sets:
|
| 457 |
+
if len(index_set) > val_size / 2 or len(index_set) > test_size / 2:
|
| 458 |
+
big_index_sets.append(index_set)
|
| 459 |
+
else:
|
| 460 |
+
small_index_sets.append(index_set)
|
| 461 |
+
random.seed(seed)
|
| 462 |
+
random.shuffle(big_index_sets)
|
| 463 |
+
random.shuffle(small_index_sets)
|
| 464 |
+
index_sets = big_index_sets + small_index_sets
|
| 465 |
+
else: # Sort from largest to smallest scaffold sets
|
| 466 |
+
index_sets = sorted(list(scaffold_to_indices.values()),
|
| 467 |
+
key=lambda index_set: len(index_set),
|
| 468 |
+
reverse=True)
|
| 469 |
+
|
| 470 |
+
for index_set in index_sets:
|
| 471 |
+
if len(train) + len(index_set) <= train_size:
|
| 472 |
+
train += index_set
|
| 473 |
+
train_scaffold_count += 1
|
| 474 |
+
elif len(val) + len(index_set) <= val_size:
|
| 475 |
+
val += index_set
|
| 476 |
+
val_scaffold_count += 1
|
| 477 |
+
else:
|
| 478 |
+
test += index_set
|
| 479 |
+
test_scaffold_count += 1
|
| 480 |
+
|
| 481 |
+
if logger is not None:
|
| 482 |
+
logger.debug(f'Total scaffolds = {len(scaffold_to_indices):,} | '
|
| 483 |
+
f'train scaffolds = {train_scaffold_count:,} | '
|
| 484 |
+
f'val scaffolds = {val_scaffold_count:,} | '
|
| 485 |
+
f'test scaffolds = {test_scaffold_count:,}')
|
| 486 |
+
|
| 487 |
+
log_scaffold_stats(data, index_sets, logger=logger)
|
| 488 |
+
|
| 489 |
+
# Map from indices to data
|
| 490 |
+
train = [data[i] for i in train]
|
| 491 |
+
val = [data[i] for i in val]
|
| 492 |
+
test = [data[i] for i in test]
|
| 493 |
+
|
| 494 |
+
return MoleculeDataset(train), MoleculeDataset(val), MoleculeDataset(test)
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
def log_scaffold_stats(data: MoleculeDataset,
|
| 498 |
+
index_sets: List[Set[int]],
|
| 499 |
+
num_scaffolds: int = 10,
|
| 500 |
+
num_labels: int = 20,
|
| 501 |
+
logger: logging.Logger = None) -> List[Tuple[List[float], List[int]]]:
|
| 502 |
+
"""
|
| 503 |
+
Logs and returns statistics about counts and average target values in molecular scaffolds.
|
| 504 |
+
|
| 505 |
+
:param data: A MoleculeDataset.
|
| 506 |
+
:param index_sets: A list of sets of indices representing splits of the data.
|
| 507 |
+
:param num_scaffolds: The number of scaffolds about which to display statistics.
|
| 508 |
+
:param num_labels: The number of labels about which to display statistics.
|
| 509 |
+
:param logger: A Logger.
|
| 510 |
+
:return: A list of tuples where each tuple contains a list of average target values
|
| 511 |
+
across the first num_labels labels and a list of the number of non-zero values for
|
| 512 |
+
the first num_scaffolds scaffolds, sorted in decreasing order of scaffold frequency.
|
| 513 |
+
"""
|
| 514 |
+
# print some statistics about scaffolds
|
| 515 |
+
target_avgs = []
|
| 516 |
+
counts = []
|
| 517 |
+
for index_set in index_sets:
|
| 518 |
+
data_set = [data[i] for i in index_set]
|
| 519 |
+
targets = [d.targets for d in data_set]
|
| 520 |
+
targets = np.array(targets, dtype=np.float)
|
| 521 |
+
target_avgs.append(np.nanmean(targets, axis=0))
|
| 522 |
+
counts.append(np.count_nonzero(~np.isnan(targets), axis=0))
|
| 523 |
+
stats = [(target_avgs[i][:num_labels], counts[i][:num_labels]) for i in range(min(num_scaffolds, len(target_avgs)))]
|
| 524 |
+
|
| 525 |
+
if logger is not None:
|
| 526 |
+
logger.debug('Label averages per scaffold, in decreasing order of scaffold frequency,'
|
| 527 |
+
f'capped at {num_scaffolds} scaffolds and {num_labels} labels: {stats}')
|
| 528 |
+
|
| 529 |
+
return stats
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
def makedirs(path: str, isfile: bool = False):
|
| 533 |
+
"""
|
| 534 |
+
Creates a directory given a path to either a directory or file.
|
| 535 |
+
|
| 536 |
+
If a directory is provided, creates that directory. If a file is provided (i.e. isfiled == True),
|
| 537 |
+
creates the parent directory for that file.
|
| 538 |
+
|
| 539 |
+
:param path: Path to a directory or file.
|
| 540 |
+
:param isfile: Whether the provided path is a directory or file.
|
| 541 |
+
"""
|
| 542 |
+
if isfile:
|
| 543 |
+
path = os.path.dirname(path)
|
| 544 |
+
if path != '':
|
| 545 |
+
os.makedirs(path, exist_ok=True)
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
def load_args(path: str) -> Namespace:
|
| 549 |
+
"""
|
| 550 |
+
Loads the arguments a model was trained with.
|
| 551 |
+
|
| 552 |
+
:param path: Path where model checkpoint is saved.
|
| 553 |
+
:return: The arguments Namespace that the model was trained with.
|
| 554 |
+
"""
|
| 555 |
+
return torch.load(path, map_location=lambda storage, loc: storage)['args']
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
def get_ffn_layer_id(model: GroverFinetuneTask):
|
| 560 |
+
"""
|
| 561 |
+
Get the ffn layer id for GroverFinetune Task. (Adhoc!)
|
| 562 |
+
:param model:
|
| 563 |
+
:return:
|
| 564 |
+
"""
|
| 565 |
+
return [id(x) for x in model.state_dict() if "grover" not in x and "ffn" in x]
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
def build_optimizer(model: nn.Module, args: Namespace):
|
| 569 |
+
"""
|
| 570 |
+
Builds an Optimizer.
|
| 571 |
+
|
| 572 |
+
:param model: The model to optimize.
|
| 573 |
+
:param args: Arguments.
|
| 574 |
+
:return: An initialized Optimizer.
|
| 575 |
+
"""
|
| 576 |
+
|
| 577 |
+
# Only adjust the learning rate for the GroverFinetuneTask.
|
| 578 |
+
if type(model) == GroverFinetuneTask:
|
| 579 |
+
ffn_params = get_ffn_layer_id(model)
|
| 580 |
+
else:
|
| 581 |
+
# if not, init adam optimizer normally.
|
| 582 |
+
return torch.optim.Adam(model.parameters(), lr=args.init_lr, weight_decay=args.weight_decay)
|
| 583 |
+
base_params = filter(lambda p: id(p) not in ffn_params, model.parameters())
|
| 584 |
+
ffn_params = filter(lambda p: id(p) in ffn_params, model.parameters())
|
| 585 |
+
if args.fine_tune_coff == 0:
|
| 586 |
+
for param in base_params:
|
| 587 |
+
param.requires_grad = False
|
| 588 |
+
|
| 589 |
+
optimizer = torch.optim.Adam([
|
| 590 |
+
{'params': base_params, 'lr': args.init_lr * args.fine_tune_coff},
|
| 591 |
+
{'params': ffn_params, 'lr': args.init_lr}
|
| 592 |
+
], lr=args.init_lr, weight_decay=args.weight_decay)
|
| 593 |
+
|
| 594 |
+
return optimizer
|
| 595 |
+
|
| 596 |
+
|
| 597 |
+
def build_lr_scheduler(optimizer, args: Namespace, total_epochs: List[int] = None):
|
| 598 |
+
"""
|
| 599 |
+
Builds a learning rate scheduler.
|
| 600 |
+
|
| 601 |
+
:param optimizer: The Optimizer whose learning rate will be scheduled.
|
| 602 |
+
:param args: Arguments.
|
| 603 |
+
:param total_epochs: The total number of epochs for which the model will be task.
|
| 604 |
+
:return: An initialized learning rate scheduler.
|
| 605 |
+
"""
|
| 606 |
+
|
| 607 |
+
# Learning rate scheduler
|
| 608 |
+
# Divide the parameter into two groups for the finetune.
|
| 609 |
+
return NoamLR(
|
| 610 |
+
optimizer=optimizer,
|
| 611 |
+
warmup_epochs=args.warmup_epochs,
|
| 612 |
+
total_epochs=args.epochs,
|
| 613 |
+
steps_per_epoch=args.train_data_size // args.batch_size,
|
| 614 |
+
init_lr=args.init_lr,
|
| 615 |
+
max_lr=args.max_lr,
|
| 616 |
+
final_lr=args.final_lr,
|
| 617 |
+
fine_tune_coff=args.fine_tune_coff
|
| 618 |
+
)
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
def create_logger(name: str, save_dir: str = None, quiet: bool = False) -> logging.Logger:
|
| 622 |
+
"""
|
| 623 |
+
Creates a logger with a stream handler and two file handlers.
|
| 624 |
+
|
| 625 |
+
The stream handler prints to the screen depending on the value of `quiet`.
|
| 626 |
+
One file handler (verbose.log) saves all logs, the other (quiet.log) only saves important info.
|
| 627 |
+
|
| 628 |
+
:param name: The name of the logger.
|
| 629 |
+
:param save_dir: The directory in which to save the logs.
|
| 630 |
+
:param quiet: Whether the stream handler should be quiet (i.e. print only important info).
|
| 631 |
+
:return: The logger.
|
| 632 |
+
"""
|
| 633 |
+
logger = logging.getLogger(name)
|
| 634 |
+
logger.setLevel(logging.DEBUG)
|
| 635 |
+
logger.propagate = False
|
| 636 |
+
|
| 637 |
+
# Set logger depending on desired verbosity
|
| 638 |
+
ch = logging.StreamHandler()
|
| 639 |
+
if quiet:
|
| 640 |
+
ch.setLevel(logging.INFO)
|
| 641 |
+
else:
|
| 642 |
+
ch.setLevel(logging.DEBUG)
|
| 643 |
+
logger.addHandler(ch)
|
| 644 |
+
|
| 645 |
+
if save_dir is not None:
|
| 646 |
+
makedirs(save_dir)
|
| 647 |
+
fh_v = logging.FileHandler(os.path.join(save_dir, 'verbose.log'))
|
| 648 |
+
fh_v.setLevel(logging.DEBUG)
|
| 649 |
+
fh_q = logging.FileHandler(os.path.join(save_dir, 'quiet.log'))
|
| 650 |
+
fh_q.setLevel(logging.INFO)
|
| 651 |
+
|
| 652 |
+
logger.addHandler(fh_v)
|
| 653 |
+
logger.addHandler(fh_q)
|
| 654 |
+
|
| 655 |
+
return logger
|
| 656 |
+
|
| 657 |
+
|
| 658 |
+
def load_checkpoint(path: str,
|
| 659 |
+
current_args: Namespace = None,
|
| 660 |
+
cuda: bool = None,
|
| 661 |
+
logger: logging.Logger = None):
|
| 662 |
+
"""
|
| 663 |
+
Loads a model checkpoint.
|
| 664 |
+
|
| 665 |
+
:param path: Path where checkpoint is saved.
|
| 666 |
+
:param current_args: The current arguments. Replaces the arguments loaded from the checkpoint if provided.
|
| 667 |
+
:param cuda: Whether to move model to cuda.
|
| 668 |
+
:param logger: A logger.
|
| 669 |
+
:return: The loaded MPNN.
|
| 670 |
+
"""
|
| 671 |
+
debug = logger.debug if logger is not None else print
|
| 672 |
+
|
| 673 |
+
# Load model and args
|
| 674 |
+
state = torch.load(path, map_location=lambda storage, loc: storage)
|
| 675 |
+
args, loaded_state_dict = state['args'], state['state_dict']
|
| 676 |
+
model_ralated_args = get_model_args()
|
| 677 |
+
|
| 678 |
+
if current_args is not None:
|
| 679 |
+
for key, value in vars(args).items():
|
| 680 |
+
if key in model_ralated_args:
|
| 681 |
+
setattr(current_args, key, value)
|
| 682 |
+
else:
|
| 683 |
+
current_args = args
|
| 684 |
+
|
| 685 |
+
# args.cuda = cuda if cuda is not None else args.cuda
|
| 686 |
+
|
| 687 |
+
# Build model
|
| 688 |
+
model = build_model(current_args)
|
| 689 |
+
model_state_dict = model.state_dict()
|
| 690 |
+
|
| 691 |
+
# Skip missing parameters and parameters of mismatched size
|
| 692 |
+
pretrained_state_dict = {}
|
| 693 |
+
for param_name in loaded_state_dict.keys():
|
| 694 |
+
new_param_name = param_name
|
| 695 |
+
if new_param_name not in model_state_dict:
|
| 696 |
+
debug(f'Pretrained parameter "{param_name}" cannot be found in model parameters.')
|
| 697 |
+
elif model_state_dict[new_param_name].shape != loaded_state_dict[param_name].shape:
|
| 698 |
+
debug(f'Pretrained parameter "{param_name}" '
|
| 699 |
+
f'of shape {loaded_state_dict[param_name].shape} does not match corresponding '
|
| 700 |
+
f'model parameter of shape {model_state_dict[new_param_name].shape}.')
|
| 701 |
+
else:
|
| 702 |
+
debug(f'Loading pretrained parameter "{param_name}".')
|
| 703 |
+
pretrained_state_dict[new_param_name] = loaded_state_dict[param_name]
|
| 704 |
+
# Load pretrained weights
|
| 705 |
+
model_state_dict.update(pretrained_state_dict)
|
| 706 |
+
model.load_state_dict(model_state_dict)
|
| 707 |
+
|
| 708 |
+
if cuda:
|
| 709 |
+
debug('Moving model to cuda')
|
| 710 |
+
model = model.cuda()
|
| 711 |
+
|
| 712 |
+
return model
|
| 713 |
+
|
| 714 |
+
|
| 715 |
+
def get_loss_func(args: Namespace, model=None):
|
| 716 |
+
"""
|
| 717 |
+
Gets the loss function corresponding to a given dataset type.
|
| 718 |
+
|
| 719 |
+
:param args: Namespace containing the dataset type ("classification" or "regression").
|
| 720 |
+
:return: A PyTorch loss function.
|
| 721 |
+
"""
|
| 722 |
+
if hasattr(model, "get_loss_func"):
|
| 723 |
+
return model.get_loss_func(args)
|
| 724 |
+
if args.dataset_type == 'classification':
|
| 725 |
+
return nn.BCEWithLogitsLoss(reduction='none')
|
| 726 |
+
if args.dataset_type == 'regression':
|
| 727 |
+
return nn.MSELoss(reduction='none')
|
| 728 |
+
|
| 729 |
+
raise ValueError(f'Dataset type "{args.dataset_type}" not supported.')
|
| 730 |
+
|
| 731 |
+
|
| 732 |
+
def load_scalars(path: str):
|
| 733 |
+
"""
|
| 734 |
+
Loads the scalars a model was trained with.
|
| 735 |
+
|
| 736 |
+
:param path: Path where model checkpoint is saved.
|
| 737 |
+
:return: A tuple with the data scaler and the features scaler.
|
| 738 |
+
"""
|
| 739 |
+
state = torch.load(path, map_location=lambda storage, loc: storage)
|
| 740 |
+
|
| 741 |
+
scaler = StandardScaler(state['data_scaler']['means'],
|
| 742 |
+
state['data_scaler']['stds']) if state['data_scaler'] is not None else None
|
| 743 |
+
features_scaler = StandardScaler(state['features_scaler']['means'],
|
| 744 |
+
state['features_scaler']['stds'],
|
| 745 |
+
replace_nan_token=0) if state['features_scaler'] is not None else None
|
| 746 |
+
|
| 747 |
+
return scaler, features_scaler
|
| 748 |
+
|
| 749 |
+
|
| 750 |
+
def save_checkpoint(path: str,
|
| 751 |
+
model,
|
| 752 |
+
scaler,
|
| 753 |
+
features_scaler,
|
| 754 |
+
args: Namespace = None):
|
| 755 |
+
"""
|
| 756 |
+
Saves a model checkpoint.
|
| 757 |
+
|
| 758 |
+
:param model: A MPNN.
|
| 759 |
+
:param scaler: A StandardScaler fitted on the data.
|
| 760 |
+
:param features_scaler: A StandardScaler fitted on the features.
|
| 761 |
+
:param args: Arguments namespace.
|
| 762 |
+
:param path: Path where checkpoint will be saved.
|
| 763 |
+
"""
|
| 764 |
+
state = {
|
| 765 |
+
'args': args,
|
| 766 |
+
'state_dict': model.state_dict(),
|
| 767 |
+
'data_scaler': {
|
| 768 |
+
'means': scaler.means,
|
| 769 |
+
'stds': scaler.stds
|
| 770 |
+
} if scaler is not None else None,
|
| 771 |
+
'features_scaler': {
|
| 772 |
+
'means': features_scaler.means,
|
| 773 |
+
'stds': features_scaler.stds
|
| 774 |
+
} if features_scaler is not None else None
|
| 775 |
+
}
|
| 776 |
+
torch.save(state, path)
|
| 777 |
+
|
| 778 |
+
|
| 779 |
+
def build_model(args: Namespace, model_idx=0):
|
| 780 |
+
"""
|
| 781 |
+
Builds a MPNN, which is a message passing neural network + feed-forward layers.
|
| 782 |
+
|
| 783 |
+
:param args: Arguments.
|
| 784 |
+
:return: A MPNN containing the MPN encoder along with final linear layers with parameters initialized.
|
| 785 |
+
"""
|
| 786 |
+
if hasattr(args, 'num_tasks'):
|
| 787 |
+
args.output_size = args.num_tasks
|
| 788 |
+
else:
|
| 789 |
+
args.output_size = 1
|
| 790 |
+
|
| 791 |
+
if args.parser_name == "fingerprint":
|
| 792 |
+
model = GroverFpGeneration(args)
|
| 793 |
+
else:
|
| 794 |
+
# finetune and evaluation case.
|
| 795 |
+
model = GroverFinetuneTask(args)
|
| 796 |
+
initialize_weights(model=model, model_idx=model_idx)
|
| 797 |
+
return model
|
prepare_data.py
CHANGED
|
@@ -17,4 +17,5 @@ val_path = "./tox21/tox21_validation.csv"
|
|
| 17 |
train_path_clean = "./tox21/tox21_train_clean.csv"
|
| 18 |
val_path_clean = "./tox21/tox21_validation_clean.csv"
|
| 19 |
|
| 20 |
-
prepare_data(
|
|
|
|
|
|
| 17 |
train_path_clean = "./tox21/tox21_train_clean.csv"
|
| 18 |
val_path_clean = "./tox21/tox21_validation_clean.csv"
|
| 19 |
|
| 20 |
+
prepare_data(train_path, train_path_clean, "./tox21/valid_mask_train.npy")
|
| 21 |
+
prepare_data(val_path, val_path_clean, "./tox21/valid_mask_val.npy")
|
requirements.txt
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
absl-py @ file:///home/conda/feedstock_root/build_artifacts/absl-py_1751547525079/work
|
| 2 |
+
aiohappyeyeballs==2.6.1
|
| 3 |
+
aiohttp==3.13.2
|
| 4 |
+
aiosignal==1.4.0
|
| 5 |
+
anyio==4.12.0
|
| 6 |
+
async-timeout==5.0.1
|
| 7 |
+
attrs==25.4.0
|
| 8 |
+
Brotli @ file:///home/conda/feedstock_root/build_artifacts/brotli-split_1749229842835/work
|
| 9 |
+
certifi @ file:///home/conda/feedstock_root/build_artifacts/certifi_1754231422783/work/certifi
|
| 10 |
+
cffi @ file:///home/conda/feedstock_root/build_artifacts/cffi_1725571112467/work
|
| 11 |
+
charset-normalizer @ file:///home/conda/feedstock_root/build_artifacts/charset-normalizer_1754767332901/work
|
| 12 |
+
click==8.1.8
|
| 13 |
+
colorama @ file:///home/conda/feedstock_root/build_artifacts/colorama_1733218098505/work
|
| 14 |
+
datasets==4.4.1
|
| 15 |
+
descriptastorus==2.8.0
|
| 16 |
+
dill==0.4.0
|
| 17 |
+
exceptiongroup==1.3.1
|
| 18 |
+
filelock @ file:///home/conda/feedstock_root/build_artifacts/filelock_1755216263872/work
|
| 19 |
+
frozenlist==1.8.0
|
| 20 |
+
fsspec==2025.10.0
|
| 21 |
+
git-filter-repo @ file:///home/conda/feedstock_root/build_artifacts/git-filter-repo_1735551402582/work
|
| 22 |
+
gmpy2 @ file:///home/conda/feedstock_root/build_artifacts/gmpy2_1745509363867/work
|
| 23 |
+
grpcio @ file:///home/conda/feedstock_root/build_artifacts/grpc-split_1754634529307/work
|
| 24 |
+
h11==0.16.0
|
| 25 |
+
h2 @ file:///home/conda/feedstock_root/build_artifacts/h2_1738578511449/work
|
| 26 |
+
hf-xet==1.2.0
|
| 27 |
+
hpack @ file:///home/conda/feedstock_root/build_artifacts/hpack_1737618293087/work
|
| 28 |
+
httpcore==1.0.9
|
| 29 |
+
httpx==0.28.1
|
| 30 |
+
huggingface_hub==1.1.7
|
| 31 |
+
hyperframe @ file:///home/conda/feedstock_root/build_artifacts/hyperframe_1737618333194/work
|
| 32 |
+
idna @ file:///home/conda/feedstock_root/build_artifacts/idna_1733211830134/work
|
| 33 |
+
importlib_metadata @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_importlib-metadata_1747934053/work
|
| 34 |
+
Jinja2 @ file:///home/conda/feedstock_root/build_artifacts/jinja2_1741263328855/work
|
| 35 |
+
joblib @ file:///home/conda/feedstock_root/build_artifacts/joblib_1748019130050/work
|
| 36 |
+
Markdown @ file:///home/conda/feedstock_root/build_artifacts/markdown_1750360292101/work
|
| 37 |
+
MarkupSafe @ file:///home/conda/feedstock_root/build_artifacts/markupsafe_1733219680183/work
|
| 38 |
+
mpmath @ file:///home/conda/feedstock_root/build_artifacts/mpmath_1733302684489/work
|
| 39 |
+
multidict==6.7.0
|
| 40 |
+
multiprocess==0.70.18
|
| 41 |
+
networkx @ file:///home/conda/feedstock_root/build_artifacts/networkx_1698504735452/work
|
| 42 |
+
numpy @ file:///home/conda/feedstock_root/build_artifacts/numpy_1707225342954/work/dist/numpy-1.26.4-cp39-cp39-linux_x86_64.whl#sha256=c799942b5898f6e6c60264d1663a6469a475290e758c654aeeb78e2596463abd
|
| 43 |
+
packaging @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_packaging_1745345660/work
|
| 44 |
+
pandas @ file:///home/conda/feedstock_root/build_artifacts/pandas_1752081702369/work
|
| 45 |
+
pandas_flavor==0.7.0
|
| 46 |
+
pillow @ file:///home/conda/feedstock_root/build_artifacts/pillow_1751482006338/work
|
| 47 |
+
propcache==0.4.1
|
| 48 |
+
protobuf @ file:///home/conda/feedstock_root/build_artifacts/protobuf_1751668301193/work/bazel-bin/python/dist/protobuf-6.31.1-cp39-abi3-linux_x86_64.whl#sha256=91a4a00a210b50fbca2de99b20633990d9f00a443829a9badc867ec313e0fecc
|
| 49 |
+
pyarrow==21.0.0
|
| 50 |
+
pycairo==1.28.0
|
| 51 |
+
pycparser @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_pycparser_1733195786/work
|
| 52 |
+
PySocks @ file:///home/conda/feedstock_root/build_artifacts/pysocks_1733217236728/work
|
| 53 |
+
python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_python-dateutil_1751104122/work
|
| 54 |
+
pytz @ file:///home/conda/feedstock_root/build_artifacts/pytz_1742920838005/work
|
| 55 |
+
PyYAML @ file:///home/conda/feedstock_root/build_artifacts/pyyaml_1737454647378/work
|
| 56 |
+
rdkit==2025.9.1
|
| 57 |
+
rdkit-pypi==2022.9.5
|
| 58 |
+
requests @ file:///home/conda/feedstock_root/build_artifacts/requests_1755614211359/work
|
| 59 |
+
scikit-learn @ file:///home/conda/feedstock_root/build_artifacts/scikit-learn_1736496755362/work/dist/scikit_learn-1.6.1-cp39-cp39-linux_x86_64.whl#sha256=e8f978e37bb47e04e1337a63f75697b723d6d25f58e477734555faed033884ba
|
| 60 |
+
scipy==1.10.1
|
| 61 |
+
shellingham==1.5.4
|
| 62 |
+
six @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_six_1753199211/work
|
| 63 |
+
sympy @ file:///home/conda/feedstock_root/build_artifacts/sympy_1745946051654/work
|
| 64 |
+
tensorboard @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_tensorboard_1752825441/work/tensorboard-2.20.0-py3-none-any.whl#sha256=9dc9f978cb84c0723acf9a345d96c184f0293d18f166bb8d59ee098e6cfaaba6
|
| 65 |
+
tensorboard_data_server @ file:///home/conda/feedstock_root/build_artifacts/tensorboard-data-server_1728639721704/work/tensorboard_data_server-0.7.0-py3-none-manylinux2014_x86_64.whl#sha256=3b7dc7cf17b685028f955453a839cc9b2871818de53e7911eae158fe6b3a80cf
|
| 66 |
+
threadpoolctl @ file:///home/conda/feedstock_root/build_artifacts/threadpoolctl_1741878222898/work
|
| 67 |
+
torch==2.4.0
|
| 68 |
+
torchvision==0.19.0
|
| 69 |
+
tqdm @ file:///home/conda/feedstock_root/build_artifacts/tqdm_1735661334605/work
|
| 70 |
+
triton==3.0.0
|
| 71 |
+
typer-slim==0.20.0
|
| 72 |
+
typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_typing_extensions_1751643513/work
|
| 73 |
+
tzdata @ file:///home/conda/feedstock_root/build_artifacts/python-tzdata_1742745135198/work
|
| 74 |
+
urllib3 @ file:///home/conda/feedstock_root/build_artifacts/urllib3_1750271362675/work
|
| 75 |
+
Werkzeug @ file:///home/conda/feedstock_root/build_artifacts/werkzeug_1733160440960/work
|
| 76 |
+
xarray==2024.7.0
|
| 77 |
+
xxhash==3.6.0
|
| 78 |
+
yarl==1.22.0
|
| 79 |
+
zipp @ file:///home/conda/feedstock_root/build_artifacts/zipp_1749421620841/work
|
| 80 |
+
zstandard==0.23.0
|
| 81 |
+
fastapi
|
| 82 |
+
uvicorn[standard]
|
scripts/__init__.py
ADDED
|
File without changes
|
scripts/build_vocab.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
The vocabulary building scripts.
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
from grover.data.torchvocab import MolVocab
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def build():
|
| 10 |
+
import argparse
|
| 11 |
+
parser = argparse.ArgumentParser()
|
| 12 |
+
parser.add_argument('--data_path', default="../../dataset/grover_new_dataset/druglike_merged_refine2.csv", type=str)
|
| 13 |
+
parser.add_argument('--vocab_save_folder', default="../../dataset/grover_new_dataset", type=str)
|
| 14 |
+
parser.add_argument('--dataset_name', type=str, default=None,
|
| 15 |
+
help="Will be the first part of the vocab file name. If it is None,"
|
| 16 |
+
"the vocab files will be: atom_vocab.pkl and bond_vocab.pkl")
|
| 17 |
+
parser.add_argument('--vocab_max_size', type=int, default=None)
|
| 18 |
+
parser.add_argument('--vocab_min_freq', type=int, default=1)
|
| 19 |
+
args = parser.parse_args()
|
| 20 |
+
|
| 21 |
+
# fin = open(args.data_path, 'r')
|
| 22 |
+
# lines = fin.readlines()
|
| 23 |
+
|
| 24 |
+
for vocab_type in ['atom', 'bond']:
|
| 25 |
+
vocab_file = f"{vocab_type}_vocab.pkl"
|
| 26 |
+
if args.dataset_name is not None:
|
| 27 |
+
vocab_file = args.dataset_name + '_' + vocab_file
|
| 28 |
+
vocab_save_path = os.path.join(args.vocab_save_folder, vocab_file)
|
| 29 |
+
|
| 30 |
+
os.makedirs(os.path.dirname(vocab_save_path), exist_ok=True)
|
| 31 |
+
vocab = MolVocab(file_path=args.data_path,
|
| 32 |
+
max_size=args.vocab_max_size,
|
| 33 |
+
min_freq=args.vocab_min_freq,
|
| 34 |
+
num_workers=100,
|
| 35 |
+
vocab_type=vocab_type)
|
| 36 |
+
print(f"{vocab_type} vocab size", len(vocab))
|
| 37 |
+
vocab.save_vocab(vocab_save_path)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
if __name__ == '__main__':
|
| 41 |
+
build()
|
scripts/save_features.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Computes and saves molecular features for a dataset.
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
import shutil
|
| 6 |
+
import sys
|
| 7 |
+
from argparse import ArgumentParser, Namespace
|
| 8 |
+
from multiprocessing import Pool
|
| 9 |
+
from typing import List, Tuple
|
| 10 |
+
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
|
| 13 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
|
| 14 |
+
|
| 15 |
+
from grover.util.utils import get_data, makedirs, load_features, save_features
|
| 16 |
+
from grover.data.molfeaturegenerator import get_available_features_generators, \
|
| 17 |
+
get_features_generator
|
| 18 |
+
from grover.data.task_labels import rdkit_functional_group_label_features_generator
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def load_temp(temp_dir: str) -> Tuple[List[List[float]], int]:
|
| 23 |
+
"""
|
| 24 |
+
Loads all features saved as .npz files in load_dir.
|
| 25 |
+
|
| 26 |
+
Assumes temporary files are named in order 0.npz, 1.npz, ...
|
| 27 |
+
|
| 28 |
+
:param temp_dir: Directory in which temporary .npz files containing features are stored.
|
| 29 |
+
:return: A tuple with a list of molecule features, where each molecule's features is a list of floats,
|
| 30 |
+
and the number of temporary files.
|
| 31 |
+
"""
|
| 32 |
+
features = []
|
| 33 |
+
temp_num = 0
|
| 34 |
+
temp_path = os.path.join(temp_dir, f'{temp_num}.npz')
|
| 35 |
+
|
| 36 |
+
while os.path.exists(temp_path):
|
| 37 |
+
features.extend(load_features(temp_path))
|
| 38 |
+
temp_num += 1
|
| 39 |
+
temp_path = os.path.join(temp_dir, f'{temp_num}.npz')
|
| 40 |
+
|
| 41 |
+
return features, temp_num
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def generate_and_save_features(args: Namespace):
|
| 45 |
+
"""
|
| 46 |
+
Computes and saves features for a dataset of molecules as a 2D array in a .npz file.
|
| 47 |
+
|
| 48 |
+
:param args: Arguments.
|
| 49 |
+
"""
|
| 50 |
+
# Create directory for save_path
|
| 51 |
+
makedirs(args.save_path, isfile=True)
|
| 52 |
+
|
| 53 |
+
# Get data and features function
|
| 54 |
+
data = get_data(path=args.data_path, max_data_size=None)
|
| 55 |
+
features_generator = get_features_generator(args.features_generator)
|
| 56 |
+
temp_save_dir = args.save_path + '_temp'
|
| 57 |
+
|
| 58 |
+
# Load partially complete data
|
| 59 |
+
if args.restart:
|
| 60 |
+
if os.path.exists(args.save_path):
|
| 61 |
+
os.remove(args.save_path)
|
| 62 |
+
if os.path.exists(temp_save_dir):
|
| 63 |
+
shutil.rmtree(temp_save_dir)
|
| 64 |
+
else:
|
| 65 |
+
if os.path.exists(args.save_path):
|
| 66 |
+
raise ValueError(f'"{args.save_path}" already exists and args.restart is False.')
|
| 67 |
+
|
| 68 |
+
if os.path.exists(temp_save_dir):
|
| 69 |
+
features, temp_num = load_temp(temp_save_dir)
|
| 70 |
+
|
| 71 |
+
if not os.path.exists(temp_save_dir):
|
| 72 |
+
makedirs(temp_save_dir)
|
| 73 |
+
features, temp_num = [], 0
|
| 74 |
+
|
| 75 |
+
# Build features map function
|
| 76 |
+
data = data[len(features):] # restrict to data for which features have not been computed yet
|
| 77 |
+
mols = (d.smiles for d in data)
|
| 78 |
+
|
| 79 |
+
if args.sequential:
|
| 80 |
+
features_map = map(features_generator, mols)
|
| 81 |
+
else:
|
| 82 |
+
features_map = Pool(30).imap(features_generator, mols)
|
| 83 |
+
|
| 84 |
+
# Get features
|
| 85 |
+
temp_features = []
|
| 86 |
+
for i, feats in tqdm(enumerate(features_map), total=len(data)):
|
| 87 |
+
temp_features.append(feats)
|
| 88 |
+
|
| 89 |
+
# Save temporary features every save_frequency
|
| 90 |
+
if (i > 0 and (i + 1) % args.save_frequency == 0) or i == len(data) - 1:
|
| 91 |
+
save_features(os.path.join(temp_save_dir, f'{temp_num}.npz'), temp_features)
|
| 92 |
+
features.extend(temp_features)
|
| 93 |
+
temp_features = []
|
| 94 |
+
temp_num += 1
|
| 95 |
+
|
| 96 |
+
try:
|
| 97 |
+
# Save all features
|
| 98 |
+
save_features(args.save_path, features)
|
| 99 |
+
|
| 100 |
+
# Remove temporary features
|
| 101 |
+
shutil.rmtree(temp_save_dir)
|
| 102 |
+
except OverflowError:
|
| 103 |
+
print('Features array is too large to save as a single file. Instead keeping features as a directory of files.')
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
if __name__ == '__main__':
|
| 107 |
+
|
| 108 |
+
parser = ArgumentParser()
|
| 109 |
+
parser.add_argument('--data_path', type=str, required=True,
|
| 110 |
+
help='Path to data CSV')
|
| 111 |
+
parser.add_argument('--features_generator', type=str, required=True,
|
| 112 |
+
choices=get_available_features_generators(),
|
| 113 |
+
help='Type of features to generate')
|
| 114 |
+
parser.add_argument('--save_path', type=str, default=None,
|
| 115 |
+
help='Path to .npz file where features will be saved as a compressed numpy archive')
|
| 116 |
+
parser.add_argument('--save_frequency', type=int, default=10000,
|
| 117 |
+
help='Frequency with which to save the features')
|
| 118 |
+
parser.add_argument('--restart', action='store_true', default=False,
|
| 119 |
+
help='Whether to not load partially complete featurization and instead start from scratch')
|
| 120 |
+
parser.add_argument('--max_data_size', type=int,
|
| 121 |
+
help='Maximum number of data points to load')
|
| 122 |
+
parser.add_argument('--sequential', action='store_true', default=False,
|
| 123 |
+
help='Whether to task sequentially rather than in parallel')
|
| 124 |
+
args = parser.parse_args()
|
| 125 |
+
if args.save_path is None:
|
| 126 |
+
args.save_path = args.data_path.split('csv')[0] + 'npz'
|
| 127 |
+
generate_and_save_features(args)
|
scripts/split_data.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
The data splitting script for pretraining.
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
from argparse import ArgumentParser
|
| 6 |
+
import csv
|
| 7 |
+
import shutil
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
import grover.util.utils as fea_utils
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
parser = ArgumentParser()
|
| 15 |
+
parser.add_argument("--data_path", default="../drug_data/grover_data/delaneyfreesolvlipo.csv")
|
| 16 |
+
parser.add_argument("--features_path", default="../drug_data/grover_data/delaneyfreesolvlipo_molbert.npz")
|
| 17 |
+
parser.add_argument("--sample_per_file", type=int, default=1000)
|
| 18 |
+
parser.add_argument("--output_path", default="../drug_data/grover_data/delaneyfreesolvlipo")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def load_smiles(data_path):
|
| 22 |
+
with open(data_path) as f:
|
| 23 |
+
reader = csv.reader(f)
|
| 24 |
+
header = next(reader)
|
| 25 |
+
res = []
|
| 26 |
+
for line in reader:
|
| 27 |
+
res.append(line)
|
| 28 |
+
return res, header
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def load_features(data_path):
|
| 32 |
+
fea = fea_utils.load_features(data_path)
|
| 33 |
+
return fea
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def save_smiles(data_path, index, data, header):
|
| 37 |
+
fn = os.path.join(data_path, str(index) + ".csv")
|
| 38 |
+
with open(fn, "w") as f:
|
| 39 |
+
fw = csv.writer(f)
|
| 40 |
+
fw.writerow(header)
|
| 41 |
+
for d in data:
|
| 42 |
+
fw.writerow(d)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def save_features(data_path, index, data):
|
| 46 |
+
fn = os.path.join(data_path, str(index) + ".npz")
|
| 47 |
+
np.savez_compressed(fn, features=data)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def run():
|
| 51 |
+
args = parser.parse_args()
|
| 52 |
+
res, header = load_smiles(data_path=args.data_path)
|
| 53 |
+
fea = load_features(data_path=args.features_path)
|
| 54 |
+
assert len(res) == fea.shape[0]
|
| 55 |
+
|
| 56 |
+
n_graphs = len(res)
|
| 57 |
+
perm = np.random.permutation(n_graphs)
|
| 58 |
+
|
| 59 |
+
nfold = int(n_graphs / args.sample_per_file + 1)
|
| 60 |
+
print("Number of files: %d" % nfold)
|
| 61 |
+
if os.path.exists(args.output_path):
|
| 62 |
+
shutil.rmtree(args.output_path)
|
| 63 |
+
os.makedirs(args.output_path, exist_ok=True)
|
| 64 |
+
graph_path = os.path.join(args.output_path, "graph")
|
| 65 |
+
fea_path = os.path.join(args.output_path, "feature")
|
| 66 |
+
os.makedirs(graph_path, exist_ok=True)
|
| 67 |
+
os.makedirs(fea_path, exist_ok=True)
|
| 68 |
+
|
| 69 |
+
for i in range(nfold):
|
| 70 |
+
sidx = i * args.sample_per_file
|
| 71 |
+
eidx = min((i + 1) * args.sample_per_file, n_graphs)
|
| 72 |
+
indexes = perm[sidx:eidx]
|
| 73 |
+
sres = [res[j] for j in indexes]
|
| 74 |
+
sfea = fea[indexes]
|
| 75 |
+
save_smiles(graph_path, i, sres, header)
|
| 76 |
+
save_features(fea_path, i, sfea)
|
| 77 |
+
|
| 78 |
+
summary_path = os.path.join(args.output_path, "summary.txt")
|
| 79 |
+
summary_fout = open(summary_path, 'w')
|
| 80 |
+
summary_fout.write("n_files:%d\n" % nfold)
|
| 81 |
+
summary_fout.write("n_samples:%d\n" % n_graphs)
|
| 82 |
+
summary_fout.write("sample_per_file:%d\n" % args.sample_per_file)
|
| 83 |
+
summary_fout.close()
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
if __name__ == "__main__":
|
| 87 |
+
run()
|
src/commands.py
CHANGED
|
@@ -16,6 +16,16 @@ def generate_features(data_path, save_path):
|
|
| 16 |
f"--restart"
|
| 17 |
)
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
def finetune(train_path, val_path, train_features_path, val_features_path,
|
| 21 |
save_dir, checkpoint_path, args
|
|
|
|
| 16 |
f"--restart"
|
| 17 |
)
|
| 18 |
|
| 19 |
+
def predict_from_csv(data_path, features_path, checkpoint_dir, output_path):
|
| 20 |
+
predict_cmd = (
|
| 21 |
+
f"python main.py predict "
|
| 22 |
+
f"--data_path {data_path} "
|
| 23 |
+
f"--features_path {features_path} "
|
| 24 |
+
f"--checkpoint_dir {checkpoint_dir} "
|
| 25 |
+
f"--no_features_scaling "
|
| 26 |
+
f"--output {output_path}"
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
|
| 30 |
def finetune(train_path, val_path, train_features_path, val_features_path,
|
| 31 |
save_dir, checkpoint_path, args
|
task/__init__.py
ADDED
|
File without changes
|
task/cross_validate.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
The cross validation function for finetuning.
|
| 3 |
+
This implementation is adapted from
|
| 4 |
+
https://github.com/chemprop/chemprop/blob/master/chemprop/train/cross_validate.py
|
| 5 |
+
"""
|
| 6 |
+
import os
|
| 7 |
+
import time
|
| 8 |
+
from argparse import Namespace
|
| 9 |
+
from logging import Logger
|
| 10 |
+
from typing import Tuple
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
from grover.util.utils import get_task_names
|
| 15 |
+
from grover.util.utils import makedirs
|
| 16 |
+
from task.run_evaluation import run_evaluation
|
| 17 |
+
from task.train import run_training
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def cross_validate(args: Namespace, logger: Logger = None) -> Tuple[float, float]:
|
| 21 |
+
"""
|
| 22 |
+
k-fold cross validation.
|
| 23 |
+
|
| 24 |
+
:return: A tuple of mean_score and std_score.
|
| 25 |
+
"""
|
| 26 |
+
info = logger.info if logger is not None else print
|
| 27 |
+
|
| 28 |
+
# Initialize relevant variables
|
| 29 |
+
init_seed = args.seed
|
| 30 |
+
save_dir = args.save_dir
|
| 31 |
+
task_names = get_task_names(args.data_path)
|
| 32 |
+
|
| 33 |
+
# Run training with different random seeds for each fold
|
| 34 |
+
all_scores = []
|
| 35 |
+
time_start = time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime())
|
| 36 |
+
for fold_num in range(args.num_folds):
|
| 37 |
+
info(f'Fold {fold_num}')
|
| 38 |
+
args.seed = init_seed + fold_num
|
| 39 |
+
args.save_dir = os.path.join(save_dir, f'fold_{fold_num}')
|
| 40 |
+
makedirs(args.save_dir)
|
| 41 |
+
if args.parser_name == "finetune":
|
| 42 |
+
model_scores = run_training(args, time_start, logger)
|
| 43 |
+
else:
|
| 44 |
+
model_scores = run_evaluation(args, logger)
|
| 45 |
+
all_scores.append(model_scores)
|
| 46 |
+
all_scores = np.array(all_scores)
|
| 47 |
+
|
| 48 |
+
# Report scores for each fold
|
| 49 |
+
info(f'{args.num_folds}-fold cross validation')
|
| 50 |
+
|
| 51 |
+
for fold_num, scores in enumerate(all_scores):
|
| 52 |
+
info(f'Seed {init_seed + fold_num} ==> test {args.metric} = {np.nanmean(scores):.6f}')
|
| 53 |
+
|
| 54 |
+
if args.show_individual_scores:
|
| 55 |
+
for task_name, score in zip(task_names, scores):
|
| 56 |
+
info(f'Seed {init_seed + fold_num} ==> test {task_name} {args.metric} = {score:.6f}')
|
| 57 |
+
|
| 58 |
+
# Report scores across models
|
| 59 |
+
avg_scores = np.nanmean(all_scores, axis=1) # average score for each model across tasks
|
| 60 |
+
mean_score, std_score = np.nanmean(avg_scores), np.nanstd(avg_scores)
|
| 61 |
+
info(f'overall_{args.split_type}_test_{args.metric}={mean_score:.6f}')
|
| 62 |
+
info(f'std={std_score:.6f}')
|
| 63 |
+
|
| 64 |
+
if args.show_individual_scores:
|
| 65 |
+
for task_num, task_name in enumerate(task_names):
|
| 66 |
+
info(f'Overall test {task_name} {args.metric} = '
|
| 67 |
+
f'{np.nanmean(all_scores[:, task_num]):.6f} +/- {np.nanstd(all_scores[:, task_num]):.6f}')
|
| 68 |
+
|
| 69 |
+
return mean_score, std_score
|
task/fingerprint.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
The fingerprint generation function.
|
| 3 |
+
"""
|
| 4 |
+
from argparse import Namespace
|
| 5 |
+
from logging import Logger
|
| 6 |
+
from typing import List
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
from torch.utils.data import DataLoader
|
| 11 |
+
|
| 12 |
+
from grover.data import MolCollator
|
| 13 |
+
from grover.data import MoleculeDataset
|
| 14 |
+
from grover.util.utils import get_data, create_logger, load_checkpoint
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def do_generate(model: nn.Module,
|
| 18 |
+
data: MoleculeDataset,
|
| 19 |
+
args: Namespace,
|
| 20 |
+
) -> List[List[float]]:
|
| 21 |
+
"""
|
| 22 |
+
Do the fingerprint generation on a dataset using the pre-trained models.
|
| 23 |
+
|
| 24 |
+
:param model: A model.
|
| 25 |
+
:param data: A MoleculeDataset.
|
| 26 |
+
:param args: A StandardScaler object fit on the training targets.
|
| 27 |
+
:return: A list of fingerprints.
|
| 28 |
+
"""
|
| 29 |
+
model.eval()
|
| 30 |
+
args.bond_drop_rate = 0
|
| 31 |
+
preds = []
|
| 32 |
+
|
| 33 |
+
mol_collator = MolCollator(args=args, shared_dict={})
|
| 34 |
+
|
| 35 |
+
num_workers = 4
|
| 36 |
+
mol_loader = DataLoader(data,
|
| 37 |
+
batch_size=32,
|
| 38 |
+
shuffle=False,
|
| 39 |
+
num_workers=num_workers,
|
| 40 |
+
collate_fn=mol_collator)
|
| 41 |
+
for item in mol_loader:
|
| 42 |
+
_, batch, features_batch, _, _ = item
|
| 43 |
+
with torch.no_grad():
|
| 44 |
+
batch_preds = model(batch, features_batch)
|
| 45 |
+
preds.extend(batch_preds.data.cpu().numpy())
|
| 46 |
+
return preds
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def generate_fingerprints(args: Namespace, logger: Logger = None) -> List[List[float]]:
|
| 50 |
+
"""
|
| 51 |
+
Generate the fingerprints.
|
| 52 |
+
|
| 53 |
+
:param logger:
|
| 54 |
+
:param args: Arguments.
|
| 55 |
+
:return: A list of lists of target fingerprints.
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
checkpoint_path = args.checkpoint_paths[0]
|
| 59 |
+
if logger is None:
|
| 60 |
+
logger = create_logger('fingerprints', quiet=False)
|
| 61 |
+
print('Loading data')
|
| 62 |
+
test_data = get_data(path=args.data_path,
|
| 63 |
+
args=args,
|
| 64 |
+
use_compound_names=False,
|
| 65 |
+
max_data_size=float("inf"),
|
| 66 |
+
skip_invalid_smiles=False)
|
| 67 |
+
test_data = MoleculeDataset(test_data)
|
| 68 |
+
|
| 69 |
+
logger.info(f'Total size = {len(test_data):,}')
|
| 70 |
+
logger.info(f'Generating...')
|
| 71 |
+
# Load model
|
| 72 |
+
model = load_checkpoint(checkpoint_path, cuda=args.cuda, current_args=args, logger=logger)
|
| 73 |
+
model_preds = do_generate(
|
| 74 |
+
model=model,
|
| 75 |
+
data=test_data,
|
| 76 |
+
args=args
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
return model_preds
|
task/grovertrainer.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
The GROVER trainer.
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
import time
|
| 6 |
+
from logging import Logger
|
| 7 |
+
from typing import List, Tuple
|
| 8 |
+
from collections.abc import Callable
|
| 9 |
+
import torch
|
| 10 |
+
from torch.nn import Module
|
| 11 |
+
from torch.utils.data import DataLoader
|
| 12 |
+
|
| 13 |
+
from grover.model.models import GroverTask
|
| 14 |
+
from grover.util.multi_gpu_wrapper import MultiGpuWrapper as mgw
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class GROVERTrainer:
|
| 18 |
+
def __init__(self,
|
| 19 |
+
args,
|
| 20 |
+
embedding_model: Module,
|
| 21 |
+
atom_vocab_size: int, # atom vocab size
|
| 22 |
+
bond_vocab_size: int,
|
| 23 |
+
fg_szie: int,
|
| 24 |
+
train_dataloader: DataLoader,
|
| 25 |
+
test_dataloader: DataLoader,
|
| 26 |
+
optimizer_builder: Callable,
|
| 27 |
+
scheduler_builder: Callable,
|
| 28 |
+
logger: Logger = None,
|
| 29 |
+
with_cuda: bool = False,
|
| 30 |
+
enable_multi_gpu: bool = False):
|
| 31 |
+
"""
|
| 32 |
+
The init function of GROVERTrainer
|
| 33 |
+
:param args: the input arguments.
|
| 34 |
+
:param embedding_model: the model to generate atom/bond embeddings.
|
| 35 |
+
:param atom_vocab_size: the vocabulary size of atoms.
|
| 36 |
+
:param bond_vocab_size: the vocabulary size of bonds.
|
| 37 |
+
:param fg_szie: the size of semantic motifs (functional groups)
|
| 38 |
+
:param train_dataloader: the data loader of train data.
|
| 39 |
+
:param test_dataloader: the data loader of validation data.
|
| 40 |
+
:param optimizer_builder: the function of building the optimizer.
|
| 41 |
+
:param scheduler_builder: the function of building the scheduler.
|
| 42 |
+
:param logger: the logger
|
| 43 |
+
:param with_cuda: enable gpu training.
|
| 44 |
+
:param enable_multi_gpu: enable multi_gpu traning.
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
self.args = args
|
| 48 |
+
self.with_cuda = with_cuda
|
| 49 |
+
self.grover = embedding_model
|
| 50 |
+
self.model = GroverTask(args, embedding_model, atom_vocab_size, bond_vocab_size, fg_szie)
|
| 51 |
+
self.loss_func = self.model.get_loss_func(args)
|
| 52 |
+
self.enable_multi_gpu = enable_multi_gpu
|
| 53 |
+
|
| 54 |
+
self.atom_vocab_size = atom_vocab_size
|
| 55 |
+
self.bond_vocab_size = bond_vocab_size
|
| 56 |
+
self.debug = logger.debug if logger is not None else print
|
| 57 |
+
|
| 58 |
+
if self.with_cuda:
|
| 59 |
+
# print("Using %d GPUs for training." % (torch.cuda.device_count()))
|
| 60 |
+
self.model = self.model.cuda()
|
| 61 |
+
|
| 62 |
+
self.train_data = train_dataloader
|
| 63 |
+
self.test_data = test_dataloader
|
| 64 |
+
|
| 65 |
+
self.optimizer = optimizer_builder(self.model, self.args)
|
| 66 |
+
self.scheduler = scheduler_builder(self.optimizer, self.args)
|
| 67 |
+
if self.enable_multi_gpu:
|
| 68 |
+
self.optimizer = mgw.DistributedOptimizer(self.optimizer,
|
| 69 |
+
named_parameters=self.model.named_parameters())
|
| 70 |
+
self.args = args
|
| 71 |
+
self.n_iter = 0
|
| 72 |
+
|
| 73 |
+
def broadcast_parameters(self) -> None:
|
| 74 |
+
"""
|
| 75 |
+
Broadcast parameters before training.
|
| 76 |
+
:return: no return.
|
| 77 |
+
"""
|
| 78 |
+
if self.enable_multi_gpu:
|
| 79 |
+
# broadcast parameters & optimizer state.
|
| 80 |
+
mgw.broadcast_parameters(self.model.state_dict(), root_rank=0)
|
| 81 |
+
mgw.broadcast_optimizer_state(self.optimizer, root_rank=0)
|
| 82 |
+
|
| 83 |
+
def train(self, epoch: int) -> List:
|
| 84 |
+
"""
|
| 85 |
+
The training iteration
|
| 86 |
+
:param epoch: the current epoch number.
|
| 87 |
+
:return: the loss terms of current epoch.
|
| 88 |
+
"""
|
| 89 |
+
# return self.mock_iter(epoch, self.train_data, train=True)
|
| 90 |
+
return self.iter(epoch, self.train_data, train=True)
|
| 91 |
+
|
| 92 |
+
def test(self, epoch: int) -> List:
|
| 93 |
+
"""
|
| 94 |
+
The test/validaiion iteration
|
| 95 |
+
:param epoch: the current epoch number.
|
| 96 |
+
:return: the loss terms as a list
|
| 97 |
+
"""
|
| 98 |
+
# return self.mock_iter(epoch, self.test_data, train=False)
|
| 99 |
+
return self.iter(epoch, self.test_data, train=False)
|
| 100 |
+
|
| 101 |
+
def mock_iter(self, epoch: int, data_loader: DataLoader, train: bool = True) -> List:
|
| 102 |
+
"""
|
| 103 |
+
Perform a mock iteration. For test only.
|
| 104 |
+
:param epoch: the current epoch number.
|
| 105 |
+
:param data_loader: the data loader.
|
| 106 |
+
:param train: True: train model, False: validation model.
|
| 107 |
+
:return: the loss terms as a list
|
| 108 |
+
"""
|
| 109 |
+
|
| 110 |
+
for _, _ in enumerate(data_loader):
|
| 111 |
+
self.scheduler.step()
|
| 112 |
+
cum_loss_sum = 0.0
|
| 113 |
+
self.n_iter += self.args.batch_size
|
| 114 |
+
return self.n_iter, cum_loss_sum, (0, 0, 0, 0, 0, 0)
|
| 115 |
+
|
| 116 |
+
def iter(self, epoch, data_loader, train=True) -> List:
|
| 117 |
+
"""
|
| 118 |
+
Perform a training / validation iteration.
|
| 119 |
+
:param epoch: the current epoch number.
|
| 120 |
+
:param data_loader: the data loader.
|
| 121 |
+
:param train: True: train model, False: validation model.
|
| 122 |
+
:return: the loss terms as a list
|
| 123 |
+
"""
|
| 124 |
+
|
| 125 |
+
if train:
|
| 126 |
+
self.model.train()
|
| 127 |
+
else:
|
| 128 |
+
self.model.eval()
|
| 129 |
+
|
| 130 |
+
loss_sum, iter_count = 0, 0
|
| 131 |
+
cum_loss_sum, cum_iter_count = 0, 0
|
| 132 |
+
av_loss_sum, bv_loss_sum, fg_loss_sum, av_dist_loss_sum, bv_dist_loss_sum, fg_dist_loss_sum = 0, 0, 0, 0, 0, 0
|
| 133 |
+
# loss_func = self.model.get_loss_func(self.args)
|
| 134 |
+
|
| 135 |
+
for _, item in enumerate(data_loader):
|
| 136 |
+
batch_graph = item["graph_input"]
|
| 137 |
+
targets = item["targets"]
|
| 138 |
+
|
| 139 |
+
if next(self.model.parameters()).is_cuda:
|
| 140 |
+
targets["av_task"] = targets["av_task"].cuda()
|
| 141 |
+
targets["bv_task"] = targets["bv_task"].cuda()
|
| 142 |
+
targets["fg_task"] = targets["fg_task"].cuda()
|
| 143 |
+
|
| 144 |
+
preds = self.model(batch_graph)
|
| 145 |
+
|
| 146 |
+
# # ad-hoc code, for visualizing a model, comment this block when it is not needed
|
| 147 |
+
# import dglt.contrib.grover.vis_model as vis_model
|
| 148 |
+
# for task in ['av_task', 'bv_task', 'fg_task']:
|
| 149 |
+
# vis_graph = vis_model.make_dot(self.model(batch_graph)[task],
|
| 150 |
+
# params=dict(self.model.named_parameters()))
|
| 151 |
+
# # vis_graph.view()
|
| 152 |
+
# vis_graph.render(f"{self.args.backbone}_model_{task}_vis.png", format="png")
|
| 153 |
+
# exit()
|
| 154 |
+
|
| 155 |
+
loss, av_loss, bv_loss, fg_loss, av_dist_loss, bv_dist_loss, fg_dist_loss = self.loss_func(preds, targets)
|
| 156 |
+
|
| 157 |
+
loss_sum += loss.item()
|
| 158 |
+
iter_count += self.args.batch_size
|
| 159 |
+
|
| 160 |
+
if train:
|
| 161 |
+
cum_loss_sum += loss.item()
|
| 162 |
+
# Run model
|
| 163 |
+
self.model.zero_grad()
|
| 164 |
+
self.optimizer.zero_grad()
|
| 165 |
+
loss.backward()
|
| 166 |
+
self.optimizer.step()
|
| 167 |
+
self.scheduler.step()
|
| 168 |
+
else:
|
| 169 |
+
# For eval model, only consider the loss of three task.
|
| 170 |
+
cum_loss_sum += av_loss.item()
|
| 171 |
+
cum_loss_sum += bv_loss.item()
|
| 172 |
+
cum_loss_sum += fg_loss.item()
|
| 173 |
+
|
| 174 |
+
av_loss_sum += av_loss.item()
|
| 175 |
+
bv_loss_sum += bv_loss.item()
|
| 176 |
+
fg_loss_sum += fg_loss.item()
|
| 177 |
+
av_dist_loss_sum += av_dist_loss.item() if type(av_dist_loss) != float else av_dist_loss
|
| 178 |
+
bv_dist_loss_sum += bv_dist_loss.item() if type(bv_dist_loss) != float else bv_dist_loss
|
| 179 |
+
fg_dist_loss_sum += fg_dist_loss.item() if type(fg_dist_loss) != float else fg_dist_loss
|
| 180 |
+
|
| 181 |
+
cum_iter_count += 1
|
| 182 |
+
self.n_iter += self.args.batch_size
|
| 183 |
+
|
| 184 |
+
# Debug only.
|
| 185 |
+
# if i % 50 == 0:
|
| 186 |
+
# print(f"epoch: {epoch}, batch_id: {i}, av_loss: {av_loss}, bv_loss: {bv_loss}, "
|
| 187 |
+
# f"fg_loss: {fg_loss}, av_dist_loss: {av_dist_loss}, bv_dist_loss: {bv_dist_loss}, "
|
| 188 |
+
# f"fg_dist_loss: {fg_dist_loss}")
|
| 189 |
+
|
| 190 |
+
cum_loss_sum /= cum_iter_count
|
| 191 |
+
av_loss_sum /= cum_iter_count
|
| 192 |
+
bv_loss_sum /= cum_iter_count
|
| 193 |
+
fg_loss_sum /= cum_iter_count
|
| 194 |
+
av_dist_loss_sum /= cum_iter_count
|
| 195 |
+
bv_dist_loss_sum /= cum_iter_count
|
| 196 |
+
fg_dist_loss_sum /= cum_iter_count
|
| 197 |
+
|
| 198 |
+
return self.n_iter, cum_loss_sum, (av_loss_sum, bv_loss_sum, fg_loss_sum, av_dist_loss_sum,
|
| 199 |
+
bv_dist_loss_sum, fg_dist_loss_sum)
|
| 200 |
+
|
| 201 |
+
def save(self, epoch, file_path, name=None) -> str:
|
| 202 |
+
"""
|
| 203 |
+
Save the intermediate models during training.
|
| 204 |
+
:param epoch: the epoch number.
|
| 205 |
+
:param file_path: the file_path to save the model.
|
| 206 |
+
:return: the output path.
|
| 207 |
+
"""
|
| 208 |
+
# add specific time in model fine name, in order to distinguish different saved models
|
| 209 |
+
now = time.localtime()
|
| 210 |
+
if name is None:
|
| 211 |
+
name = "_%04d_%02d_%02d_%02d_%02d_%02d" % (
|
| 212 |
+
now.tm_year, now.tm_mon, now.tm_mday, now.tm_hour, now.tm_min, now.tm_sec)
|
| 213 |
+
output_path = file_path + name + ".ep%d" % epoch
|
| 214 |
+
scaler = None
|
| 215 |
+
features_scaler = None
|
| 216 |
+
state = {
|
| 217 |
+
'args': self.args,
|
| 218 |
+
'state_dict': self.model.state_dict(),
|
| 219 |
+
'optimizer': self.optimizer.state_dict(),
|
| 220 |
+
'scheduler_step': self.scheduler.current_step,
|
| 221 |
+
"epoch": epoch,
|
| 222 |
+
'data_scaler': {
|
| 223 |
+
'means': scaler.means,
|
| 224 |
+
'stds': scaler.stds
|
| 225 |
+
} if scaler is not None else None,
|
| 226 |
+
'features_scaler': {
|
| 227 |
+
'means': features_scaler.means,
|
| 228 |
+
'stds': features_scaler.stds
|
| 229 |
+
} if features_scaler is not None else None
|
| 230 |
+
}
|
| 231 |
+
torch.save(state, output_path)
|
| 232 |
+
|
| 233 |
+
# Is this necessary?
|
| 234 |
+
# if self.with_cuda:
|
| 235 |
+
# self.model = self.model.cuda()
|
| 236 |
+
print("EP:%d Model Saved on:" % epoch, output_path)
|
| 237 |
+
return output_path
|
| 238 |
+
|
| 239 |
+
def save_tmp(self, epoch, file_path, rank=0):
|
| 240 |
+
"""
|
| 241 |
+
Save the models for auto-restore during training.
|
| 242 |
+
The model are stored in file_path/tmp folder and will replaced on each epoch.
|
| 243 |
+
:param epoch: the epoch number.
|
| 244 |
+
:param file_path: the file_path to store the model.
|
| 245 |
+
:param rank: the current rank (decrypted).
|
| 246 |
+
:return:
|
| 247 |
+
"""
|
| 248 |
+
store_path = os.path.join(file_path, "tmp")
|
| 249 |
+
if not os.path.exists(store_path):
|
| 250 |
+
os.makedirs(store_path, exist_ok=True)
|
| 251 |
+
store_path = os.path.join(store_path, "model.%d" % rank)
|
| 252 |
+
state = {
|
| 253 |
+
'args': self.args,
|
| 254 |
+
'state_dict': self.model.state_dict(),
|
| 255 |
+
'optimizer': self.optimizer.state_dict(),
|
| 256 |
+
'scheduler_step': self.scheduler.current_step,
|
| 257 |
+
"epoch": epoch
|
| 258 |
+
}
|
| 259 |
+
torch.save(state, store_path)
|
| 260 |
+
|
| 261 |
+
def restore(self, file_path, rank=0) -> Tuple[int, int]:
|
| 262 |
+
"""
|
| 263 |
+
Restore the training state saved by save_tmp.
|
| 264 |
+
:param file_path: the file_path to store the model.
|
| 265 |
+
:param rank: the current rank (decrypted).
|
| 266 |
+
:return: the restored epoch number and the scheduler_step in scheduler.
|
| 267 |
+
"""
|
| 268 |
+
cpt_path = os.path.join(file_path, "tmp", "model.%d" % rank)
|
| 269 |
+
if not os.path.exists(cpt_path):
|
| 270 |
+
print("No checkpoint found %d")
|
| 271 |
+
return 0, 0
|
| 272 |
+
cpt = torch.load(cpt_path)
|
| 273 |
+
self.model.load_state_dict(cpt["state_dict"])
|
| 274 |
+
self.optimizer.load_state_dict(cpt["optimizer"])
|
| 275 |
+
epoch = cpt["epoch"]
|
| 276 |
+
scheduler_step = cpt["scheduler_step"]
|
| 277 |
+
self.scheduler.current_step = scheduler_step
|
| 278 |
+
print("Restore checkpoint, current epoch: %d" % (epoch))
|
| 279 |
+
return epoch, scheduler_step
|
task/predict.py
ADDED
|
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
The predict function using the finetuned model to make the prediction. .
|
| 3 |
+
"""
|
| 4 |
+
from argparse import Namespace
|
| 5 |
+
from typing import List
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from torch.utils.data import DataLoader
|
| 12 |
+
|
| 13 |
+
from grover.data import MolCollator
|
| 14 |
+
from grover.data import MoleculeDataset
|
| 15 |
+
from grover.data import StandardScaler
|
| 16 |
+
from grover.util.utils import get_data, get_data_from_smiles, create_logger, load_args, get_task_names, tqdm, \
|
| 17 |
+
load_checkpoint, load_scalars
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def predict(model: nn.Module,
|
| 21 |
+
data: MoleculeDataset,
|
| 22 |
+
args: Namespace,
|
| 23 |
+
batch_size: int,
|
| 24 |
+
loss_func,
|
| 25 |
+
logger,
|
| 26 |
+
shared_dict,
|
| 27 |
+
scaler: StandardScaler = None
|
| 28 |
+
) -> List[List[float]]:
|
| 29 |
+
"""
|
| 30 |
+
Makes predictions on a dataset using an ensemble of models.
|
| 31 |
+
|
| 32 |
+
:param model: A model.
|
| 33 |
+
:param data: A MoleculeDataset.
|
| 34 |
+
:param batch_size: Batch size.
|
| 35 |
+
:param scaler: A StandardScaler object fit on the training targets.
|
| 36 |
+
:return: A list of lists of predictions. The outer list is examples
|
| 37 |
+
while the inner list is tasks.
|
| 38 |
+
"""
|
| 39 |
+
# debug = logger.debug if logger is not None else print
|
| 40 |
+
model.eval()
|
| 41 |
+
args.bond_drop_rate = 0
|
| 42 |
+
preds = []
|
| 43 |
+
|
| 44 |
+
# num_iters, iter_step = len(data), batch_size
|
| 45 |
+
loss_sum, iter_count = 0, 0
|
| 46 |
+
|
| 47 |
+
mol_collator = MolCollator(args=args, shared_dict=shared_dict)
|
| 48 |
+
# mol_dataset = MoleculeDataset(data)
|
| 49 |
+
|
| 50 |
+
num_workers = 4
|
| 51 |
+
mol_loader = DataLoader(data, batch_size=batch_size, shuffle=False, num_workers=num_workers,
|
| 52 |
+
collate_fn=mol_collator)
|
| 53 |
+
for _, item in enumerate(mol_loader):
|
| 54 |
+
_, batch, features_batch, mask, targets = item
|
| 55 |
+
class_weights = torch.ones(targets.shape)
|
| 56 |
+
if next(model.parameters()).is_cuda:
|
| 57 |
+
targets = targets.cuda()
|
| 58 |
+
mask = mask.cuda()
|
| 59 |
+
class_weights = class_weights.cuda()
|
| 60 |
+
with torch.no_grad():
|
| 61 |
+
batch_preds = model(batch, features_batch)
|
| 62 |
+
iter_count += 1
|
| 63 |
+
if args.fingerprint:
|
| 64 |
+
preds.extend(batch_preds.data.cpu().numpy())
|
| 65 |
+
continue
|
| 66 |
+
|
| 67 |
+
if loss_func is not None:
|
| 68 |
+
loss = loss_func(batch_preds, targets) * class_weights * mask
|
| 69 |
+
loss = loss.sum() / mask.sum()
|
| 70 |
+
loss_sum += loss.item()
|
| 71 |
+
# Collect vectors
|
| 72 |
+
batch_preds = batch_preds.data.cpu().numpy().tolist()
|
| 73 |
+
if scaler is not None:
|
| 74 |
+
batch_preds = scaler.inverse_transform(batch_preds)
|
| 75 |
+
preds.extend(batch_preds)
|
| 76 |
+
|
| 77 |
+
loss_avg = loss_sum / iter_count
|
| 78 |
+
return preds, loss_avg
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def make_predictions(args: Namespace, newest_train_args=None, smiles: List[str] = None):
|
| 82 |
+
"""
|
| 83 |
+
Makes predictions. If smiles is provided, makes predictions on smiles.
|
| 84 |
+
Otherwise makes predictions on args.test_data.
|
| 85 |
+
|
| 86 |
+
:param args: Arguments.
|
| 87 |
+
:param smiles: Smiles to make predictions on.
|
| 88 |
+
:return: A list of lists of target predictions.
|
| 89 |
+
"""
|
| 90 |
+
if args.gpu is not None:
|
| 91 |
+
torch.cuda.set_device(args.gpu)
|
| 92 |
+
|
| 93 |
+
print('Loading training args')
|
| 94 |
+
|
| 95 |
+
path = args.checkpoint_paths[0]
|
| 96 |
+
scaler, features_scaler = load_scalars(path)
|
| 97 |
+
train_args = load_args(path)
|
| 98 |
+
|
| 99 |
+
# Update args with training arguments saved in checkpoint
|
| 100 |
+
for key, value in vars(train_args).items():
|
| 101 |
+
if not hasattr(args, key):
|
| 102 |
+
setattr(args, key, value)
|
| 103 |
+
|
| 104 |
+
# update args with newest training args
|
| 105 |
+
if newest_train_args is not None:
|
| 106 |
+
for key, value in vars(newest_train_args).items():
|
| 107 |
+
if not hasattr(args, key):
|
| 108 |
+
setattr(args, key, value)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
# deal with multiprocess problem
|
| 112 |
+
args.debug = True
|
| 113 |
+
|
| 114 |
+
logger = create_logger('predict', quiet=False)
|
| 115 |
+
print('Loading data')
|
| 116 |
+
args.task_names = get_task_names(args.data_path)
|
| 117 |
+
if smiles is not None:
|
| 118 |
+
test_data = get_data_from_smiles(smiles=smiles, skip_invalid_smiles=False)
|
| 119 |
+
else:
|
| 120 |
+
test_data = get_data(path=args.data_path, args=args,
|
| 121 |
+
use_compound_names=args.use_compound_names, skip_invalid_smiles=False)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
args.num_tasks = test_data.num_tasks()
|
| 125 |
+
args.features_size = test_data.features_size()
|
| 126 |
+
|
| 127 |
+
print('Validating SMILES')
|
| 128 |
+
valid_indices = [i for i in range(len(test_data))]
|
| 129 |
+
full_data = test_data
|
| 130 |
+
# test_data = MoleculeDataset([test_data[i] for i in valid_indices])
|
| 131 |
+
test_data_list = []
|
| 132 |
+
for i in valid_indices:
|
| 133 |
+
test_data_list.append(test_data[i])
|
| 134 |
+
test_data = MoleculeDataset(test_data_list)
|
| 135 |
+
|
| 136 |
+
# Edge case if empty list of smiles is provided
|
| 137 |
+
if len(test_data) == 0:
|
| 138 |
+
return [None] * len(full_data)
|
| 139 |
+
|
| 140 |
+
print(f'Test size = {len(test_data):,}')
|
| 141 |
+
|
| 142 |
+
# Normalize features
|
| 143 |
+
if hasattr(train_args, 'features_scaling'):
|
| 144 |
+
if train_args.features_scaling:
|
| 145 |
+
test_data.normalize_features(features_scaler)
|
| 146 |
+
|
| 147 |
+
# Predict with each model individually and sum predictions
|
| 148 |
+
if hasattr(args, 'num_tasks'):
|
| 149 |
+
sum_preds = np.zeros((len(test_data), args.num_tasks))
|
| 150 |
+
print(f'Predicting...')
|
| 151 |
+
shared_dict = {}
|
| 152 |
+
# loss_func = torch.nn.BCEWithLogitsLoss()
|
| 153 |
+
count = 0
|
| 154 |
+
for checkpoint_path in tqdm(args.checkpoint_paths, total=len(args.checkpoint_paths)):
|
| 155 |
+
# Load model
|
| 156 |
+
model = load_checkpoint(checkpoint_path, cuda=args.cuda, current_args=args, logger=logger)
|
| 157 |
+
model_preds, _ = predict(
|
| 158 |
+
model=model,
|
| 159 |
+
data=test_data,
|
| 160 |
+
batch_size=args.batch_size,
|
| 161 |
+
scaler=scaler,
|
| 162 |
+
shared_dict=shared_dict,
|
| 163 |
+
args=args,
|
| 164 |
+
logger=logger,
|
| 165 |
+
loss_func=None
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
if args.fingerprint:
|
| 169 |
+
return model_preds
|
| 170 |
+
|
| 171 |
+
sum_preds += np.array(model_preds, dtype=float)
|
| 172 |
+
count += 1
|
| 173 |
+
|
| 174 |
+
# Ensemble predictions
|
| 175 |
+
avg_preds = sum_preds / len(args.checkpoint_paths)
|
| 176 |
+
|
| 177 |
+
# Save predictions
|
| 178 |
+
assert len(test_data) == len(avg_preds)
|
| 179 |
+
|
| 180 |
+
# Put Nones for invalid smiles
|
| 181 |
+
args.valid_indices = valid_indices
|
| 182 |
+
avg_preds = np.array(avg_preds)
|
| 183 |
+
test_smiles = full_data.smiles()
|
| 184 |
+
return avg_preds, test_smiles
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def write_prediction(avg_preds, test_smiles, args):
|
| 188 |
+
"""
|
| 189 |
+
write prediction to disk
|
| 190 |
+
:param avg_preds: prediction value
|
| 191 |
+
:param test_smiles: input smiles
|
| 192 |
+
:param args: Arguments
|
| 193 |
+
"""
|
| 194 |
+
if args.dataset_type == 'multiclass':
|
| 195 |
+
avg_preds = np.argmax(avg_preds, -1)
|
| 196 |
+
full_preds = [[None]] * len(test_smiles)
|
| 197 |
+
for i, si in enumerate(args.valid_indices):
|
| 198 |
+
full_preds[si] = avg_preds[i]
|
| 199 |
+
result = pd.DataFrame(data=full_preds, index=test_smiles, columns=args.task_names)
|
| 200 |
+
result.to_csv(args.output_path)
|
| 201 |
+
print(f'Saving predictions to {args.output_path}')
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def evaluate_predictions(preds: List[List[float]],
|
| 206 |
+
targets: List[List[float]],
|
| 207 |
+
num_tasks: int,
|
| 208 |
+
metric_func,
|
| 209 |
+
dataset_type: str,
|
| 210 |
+
logger = None) -> List[float]:
|
| 211 |
+
"""
|
| 212 |
+
Evaluates predictions using a metric function and filtering out invalid targets.
|
| 213 |
+
|
| 214 |
+
:param preds: A list of lists of shape (data_size, num_tasks) with model predictions.
|
| 215 |
+
:param targets: A list of lists of shape (data_size, num_tasks) with targets.
|
| 216 |
+
:param num_tasks: Number of tasks.
|
| 217 |
+
:param metric_func: Metric function which takes in a list of targets and a list of predictions.
|
| 218 |
+
:param dataset_type: Dataset type.
|
| 219 |
+
:param logger: Logger.
|
| 220 |
+
:return: A list with the score for each task based on `metric_func`.
|
| 221 |
+
"""
|
| 222 |
+
if dataset_type == 'multiclass':
|
| 223 |
+
results = metric_func(np.argmax(preds, -1), [i[0] for i in targets])
|
| 224 |
+
return [results]
|
| 225 |
+
|
| 226 |
+
# info = logger.info if logger is not None else print
|
| 227 |
+
|
| 228 |
+
if len(preds) == 0:
|
| 229 |
+
return [float('nan')] * num_tasks
|
| 230 |
+
|
| 231 |
+
# Filter out empty targets
|
| 232 |
+
# valid_preds and valid_targets have shape (num_tasks, data_size)
|
| 233 |
+
valid_preds = [[] for _ in range(num_tasks)]
|
| 234 |
+
valid_targets = [[] for _ in range(num_tasks)]
|
| 235 |
+
for i in range(num_tasks):
|
| 236 |
+
for j in range(len(preds)):
|
| 237 |
+
if targets[j][i] is not None: # Skip those without targets
|
| 238 |
+
valid_preds[i].append(preds[j][i])
|
| 239 |
+
valid_targets[i].append(targets[j][i])
|
| 240 |
+
|
| 241 |
+
# Compute metric
|
| 242 |
+
results = []
|
| 243 |
+
for i in range(num_tasks):
|
| 244 |
+
# # Skip if all targets or preds are identical, otherwise we'll crash during classification
|
| 245 |
+
if dataset_type == 'classification':
|
| 246 |
+
nan = False
|
| 247 |
+
if all(target == 0 for target in valid_targets[i]) or all(target == 1 for target in valid_targets[i]):
|
| 248 |
+
nan = True
|
| 249 |
+
# info('Warning: Found a task with targets all 0s or all 1s')
|
| 250 |
+
if all(pred == 0 for pred in valid_preds[i]) or all(pred == 1 for pred in valid_preds[i]):
|
| 251 |
+
nan = True
|
| 252 |
+
# info('Warning: Found a task with predictions all 0s or all 1s')
|
| 253 |
+
|
| 254 |
+
if nan:
|
| 255 |
+
results.append(float('nan'))
|
| 256 |
+
continue
|
| 257 |
+
|
| 258 |
+
if len(valid_targets[i]) == 0:
|
| 259 |
+
continue
|
| 260 |
+
|
| 261 |
+
results.append(metric_func(valid_targets[i], valid_preds[i]))
|
| 262 |
+
|
| 263 |
+
return results
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def evaluate(model: nn.Module,
|
| 267 |
+
data: MoleculeDataset,
|
| 268 |
+
num_tasks: int,
|
| 269 |
+
metric_func,
|
| 270 |
+
loss_func,
|
| 271 |
+
batch_size: int,
|
| 272 |
+
dataset_type: str,
|
| 273 |
+
args: Namespace,
|
| 274 |
+
shared_dict,
|
| 275 |
+
scaler: StandardScaler = None,
|
| 276 |
+
logger = None) -> List[float]:
|
| 277 |
+
"""
|
| 278 |
+
Evaluates an ensemble of models on a dataset.
|
| 279 |
+
|
| 280 |
+
:param model: A model.
|
| 281 |
+
:param data: A MoleculeDataset.
|
| 282 |
+
:param num_tasks: Number of tasks.
|
| 283 |
+
:param metric_func: Metric function which takes in a list of targets and a list of predictions.
|
| 284 |
+
:param batch_size: Batch size.
|
| 285 |
+
:param dataset_type: Dataset type.
|
| 286 |
+
:param scaler: A StandardScaler object fit on the training targets.
|
| 287 |
+
:param logger: Logger.
|
| 288 |
+
:return: A list with the score for each task based on `metric_func`.
|
| 289 |
+
"""
|
| 290 |
+
preds, loss_avg = predict(
|
| 291 |
+
model=model,
|
| 292 |
+
data=data,
|
| 293 |
+
loss_func=loss_func,
|
| 294 |
+
batch_size=batch_size,
|
| 295 |
+
scaler=scaler,
|
| 296 |
+
shared_dict=shared_dict,
|
| 297 |
+
logger=logger,
|
| 298 |
+
args=args
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
targets = data.targets()
|
| 302 |
+
if scaler is not None:
|
| 303 |
+
targets = scaler.inverse_transform(targets)
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
results = evaluate_predictions(
|
| 308 |
+
preds=preds,
|
| 309 |
+
targets=targets,
|
| 310 |
+
num_tasks=num_tasks,
|
| 311 |
+
metric_func=metric_func,
|
| 312 |
+
dataset_type=dataset_type,
|
| 313 |
+
logger=logger
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
return results, loss_avg
|
task/pretrain.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
The GROVER pretrain function.
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
import time
|
| 6 |
+
from argparse import Namespace
|
| 7 |
+
from logging import Logger
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch.utils.data import DataLoader
|
| 11 |
+
|
| 12 |
+
from grover.data.dist_sampler import DistributedSampler
|
| 13 |
+
from grover.data.groverdataset import get_data, split_data, GroverCollator, BatchMolDataset
|
| 14 |
+
from grover.data.torchvocab import MolVocab
|
| 15 |
+
from grover.model.models import GROVEREmbedding
|
| 16 |
+
from grover.util.multi_gpu_wrapper import MultiGpuWrapper as mgw
|
| 17 |
+
from grover.util.nn_utils import param_count
|
| 18 |
+
from grover.util.utils import build_optimizer, build_lr_scheduler
|
| 19 |
+
from task.grovertrainer import GROVERTrainer
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def pretrain_model(args: Namespace, logger: Logger = None):
|
| 23 |
+
"""
|
| 24 |
+
The entrey of pretrain.
|
| 25 |
+
:param args: the argument.
|
| 26 |
+
:param logger: the logger.
|
| 27 |
+
:return:
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
# avoid auto optimized import by pycharm.
|
| 31 |
+
a = MolVocab
|
| 32 |
+
s_time = time.time()
|
| 33 |
+
run_training(args=args, logger=logger)
|
| 34 |
+
e_time = time.time()
|
| 35 |
+
print("Total Time: %.3f" % (e_time - s_time))
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def pre_load_data(dataset: BatchMolDataset, rank: int, num_replicas: int, sample_per_file: int = None, epoch: int = 0):
|
| 39 |
+
"""
|
| 40 |
+
Pre-load data at the beginning of each epoch.
|
| 41 |
+
:param dataset: the training dataset.
|
| 42 |
+
:param rank: the rank of the current worker.
|
| 43 |
+
:param num_replicas: the replicas.
|
| 44 |
+
:param sample_per_file: the number of the data points in each file. When sample_per_file is None, all data will be
|
| 45 |
+
loaded. It implies the testing phase. (TODO: bad design here.)
|
| 46 |
+
:param epoch: the epoch number.
|
| 47 |
+
:return:
|
| 48 |
+
"""
|
| 49 |
+
mock_sampler = DistributedSampler(dataset, num_replicas=num_replicas, rank=rank, shuffle=False,
|
| 50 |
+
sample_per_file=sample_per_file)
|
| 51 |
+
mock_sampler.set_epoch(epoch)
|
| 52 |
+
pre_indices = mock_sampler.get_indices()
|
| 53 |
+
for i in pre_indices:
|
| 54 |
+
dataset.load_data(i)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def run_training(args, logger):
|
| 58 |
+
"""
|
| 59 |
+
Run the pretrain task.
|
| 60 |
+
:param args:
|
| 61 |
+
:param logger:
|
| 62 |
+
:return:
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
# initalize the logger.
|
| 66 |
+
if logger is not None:
|
| 67 |
+
debug, _ = logger.debug, logger.info
|
| 68 |
+
else:
|
| 69 |
+
debug = print
|
| 70 |
+
|
| 71 |
+
# initialize the horovod library
|
| 72 |
+
if args.enable_multi_gpu:
|
| 73 |
+
mgw.init()
|
| 74 |
+
|
| 75 |
+
# binding training to GPUs.
|
| 76 |
+
master_worker = (mgw.rank() == 0) if args.enable_multi_gpu else True
|
| 77 |
+
# pin GPU to local rank. By default, we use gpu:0 for training.
|
| 78 |
+
local_gpu_idx = mgw.local_rank() if args.enable_multi_gpu else 0
|
| 79 |
+
with_cuda = args.cuda
|
| 80 |
+
if with_cuda:
|
| 81 |
+
torch.cuda.set_device(local_gpu_idx)
|
| 82 |
+
|
| 83 |
+
# get rank an number of workers
|
| 84 |
+
rank = mgw.rank() if args.enable_multi_gpu else 0
|
| 85 |
+
num_replicas = mgw.size() if args.enable_multi_gpu else 1
|
| 86 |
+
# print("Rank: %d Rep: %d" % (rank, num_replicas))
|
| 87 |
+
|
| 88 |
+
# load file paths of the data.
|
| 89 |
+
if master_worker:
|
| 90 |
+
print(args)
|
| 91 |
+
if args.enable_multi_gpu:
|
| 92 |
+
debug("Total workers: %d" % (mgw.size()))
|
| 93 |
+
debug('Loading data')
|
| 94 |
+
data, sample_per_file = get_data(data_path=args.data_path)
|
| 95 |
+
|
| 96 |
+
# data splitting
|
| 97 |
+
if master_worker:
|
| 98 |
+
debug(f'Splitting data with seed 0.')
|
| 99 |
+
train_data, test_data, _ = split_data(data=data, sizes=(0.9, 0.1, 0.0), seed=0, logger=logger)
|
| 100 |
+
|
| 101 |
+
# Here the true train data size is the train_data divided by #GPUs
|
| 102 |
+
if args.enable_multi_gpu:
|
| 103 |
+
args.train_data_size = len(train_data) // mgw.size()
|
| 104 |
+
else:
|
| 105 |
+
args.train_data_size = len(train_data)
|
| 106 |
+
if master_worker:
|
| 107 |
+
debug(f'Total size = {len(data):,} | '
|
| 108 |
+
f'train size = {len(train_data):,} | val size = {len(test_data):,}')
|
| 109 |
+
|
| 110 |
+
# load atom and bond vocabulary and the semantic motif labels.
|
| 111 |
+
atom_vocab = MolVocab.load_vocab(args.atom_vocab_path)
|
| 112 |
+
bond_vocab = MolVocab.load_vocab(args.bond_vocab_path)
|
| 113 |
+
atom_vocab_size, bond_vocab_size = len(atom_vocab), len(bond_vocab)
|
| 114 |
+
|
| 115 |
+
# Hard coding here, since we haven't load any data yet!
|
| 116 |
+
fg_size = 85
|
| 117 |
+
shared_dict = {}
|
| 118 |
+
mol_collator = GroverCollator(shared_dict=shared_dict, atom_vocab=atom_vocab, bond_vocab=bond_vocab, args=args)
|
| 119 |
+
if master_worker:
|
| 120 |
+
debug("atom vocab size: %d, bond vocab size: %d, Number of FG tasks: %d" % (atom_vocab_size,
|
| 121 |
+
bond_vocab_size, fg_size))
|
| 122 |
+
|
| 123 |
+
# Define the distributed sampler. If using the single card, the sampler will be None.
|
| 124 |
+
train_sampler = None
|
| 125 |
+
test_sampler = None
|
| 126 |
+
shuffle = True
|
| 127 |
+
if args.enable_multi_gpu:
|
| 128 |
+
# If not shuffle, the performance may decayed.
|
| 129 |
+
train_sampler = DistributedSampler(
|
| 130 |
+
train_data, num_replicas=mgw.size(), rank=mgw.rank(), shuffle=True, sample_per_file=sample_per_file)
|
| 131 |
+
# Here sample_per_file in test_sampler is None, indicating the test sampler would not divide the test samples by
|
| 132 |
+
# rank. (TODO: bad design here.)
|
| 133 |
+
test_sampler = DistributedSampler(
|
| 134 |
+
test_data, num_replicas=mgw.size(), rank=mgw.rank(), shuffle=False)
|
| 135 |
+
train_sampler.set_epoch(args.epochs)
|
| 136 |
+
test_sampler.set_epoch(1)
|
| 137 |
+
# if we enables multi_gpu training. shuffle should be disabled.
|
| 138 |
+
shuffle = False
|
| 139 |
+
|
| 140 |
+
# Pre load data. (Maybe unnecessary. )
|
| 141 |
+
pre_load_data(train_data, rank, num_replicas, sample_per_file)
|
| 142 |
+
pre_load_data(test_data, rank, num_replicas)
|
| 143 |
+
if master_worker:
|
| 144 |
+
# print("Pre-loaded training data: %d" % train_data.count_loaded_datapoints())
|
| 145 |
+
print("Pre-loaded test data: %d" % test_data.count_loaded_datapoints())
|
| 146 |
+
|
| 147 |
+
# Build dataloader
|
| 148 |
+
train_data_dl = DataLoader(train_data,
|
| 149 |
+
batch_size=args.batch_size,
|
| 150 |
+
shuffle=shuffle,
|
| 151 |
+
num_workers=12,
|
| 152 |
+
sampler=train_sampler,
|
| 153 |
+
collate_fn=mol_collator)
|
| 154 |
+
test_data_dl = DataLoader(test_data,
|
| 155 |
+
batch_size=args.batch_size,
|
| 156 |
+
shuffle=shuffle,
|
| 157 |
+
num_workers=10,
|
| 158 |
+
sampler=test_sampler,
|
| 159 |
+
collate_fn=mol_collator)
|
| 160 |
+
|
| 161 |
+
# Build the embedding model.
|
| 162 |
+
grover_model = GROVEREmbedding(args)
|
| 163 |
+
|
| 164 |
+
# Build the trainer.
|
| 165 |
+
trainer = GROVERTrainer(args=args,
|
| 166 |
+
embedding_model=grover_model,
|
| 167 |
+
atom_vocab_size=atom_vocab_size,
|
| 168 |
+
bond_vocab_size=bond_vocab_size,
|
| 169 |
+
fg_szie=fg_size,
|
| 170 |
+
train_dataloader=train_data_dl,
|
| 171 |
+
test_dataloader=test_data_dl,
|
| 172 |
+
optimizer_builder=build_optimizer,
|
| 173 |
+
scheduler_builder=build_lr_scheduler,
|
| 174 |
+
logger=logger,
|
| 175 |
+
with_cuda=with_cuda,
|
| 176 |
+
enable_multi_gpu=args.enable_multi_gpu)
|
| 177 |
+
|
| 178 |
+
# Restore the interrupted training.
|
| 179 |
+
model_dir = os.path.join(args.save_dir, "model")
|
| 180 |
+
resume_from_epoch = 0
|
| 181 |
+
resume_scheduler_step = 0
|
| 182 |
+
if master_worker:
|
| 183 |
+
resume_from_epoch, resume_scheduler_step = trainer.restore(model_dir)
|
| 184 |
+
if args.enable_multi_gpu:
|
| 185 |
+
resume_from_epoch = mgw.broadcast(torch.tensor(resume_from_epoch), root_rank=0, name="resume_from_epoch").item()
|
| 186 |
+
resume_scheduler_step = mgw.broadcast(torch.tensor(resume_scheduler_step),
|
| 187 |
+
root_rank=0, name="resume_scheduler_step").item()
|
| 188 |
+
trainer.scheduler.current_step = resume_scheduler_step
|
| 189 |
+
print("Restored epoch: %d Restored scheduler step: %d" % (resume_from_epoch, trainer.scheduler.current_step))
|
| 190 |
+
trainer.broadcast_parameters()
|
| 191 |
+
|
| 192 |
+
# Print model details.
|
| 193 |
+
if master_worker:
|
| 194 |
+
# Change order here.
|
| 195 |
+
print(grover_model)
|
| 196 |
+
print("Total parameters: %d" % param_count(trainer.grover))
|
| 197 |
+
|
| 198 |
+
# Perform training.
|
| 199 |
+
for epoch in range(resume_from_epoch + 1, args.epochs):
|
| 200 |
+
s_time = time.time()
|
| 201 |
+
|
| 202 |
+
# Data pre-loading.
|
| 203 |
+
if args.enable_multi_gpu:
|
| 204 |
+
train_sampler.set_epoch(epoch)
|
| 205 |
+
train_data.clean_cache()
|
| 206 |
+
idxs = train_sampler.get_indices()
|
| 207 |
+
for local_gpu_idx in idxs:
|
| 208 |
+
train_data.load_data(local_gpu_idx)
|
| 209 |
+
d_time = time.time() - s_time
|
| 210 |
+
|
| 211 |
+
# perform training and validation.
|
| 212 |
+
s_time = time.time()
|
| 213 |
+
_, train_loss, _ = trainer.train(epoch)
|
| 214 |
+
t_time = time.time() - s_time
|
| 215 |
+
s_time = time.time()
|
| 216 |
+
_, val_loss, detailed_loss_val = trainer.test(epoch)
|
| 217 |
+
val_av_loss, val_bv_loss, val_fg_loss, _, _, _ = detailed_loss_val
|
| 218 |
+
v_time = time.time() - s_time
|
| 219 |
+
|
| 220 |
+
# print information.
|
| 221 |
+
if master_worker:
|
| 222 |
+
print('Epoch: {:04d}'.format(epoch),
|
| 223 |
+
'loss_train: {:.6f}'.format(train_loss),
|
| 224 |
+
'loss_val: {:.6f}'.format(val_loss),
|
| 225 |
+
'loss_val_av: {:.6f}'.format(val_av_loss),
|
| 226 |
+
'loss_val_bv: {:.6f}'.format(val_bv_loss),
|
| 227 |
+
'loss_val_fg: {:.6f}'.format(val_fg_loss),
|
| 228 |
+
'cur_lr: {:.5f}'.format(trainer.scheduler.get_lr()[0]),
|
| 229 |
+
't_time: {:.4f}s'.format(t_time),
|
| 230 |
+
'v_time: {:.4f}s'.format(v_time),
|
| 231 |
+
'd_time: {:.4f}s'.format(d_time), flush=True)
|
| 232 |
+
|
| 233 |
+
if epoch % args.save_interval == 0:
|
| 234 |
+
trainer.save(epoch, model_dir)
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
trainer.save_tmp(epoch, model_dir, rank)
|
| 238 |
+
|
| 239 |
+
# Only save final version.
|
| 240 |
+
if master_worker:
|
| 241 |
+
trainer.save(args.epochs, model_dir, "")
|
task/run_evaluation.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
The evaluation function.
|
| 3 |
+
"""
|
| 4 |
+
from argparse import Namespace
|
| 5 |
+
from logging import Logger
|
| 6 |
+
from typing import List
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
import torch.utils.data.distributed
|
| 11 |
+
|
| 12 |
+
from grover.data.scaler import StandardScaler
|
| 13 |
+
from grover.util.utils import get_class_sizes, get_data, split_data, get_task_names, get_loss_func
|
| 14 |
+
from grover.util.utils import load_checkpoint
|
| 15 |
+
from task.predict import evaluate_predictions
|
| 16 |
+
from grover.util.metrics import get_metric_func
|
| 17 |
+
from grover.util.nn_utils import param_count
|
| 18 |
+
from task.predict import predict
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def run_evaluation(args: Namespace, logger: Logger = None) -> List[float]:
|
| 22 |
+
"""
|
| 23 |
+
Trains a model and returns test scores on the model checkpoint with the highest validation score.
|
| 24 |
+
|
| 25 |
+
:param args: Arguments.
|
| 26 |
+
:param logger: Logger.
|
| 27 |
+
:return: A list of ensemble scores for each task.
|
| 28 |
+
"""
|
| 29 |
+
if logger is not None:
|
| 30 |
+
debug, info = logger.debug, logger.info
|
| 31 |
+
else:
|
| 32 |
+
debug = info = print
|
| 33 |
+
|
| 34 |
+
torch.cuda.set_device(0)
|
| 35 |
+
|
| 36 |
+
# Get data
|
| 37 |
+
debug('Loading data')
|
| 38 |
+
args.task_names = get_task_names(args.data_path)
|
| 39 |
+
data = get_data(path=args.data_path, args=args, logger=logger)
|
| 40 |
+
args.num_tasks = data.num_tasks()
|
| 41 |
+
args.features_size = data.features_size()
|
| 42 |
+
debug(f'Number of tasks = {args.num_tasks}')
|
| 43 |
+
|
| 44 |
+
# Split data
|
| 45 |
+
debug(f'Splitting data with seed {args.seed}')
|
| 46 |
+
|
| 47 |
+
train_data, val_data, test_data = split_data(data=data,
|
| 48 |
+
split_type=args.split_type,
|
| 49 |
+
sizes=[0.8, 0.1, 0.1],
|
| 50 |
+
seed=args.seed,
|
| 51 |
+
args=args,
|
| 52 |
+
logger=logger)
|
| 53 |
+
|
| 54 |
+
if args.dataset_type == 'classification':
|
| 55 |
+
class_sizes = get_class_sizes(data)
|
| 56 |
+
debug('Class sizes')
|
| 57 |
+
for i, task_class_sizes in enumerate(class_sizes):
|
| 58 |
+
debug(f'{args.task_names[i]} '
|
| 59 |
+
f'{", ".join(f"{cls}: {size * 100:.2f}%" for cls, size in enumerate(task_class_sizes))}')
|
| 60 |
+
|
| 61 |
+
if args.features_scaling:
|
| 62 |
+
features_scaler = train_data.normalize_features(replace_nan_token=0)
|
| 63 |
+
val_data.normalize_features(features_scaler)
|
| 64 |
+
test_data.normalize_features(features_scaler)
|
| 65 |
+
else:
|
| 66 |
+
features_scaler = None
|
| 67 |
+
|
| 68 |
+
args.train_data_size = len(train_data)
|
| 69 |
+
|
| 70 |
+
debug(f'Total size = {len(data):,} | '
|
| 71 |
+
f'train size = {len(train_data):,} | val size = {len(val_data):,} | test size = {len(test_data):,}')
|
| 72 |
+
|
| 73 |
+
# Initialize scaler (regression only)
|
| 74 |
+
scaler = None
|
| 75 |
+
if args.dataset_type == 'regression':
|
| 76 |
+
debug('Fitting scaler')
|
| 77 |
+
_, train_targets = train_data.smiles(), train_data.targets()
|
| 78 |
+
scaler = StandardScaler().fit(train_targets)
|
| 79 |
+
scaled_targets = scaler.transform(train_targets).tolist()
|
| 80 |
+
train_data.set_targets(scaled_targets)
|
| 81 |
+
|
| 82 |
+
val_targets = val_data.targets()
|
| 83 |
+
scaled_val_targets = scaler.transform(val_targets).tolist()
|
| 84 |
+
val_data.set_targets(scaled_val_targets)
|
| 85 |
+
|
| 86 |
+
metric_func = get_metric_func(metric=args.metric)
|
| 87 |
+
|
| 88 |
+
# Set up test set evaluation
|
| 89 |
+
test_smiles, test_targets = test_data.smiles(), test_data.targets()
|
| 90 |
+
sum_test_preds = np.zeros((len(test_smiles), args.num_tasks))
|
| 91 |
+
|
| 92 |
+
# Load/build model
|
| 93 |
+
if args.checkpoint_paths is not None:
|
| 94 |
+
cur_model = args.seed
|
| 95 |
+
target_path = []
|
| 96 |
+
for path in args.checkpoint_paths:
|
| 97 |
+
if "fold_%d" % cur_model in path:
|
| 98 |
+
target_path = path
|
| 99 |
+
debug(f'Loading model {args.seed} from {target_path}')
|
| 100 |
+
model = load_checkpoint(target_path, current_args=args, cuda=args.cuda, logger=logger)
|
| 101 |
+
# Get loss and metric functions
|
| 102 |
+
loss_func = get_loss_func(args, model)
|
| 103 |
+
|
| 104 |
+
debug(f'Number of parameters = {param_count(model):,}')
|
| 105 |
+
|
| 106 |
+
test_preds, _ = predict(
|
| 107 |
+
model=model,
|
| 108 |
+
data=test_data,
|
| 109 |
+
batch_size=args.batch_size,
|
| 110 |
+
loss_func=loss_func,
|
| 111 |
+
logger=logger,
|
| 112 |
+
shared_dict={},
|
| 113 |
+
scaler=scaler,
|
| 114 |
+
args=args
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
test_scores = evaluate_predictions(
|
| 118 |
+
preds=test_preds,
|
| 119 |
+
targets=test_targets,
|
| 120 |
+
num_tasks=args.num_tasks,
|
| 121 |
+
metric_func=metric_func,
|
| 122 |
+
dataset_type=args.dataset_type,
|
| 123 |
+
logger=logger
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
if len(test_preds) != 0:
|
| 127 |
+
sum_test_preds += np.array(test_preds, dtype=float)
|
| 128 |
+
|
| 129 |
+
# Average test score
|
| 130 |
+
avg_test_score = np.nanmean(test_scores)
|
| 131 |
+
info(f'Model test {args.metric} = {avg_test_score:.6f}')
|
| 132 |
+
|
| 133 |
+
if args.show_individual_scores:
|
| 134 |
+
# Individual test scores
|
| 135 |
+
for task_name, test_score in zip(args.task_names, test_scores):
|
| 136 |
+
info(f'Model test {task_name} {args.metric} = {test_score:.6f}')
|
| 137 |
+
|
| 138 |
+
# Evaluate ensemble on test set
|
| 139 |
+
avg_test_preds = (sum_test_preds / args.ensemble_size).tolist()
|
| 140 |
+
|
| 141 |
+
ensemble_scores = evaluate_predictions(
|
| 142 |
+
preds=avg_test_preds,
|
| 143 |
+
targets=test_targets,
|
| 144 |
+
num_tasks=args.num_tasks,
|
| 145 |
+
metric_func=metric_func,
|
| 146 |
+
dataset_type=args.dataset_type,
|
| 147 |
+
logger=logger
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
# If you want to save the prediction result, uncomment these lines.
|
| 151 |
+
# ind = [['preds'] * args.num_tasks + ['targets'] * args.num_tasks, args.task_names * 2]
|
| 152 |
+
# ind = pd.MultiIndex.from_tuples(list(zip(*ind)))
|
| 153 |
+
# data = np.concatenate([np.array(avg_test_preds), np.array(test_targets)], 1)
|
| 154 |
+
# test_result = pd.DataFrame(data, index=test_smiles, columns=ind)
|
| 155 |
+
# test_result.to_csv(os.path.join(args.save_dir, 'test_result.csv'))
|
| 156 |
+
|
| 157 |
+
return ensemble_scores
|
task/train.py
ADDED
|
@@ -0,0 +1,454 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
The training function used in the finetuning task.
|
| 3 |
+
"""
|
| 4 |
+
import csv
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
import pickle
|
| 8 |
+
import time
|
| 9 |
+
from argparse import Namespace
|
| 10 |
+
from logging import Logger
|
| 11 |
+
from typing import List
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
import pandas as pd
|
| 15 |
+
import torch
|
| 16 |
+
from torch.optim.lr_scheduler import ExponentialLR
|
| 17 |
+
from torch.utils.data import DataLoader
|
| 18 |
+
|
| 19 |
+
from grover.data import MolCollator
|
| 20 |
+
from grover.data import StandardScaler
|
| 21 |
+
from grover.util.metrics import get_metric_func
|
| 22 |
+
from grover.util.nn_utils import initialize_weights, param_count
|
| 23 |
+
from grover.util.scheduler import NoamLR
|
| 24 |
+
from grover.util.utils import build_optimizer, build_lr_scheduler, makedirs, load_checkpoint, get_loss_func, \
|
| 25 |
+
save_checkpoint, build_model
|
| 26 |
+
from grover.util.utils import get_class_sizes, get_data, split_data, get_task_names
|
| 27 |
+
from task.predict import predict, evaluate, evaluate_predictions
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def train(epoch, model, data, loss_func, optimizer, scheduler,
|
| 32 |
+
shared_dict, args: Namespace, n_iter: int = 0,
|
| 33 |
+
logger: logging.Logger = None):
|
| 34 |
+
"""
|
| 35 |
+
Trains a model for an epoch.
|
| 36 |
+
|
| 37 |
+
:param model: Model.
|
| 38 |
+
:param data: A MoleculeDataset (or a list of MoleculeDatasets if using moe).
|
| 39 |
+
:param loss_func: Loss function.
|
| 40 |
+
:param optimizer: An Optimizer.
|
| 41 |
+
:param scheduler: A learning rate scheduler.
|
| 42 |
+
:param args: Arguments.
|
| 43 |
+
:param n_iter: The number of iterations (training examples) trained on so far.
|
| 44 |
+
:param logger: A logger for printing intermediate results.
|
| 45 |
+
:param writer: A tensorboardX SummaryWriter.
|
| 46 |
+
:return: The total number of iterations (training examples) trained on so far.
|
| 47 |
+
"""
|
| 48 |
+
# debug = logger.debug if logger is not None else print
|
| 49 |
+
|
| 50 |
+
model.train()
|
| 51 |
+
|
| 52 |
+
# data.shuffle()
|
| 53 |
+
|
| 54 |
+
loss_sum, iter_count = 0, 0
|
| 55 |
+
cum_loss_sum, cum_iter_count = 0, 0
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
mol_collator = MolCollator(shared_dict=shared_dict, args=args)
|
| 59 |
+
|
| 60 |
+
num_workers = 4
|
| 61 |
+
if type(data) == DataLoader:
|
| 62 |
+
mol_loader = data
|
| 63 |
+
else:
|
| 64 |
+
mol_loader = DataLoader(data, batch_size=args.batch_size, shuffle=True,
|
| 65 |
+
num_workers=num_workers, collate_fn=mol_collator)
|
| 66 |
+
|
| 67 |
+
for _, item in enumerate(mol_loader):
|
| 68 |
+
_, batch, features_batch, mask, targets = item
|
| 69 |
+
if next(model.parameters()).is_cuda:
|
| 70 |
+
mask, targets = mask.cuda(), targets.cuda()
|
| 71 |
+
class_weights = torch.ones(targets.shape)
|
| 72 |
+
|
| 73 |
+
if args.cuda:
|
| 74 |
+
class_weights = class_weights.cuda()
|
| 75 |
+
|
| 76 |
+
# Run model
|
| 77 |
+
model.zero_grad()
|
| 78 |
+
preds = model(batch, features_batch)
|
| 79 |
+
loss = loss_func(preds, targets) * class_weights * mask
|
| 80 |
+
loss = loss.sum() / mask.sum()
|
| 81 |
+
|
| 82 |
+
loss_sum += loss.item()
|
| 83 |
+
iter_count += args.batch_size
|
| 84 |
+
|
| 85 |
+
cum_loss_sum += loss.item()
|
| 86 |
+
cum_iter_count += 1
|
| 87 |
+
|
| 88 |
+
loss.backward()
|
| 89 |
+
optimizer.step()
|
| 90 |
+
|
| 91 |
+
if isinstance(scheduler, NoamLR):
|
| 92 |
+
scheduler.step()
|
| 93 |
+
|
| 94 |
+
n_iter += args.batch_size
|
| 95 |
+
|
| 96 |
+
#if (n_iter // args.batch_size) % args.log_frequency == 0:
|
| 97 |
+
# lrs = scheduler.get_lr()
|
| 98 |
+
# loss_avg = loss_sum / iter_count
|
| 99 |
+
# loss_sum, iter_count = 0, 0
|
| 100 |
+
# lrs_str = ', '.join(f'lr_{i} = {lr:.4e}' for i, lr in enumerate(lrs))
|
| 101 |
+
|
| 102 |
+
return n_iter, cum_loss_sum / cum_iter_count
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def run_training(args: Namespace, time_start, logger: Logger = None) -> List[float]:
|
| 106 |
+
"""
|
| 107 |
+
Trains a model and returns test scores on the model checkpoint with the highest validation score.
|
| 108 |
+
|
| 109 |
+
:param args: Arguments.
|
| 110 |
+
:param logger: Logger.
|
| 111 |
+
:return: A list of ensemble scores for each task.
|
| 112 |
+
"""
|
| 113 |
+
if logger is not None:
|
| 114 |
+
debug, info = logger.debug, logger.info
|
| 115 |
+
else:
|
| 116 |
+
debug = info = print
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
# pin GPU to local rank.
|
| 120 |
+
idx = args.gpu
|
| 121 |
+
if args.gpu is not None:
|
| 122 |
+
torch.cuda.set_device(idx)
|
| 123 |
+
|
| 124 |
+
features_scaler, scaler, shared_dict, test_data, train_data, val_data = load_data(args, debug, logger)
|
| 125 |
+
|
| 126 |
+
metric_func = get_metric_func(metric=args.metric)
|
| 127 |
+
|
| 128 |
+
# Set up test set evaluation
|
| 129 |
+
test_smiles, test_targets = test_data.smiles(), test_data.targets()
|
| 130 |
+
sum_test_preds = np.zeros((len(test_smiles), args.num_tasks))
|
| 131 |
+
|
| 132 |
+
# Train ensemble of models
|
| 133 |
+
for model_idx in range(args.ensemble_size):
|
| 134 |
+
# Tensorboard writer
|
| 135 |
+
save_dir = os.path.join(args.save_dir, f'model_{model_idx}')
|
| 136 |
+
makedirs(save_dir)
|
| 137 |
+
|
| 138 |
+
# Load/build model
|
| 139 |
+
if args.checkpoint_paths is not None:
|
| 140 |
+
if len(args.checkpoint_paths) == 1:
|
| 141 |
+
cur_model = 0
|
| 142 |
+
else:
|
| 143 |
+
cur_model = model_idx
|
| 144 |
+
debug(f'Loading model {cur_model} from {args.checkpoint_paths[cur_model]}')
|
| 145 |
+
model = load_checkpoint(args.checkpoint_paths[cur_model], current_args=args, logger=logger)
|
| 146 |
+
else:
|
| 147 |
+
debug(f'Building model {model_idx}')
|
| 148 |
+
model = build_model(model_idx=model_idx, args=args)
|
| 149 |
+
|
| 150 |
+
if args.fine_tune_coff != 1 and args.checkpoint_paths is not None:
|
| 151 |
+
debug("Fine tune fc layer with different lr")
|
| 152 |
+
initialize_weights(model_idx=model_idx, model=model.ffn, distinct_init=args.distinct_init)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
############### FREEZE BLOCK ###########
|
| 156 |
+
# for name, param in model.named_parameters():
|
| 157 |
+
# if name.startswith("grover."):
|
| 158 |
+
# param.requires_grad = False
|
| 159 |
+
|
| 160 |
+
# # Train prediction layers (readout + two FFNs)
|
| 161 |
+
# else:
|
| 162 |
+
# param.requires_grad = True
|
| 163 |
+
|
| 164 |
+
# print("TRAINABLE PARAMETERS:")
|
| 165 |
+
# for name, p in model.named_parameters():
|
| 166 |
+
# if p.requires_grad:
|
| 167 |
+
# print(" ", name)
|
| 168 |
+
############### FREEZE BLOCK ###########
|
| 169 |
+
|
| 170 |
+
# Get loss and metric functions
|
| 171 |
+
loss_func = get_loss_func(args, model)
|
| 172 |
+
|
| 173 |
+
optimizer = build_optimizer(model, args)
|
| 174 |
+
|
| 175 |
+
debug(model)
|
| 176 |
+
debug(f'Number of parameters = {param_count(model):,}')
|
| 177 |
+
if args.cuda:
|
| 178 |
+
debug('Moving model to cuda')
|
| 179 |
+
model = model.cuda()
|
| 180 |
+
|
| 181 |
+
# Ensure that model is saved in correct location for evaluation if 0 epochs
|
| 182 |
+
save_checkpoint(os.path.join(save_dir, 'model.pt'), model, scaler, features_scaler, args)
|
| 183 |
+
|
| 184 |
+
# Learning rate schedulers
|
| 185 |
+
scheduler = build_lr_scheduler(optimizer, args)
|
| 186 |
+
|
| 187 |
+
# Bulid data_loader
|
| 188 |
+
shuffle = True
|
| 189 |
+
mol_collator = MolCollator(shared_dict={}, args=args)
|
| 190 |
+
train_data = DataLoader(train_data,
|
| 191 |
+
batch_size=args.batch_size,
|
| 192 |
+
shuffle=shuffle,
|
| 193 |
+
num_workers=10,
|
| 194 |
+
collate_fn=mol_collator)
|
| 195 |
+
|
| 196 |
+
# Run training
|
| 197 |
+
best_score = float('inf') if args.minimize_score else -float('inf')
|
| 198 |
+
best_epoch, n_iter = 0, 0
|
| 199 |
+
min_val_loss = float('inf')
|
| 200 |
+
for epoch in range(args.epochs):
|
| 201 |
+
s_time = time.time()
|
| 202 |
+
n_iter, train_loss = train(
|
| 203 |
+
epoch=epoch,
|
| 204 |
+
model=model,
|
| 205 |
+
data=train_data,
|
| 206 |
+
loss_func=loss_func,
|
| 207 |
+
optimizer=optimizer,
|
| 208 |
+
scheduler=scheduler,
|
| 209 |
+
args=args,
|
| 210 |
+
n_iter=n_iter,
|
| 211 |
+
shared_dict=shared_dict,
|
| 212 |
+
logger=logger
|
| 213 |
+
)
|
| 214 |
+
t_time = time.time() - s_time
|
| 215 |
+
s_time = time.time()
|
| 216 |
+
val_scores, val_loss = evaluate(
|
| 217 |
+
model=model,
|
| 218 |
+
data=val_data,
|
| 219 |
+
loss_func=loss_func,
|
| 220 |
+
num_tasks=args.num_tasks,
|
| 221 |
+
metric_func=metric_func,
|
| 222 |
+
batch_size=args.batch_size,
|
| 223 |
+
dataset_type=args.dataset_type,
|
| 224 |
+
scaler=scaler,
|
| 225 |
+
shared_dict=shared_dict,
|
| 226 |
+
logger=logger,
|
| 227 |
+
args=args
|
| 228 |
+
)
|
| 229 |
+
v_time = time.time() - s_time
|
| 230 |
+
# Average validation score
|
| 231 |
+
avg_val_score = np.nanmean(val_scores)
|
| 232 |
+
# Logged after lr step
|
| 233 |
+
if isinstance(scheduler, ExponentialLR):
|
| 234 |
+
scheduler.step()
|
| 235 |
+
|
| 236 |
+
if args.show_individual_scores:
|
| 237 |
+
# Individual validation scores
|
| 238 |
+
for task_name, val_score in zip(args.task_names, val_scores):
|
| 239 |
+
debug(f'Validation {task_name} {args.metric} = {val_score:.6f}')
|
| 240 |
+
print('Epoch: {:04d}'.format(epoch),
|
| 241 |
+
'loss_train: {:.6f}'.format(train_loss),
|
| 242 |
+
'loss_val: {:.6f}'.format(val_loss),
|
| 243 |
+
f'{args.metric}_val: {avg_val_score:.4f}',
|
| 244 |
+
# 'auc_val: {:.4f}'.format(avg_val_score),
|
| 245 |
+
'cur_lr: {:.5f}'.format(scheduler.get_lr()[-1]),
|
| 246 |
+
't_time: {:.4f}s'.format(t_time),
|
| 247 |
+
'v_time: {:.4f}s'.format(v_time))
|
| 248 |
+
|
| 249 |
+
if args.tensorboard:
|
| 250 |
+
writer.add_scalar('loss/train', train_loss, epoch)
|
| 251 |
+
writer.add_scalar('loss/val', val_loss, epoch)
|
| 252 |
+
writer.add_scalar(f'{args.metric}_val', avg_val_score, epoch)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
# Save model checkpoint if improved validation score
|
| 256 |
+
if args.select_by_loss:
|
| 257 |
+
if val_loss < min_val_loss:
|
| 258 |
+
min_val_loss, best_epoch = val_loss, epoch
|
| 259 |
+
save_checkpoint(os.path.join(save_dir, 'model.pt'), model, scaler, features_scaler, args)
|
| 260 |
+
else:
|
| 261 |
+
if args.minimize_score and avg_val_score < best_score or \
|
| 262 |
+
not args.minimize_score and avg_val_score > best_score:
|
| 263 |
+
best_score, best_epoch = avg_val_score, epoch
|
| 264 |
+
save_checkpoint(os.path.join(save_dir, 'model.pt'), model, scaler, features_scaler, args)
|
| 265 |
+
|
| 266 |
+
if epoch - best_epoch > args.early_stop_epoch:
|
| 267 |
+
break
|
| 268 |
+
|
| 269 |
+
ensemble_scores = 0.0
|
| 270 |
+
|
| 271 |
+
# Evaluate on test set using model with best validation score
|
| 272 |
+
if args.select_by_loss:
|
| 273 |
+
info(f'Model {model_idx} best val loss = {min_val_loss:.6f} on epoch {best_epoch}')
|
| 274 |
+
else:
|
| 275 |
+
info(f'Model {model_idx} best validation {args.metric} = {best_score:.6f} on epoch {best_epoch}')
|
| 276 |
+
model = load_checkpoint(os.path.join(save_dir, 'model.pt'), cuda=args.cuda, logger=logger)
|
| 277 |
+
|
| 278 |
+
test_preds, _ = predict(
|
| 279 |
+
model=model,
|
| 280 |
+
data=test_data,
|
| 281 |
+
loss_func=loss_func,
|
| 282 |
+
batch_size=args.batch_size,
|
| 283 |
+
logger=logger,
|
| 284 |
+
shared_dict=shared_dict,
|
| 285 |
+
scaler=scaler,
|
| 286 |
+
args=args
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
test_scores = evaluate_predictions(
|
| 290 |
+
preds=test_preds,
|
| 291 |
+
targets=test_targets,
|
| 292 |
+
num_tasks=args.num_tasks,
|
| 293 |
+
metric_func=metric_func,
|
| 294 |
+
dataset_type=args.dataset_type,
|
| 295 |
+
logger=logger
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
if len(test_preds) != 0:
|
| 299 |
+
sum_test_preds += np.array(test_preds, dtype=float)
|
| 300 |
+
|
| 301 |
+
# Average test score
|
| 302 |
+
avg_test_score = np.nanmean(test_scores)
|
| 303 |
+
info(f'Model {model_idx} test {args.metric} = {avg_test_score:.6f}')
|
| 304 |
+
|
| 305 |
+
if args.show_individual_scores:
|
| 306 |
+
# Individual test scores
|
| 307 |
+
for task_name, test_score in zip(args.task_names, test_scores):
|
| 308 |
+
info(f'Model {model_idx} test {task_name} {args.metric} = {test_score:.6f}')
|
| 309 |
+
|
| 310 |
+
# Evaluate ensemble on test set
|
| 311 |
+
avg_test_preds = (sum_test_preds / args.ensemble_size).tolist()
|
| 312 |
+
|
| 313 |
+
ensemble_scores = evaluate_predictions(
|
| 314 |
+
preds=avg_test_preds,
|
| 315 |
+
targets=test_targets,
|
| 316 |
+
num_tasks=args.num_tasks,
|
| 317 |
+
metric_func=metric_func,
|
| 318 |
+
dataset_type=args.dataset_type,
|
| 319 |
+
logger=logger
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
ind = [['preds'] * args.num_tasks + ['targets'] * args.num_tasks, args.task_names * 2]
|
| 323 |
+
ind = pd.MultiIndex.from_tuples(list(zip(*ind)))
|
| 324 |
+
data = np.concatenate([np.array(avg_test_preds), np.array(test_targets)], 1)
|
| 325 |
+
test_result = pd.DataFrame(data, index=test_smiles, columns=ind)
|
| 326 |
+
test_result.to_csv(os.path.join(args.save_dir, 'test_result.csv'))
|
| 327 |
+
|
| 328 |
+
# Average ensemble score
|
| 329 |
+
avg_ensemble_test_score = np.nanmean(ensemble_scores)
|
| 330 |
+
info(f'Ensemble test {args.metric} = {avg_ensemble_test_score:.6f}')
|
| 331 |
+
|
| 332 |
+
# Individual ensemble scores
|
| 333 |
+
if args.show_individual_scores:
|
| 334 |
+
for task_name, ensemble_score in zip(args.task_names, ensemble_scores):
|
| 335 |
+
info(f'Ensemble test {task_name} {args.metric} = {ensemble_score:.6f}')
|
| 336 |
+
|
| 337 |
+
return ensemble_scores
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
def load_data(args, debug, logger):
|
| 341 |
+
"""
|
| 342 |
+
load the training data.
|
| 343 |
+
:param args:
|
| 344 |
+
:param debug:
|
| 345 |
+
:param logger:
|
| 346 |
+
:return:
|
| 347 |
+
"""
|
| 348 |
+
# Get data
|
| 349 |
+
debug('Loading data')
|
| 350 |
+
args.task_names = get_task_names(args.data_path)
|
| 351 |
+
data = get_data(path=args.data_path, args=args, logger=logger)
|
| 352 |
+
if data.data[0].features is not None:
|
| 353 |
+
args.features_dim = len(data.data[0].features)
|
| 354 |
+
else:
|
| 355 |
+
args.features_dim = 0
|
| 356 |
+
shared_dict = {}
|
| 357 |
+
args.num_tasks = data.num_tasks()
|
| 358 |
+
args.features_size = data.features_size()
|
| 359 |
+
debug(f'Number of tasks = {args.num_tasks}')
|
| 360 |
+
# Split data
|
| 361 |
+
debug(f'Splitting data with seed {args.seed}')
|
| 362 |
+
if args.separate_test_path:
|
| 363 |
+
test_data = get_data(path=args.separate_test_path, args=args,
|
| 364 |
+
features_path=args.separate_test_features_path, logger=logger)
|
| 365 |
+
if args.separate_val_path:
|
| 366 |
+
val_data = get_data(path=args.separate_val_path, args=args,
|
| 367 |
+
features_path=args.separate_val_features_path, logger=logger)
|
| 368 |
+
if args.separate_val_path and args.separate_test_path:
|
| 369 |
+
train_data = data
|
| 370 |
+
elif args.separate_val_path:
|
| 371 |
+
train_data, _, test_data = split_data(data=data, split_type=args.split_type,
|
| 372 |
+
sizes=(0.8, 0.2, 0.0), seed=args.seed, args=args, logger=logger)
|
| 373 |
+
elif args.separate_test_path:
|
| 374 |
+
train_data, val_data, _ = split_data(data=data, split_type=args.split_type,
|
| 375 |
+
sizes=(0.8, 0.2, 0.0), seed=args.seed, args=args, logger=logger)
|
| 376 |
+
else:
|
| 377 |
+
train_data, val_data, test_data = split_data(data=data, split_type=args.split_type,
|
| 378 |
+
sizes=args.split_sizes, seed=args.seed, args=args, logger=logger)
|
| 379 |
+
if args.dataset_type == 'classification':
|
| 380 |
+
class_sizes = get_class_sizes(data)
|
| 381 |
+
debug('Class sizes')
|
| 382 |
+
for i, task_class_sizes in enumerate(class_sizes):
|
| 383 |
+
debug(f'{args.task_names[i]} '
|
| 384 |
+
f'{", ".join(f"{cls}: {size * 100:.2f}%" for cls, size in enumerate(task_class_sizes))}')
|
| 385 |
+
|
| 386 |
+
#if args.save_smiles_splits:
|
| 387 |
+
# save_splits(args, test_data, train_data, val_data)
|
| 388 |
+
|
| 389 |
+
if args.features_scaling:
|
| 390 |
+
features_scaler = train_data.normalize_features(replace_nan_token=0)
|
| 391 |
+
val_data.normalize_features(features_scaler)
|
| 392 |
+
test_data.normalize_features(features_scaler)
|
| 393 |
+
else:
|
| 394 |
+
features_scaler = None
|
| 395 |
+
args.train_data_size = len(train_data)
|
| 396 |
+
debug(f'Total size = {len(data):,} | '
|
| 397 |
+
f'train size = {len(train_data):,} | val size = {len(val_data):,} | test size = {len(test_data):,}')
|
| 398 |
+
|
| 399 |
+
# Initialize scaler and scale training targets by subtracting mean and dividing standard deviation (regression only)
|
| 400 |
+
if args.dataset_type == 'regression':
|
| 401 |
+
debug('Fitting scaler')
|
| 402 |
+
_, train_targets = train_data.smiles(), train_data.targets()
|
| 403 |
+
scaler = StandardScaler().fit(train_targets)
|
| 404 |
+
scaled_targets = scaler.transform(train_targets).tolist()
|
| 405 |
+
train_data.set_targets(scaled_targets)
|
| 406 |
+
|
| 407 |
+
val_targets = val_data.targets()
|
| 408 |
+
scaled_val_targets = scaler.transform(val_targets).tolist()
|
| 409 |
+
val_data.set_targets(scaled_val_targets)
|
| 410 |
+
else:
|
| 411 |
+
scaler = None
|
| 412 |
+
return features_scaler, scaler, shared_dict, test_data, train_data, val_data
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
def save_splits(args, test_data, train_data, val_data):
|
| 416 |
+
"""
|
| 417 |
+
Save the splits.
|
| 418 |
+
:param args:
|
| 419 |
+
:param test_data:
|
| 420 |
+
:param train_data:
|
| 421 |
+
:param val_data:
|
| 422 |
+
:return:
|
| 423 |
+
"""
|
| 424 |
+
with open(args.data_path, 'r') as f:
|
| 425 |
+
reader = csv.reader(f)
|
| 426 |
+
header = next(reader)
|
| 427 |
+
|
| 428 |
+
lines_by_smiles = {}
|
| 429 |
+
indices_by_smiles = {}
|
| 430 |
+
for i, line in enumerate(reader):
|
| 431 |
+
smiles = line[0]
|
| 432 |
+
lines_by_smiles[smiles] = line
|
| 433 |
+
indices_by_smiles[smiles] = i
|
| 434 |
+
|
| 435 |
+
all_split_indices = []
|
| 436 |
+
for dataset, name in [(train_data, 'train'), (val_data, 'val'), (test_data, 'test')]:
|
| 437 |
+
with open(os.path.join(args.save_dir, name + '_smiles.csv'), 'w') as f:
|
| 438 |
+
writer = csv.writer(f)
|
| 439 |
+
writer.writerow(['smiles'])
|
| 440 |
+
for smiles in dataset.smiles():
|
| 441 |
+
writer.writerow([smiles])
|
| 442 |
+
with open(os.path.join(args.save_dir, name + '_full.csv'), 'w') as f:
|
| 443 |
+
writer = csv.writer(f)
|
| 444 |
+
writer.writerow(header)
|
| 445 |
+
for smiles in dataset.smiles():
|
| 446 |
+
writer.writerow(lines_by_smiles[smiles])
|
| 447 |
+
split_indices = []
|
| 448 |
+
for smiles in dataset.smiles():
|
| 449 |
+
split_indices.append(indices_by_smiles[smiles])
|
| 450 |
+
split_indices = sorted(split_indices)
|
| 451 |
+
all_split_indices.append(split_indices)
|
| 452 |
+
with open(os.path.join(args.save_dir, 'split_indices.pckl'), 'wb') as f:
|
| 453 |
+
pickle.dump(all_split_indices, f)
|
| 454 |
+
return writer
|