hiitsmeme commited on
Commit
f986893
·
1 Parent(s): b25d2b6

added grover code, hf api files

Browse files
Dockerfile ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
2
+ # you will also find guides on how best to write your Dockerfile
3
+
4
+ FROM python:3.11.4
5
+
6
+ RUN useradd -m -u 1000 user
7
+ USER user
8
+ ENV PATH="/home/user/.local/bin:$PATH"
9
+
10
+ WORKDIR /app
11
+
12
+ COPY --chown=user ./requirements.txt requirements.txt
13
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
14
+
15
+ COPY --chown=user . /app
16
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
generate_features.py CHANGED
@@ -6,4 +6,5 @@ from src.commands import generate_features
6
  TRAIN_CSV = "./tox21/tox21_train_clean.csv"
7
  VAL_CSV = "./tox21/tox21_validation_clean.csv"
8
 
9
- generate_features(EXAMPLES_CSV, EXAMPLES_CSV.replace('.csv', '.npz'))
 
 
6
  TRAIN_CSV = "./tox21/tox21_train_clean.csv"
7
  VAL_CSV = "./tox21/tox21_validation_clean.csv"
8
 
9
+ generate_features(TRAIN_CSV, TRAIN_CSV.replace('.csv', '.npz'))
10
+ generate_features(VAL_CSV, VAL_CSV.replace('.csv', '.npz'))
grover DELETED
@@ -1 +0,0 @@
1
- Subproject commit 3f280d7d3419a781d303b1500c7039e37a1d87a2
 
 
grover/data/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from grover.data.molfeaturegenerator import get_available_features_generators, get_features_generator
2
+ from grover.data.molgraph import BatchMolGraph, get_atom_fdim, get_bond_fdim, mol2graph
3
+ from grover.data.molgraph import MolGraph, BatchMolGraph, MolCollator
4
+ from grover.data.moldataset import MoleculeDataset, MoleculeDatapoint
5
+ from grover.data.scaler import StandardScaler
6
+
7
+ # from .utils import load_features, save_features
grover/data/dist_sampler.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The re-implemented distributed sampler for the distributed training of GROVER.
3
+ """
4
+ import math
5
+ import time
6
+ import torch
7
+ from torch.utils.data.sampler import Sampler
8
+ import torch.distributed as dist
9
+
10
+
11
+ class DistributedSampler(Sampler):
12
+ """Sampler that restricts data loading to a subset of the dataset.
13
+
14
+ It is especially useful in conjunction with
15
+ :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
16
+ process can pass a DistributedSampler instance as a DataLoader sampler,
17
+ and load a subset of the original dataset that is exclusive to it.
18
+
19
+ .. note::
20
+ Dataset is assumed to be of constant size.
21
+
22
+ Arguments:
23
+ dataset: Dataset used for sampling.
24
+ num_replicas (optional): Number of processes participating in
25
+ distributed training.
26
+ rank (optional): Rank of the current process within num_replicas.
27
+ """
28
+
29
+ def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, sample_per_file=None):
30
+ if num_replicas is None:
31
+ if not dist.is_available():
32
+ raise RuntimeError("Requires distributed package to be available")
33
+ num_replicas = dist.get_world_size()
34
+ if rank is None:
35
+ if not dist.is_available():
36
+ raise RuntimeError("Requires distributed package to be available")
37
+ rank = dist.get_rank()
38
+ self.dataset = dataset
39
+ self.num_replicas = num_replicas
40
+ self.rank = rank
41
+ self.epoch = 0
42
+ self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
43
+ self.total_size = self.num_samples * self.num_replicas
44
+ self.sample_per_file = sample_per_file
45
+ self.shuffle = shuffle
46
+
47
+ def get_indices(self):
48
+
49
+ indices = list(range(len(self.dataset)))
50
+
51
+ if self.sample_per_file is not None:
52
+ indices = self.sub_indices_of_rank(indices)
53
+ else:
54
+ # add extra samples to make it evenly divisible
55
+ indices += indices[:(self.total_size - len(indices))]
56
+ assert len(indices) == self.total_size
57
+ # subsample
58
+ s = self.rank * self.num_samples
59
+ e = min((self.rank + 1) * self.num_samples, len(indices))
60
+
61
+ # indices = indices[self.rank:self.total_size:self.num_replicas]
62
+ indices = indices[s:e]
63
+
64
+ if self.shuffle:
65
+ g = torch.Generator()
66
+ # the seed need to be considered.
67
+ g.manual_seed((self.epoch + 1) * (self.rank + 1) * time.time())
68
+ idx = torch.randperm(len(indices), generator=g).tolist()
69
+ indices = [indices[i] for i in idx]
70
+
71
+ # disable this since sub_indices_of_rank.
72
+ # assert len(indices) == self.num_samples
73
+
74
+ return indices
75
+
76
+ def sub_indices_of_rank(self, indices):
77
+
78
+ # fix generator for each epoch
79
+ g = torch.Generator()
80
+ # All data should be loaded in each epoch.
81
+ g.manual_seed((self.epoch + 1) * 2 + 3)
82
+
83
+ # the fake file indices to cache
84
+ f_indices = list(range(int(math.ceil(len(indices) * 1.0 / self.sample_per_file))))
85
+ idx = torch.randperm(len(f_indices), generator=g).tolist()
86
+ f_indices = [f_indices[i] for i in idx]
87
+
88
+ file_per_rank = int(math.ceil(len(f_indices) * 1.0 / self.num_replicas))
89
+ # add extra fake file to make it evenly divisible
90
+ f_indices += f_indices[:(file_per_rank * self.num_replicas - len(f_indices))]
91
+
92
+ # divide index by rank
93
+ rank_s = self.rank * file_per_rank
94
+ rank_e = min((self.rank + 1) * file_per_rank, len(f_indices))
95
+
96
+ # get file index for this rank
97
+ f_indices = f_indices[rank_s:rank_e]
98
+ # print("f_indices")
99
+ # print(f_indices)
100
+ res_indices = []
101
+ for fi in f_indices:
102
+ # get real indices for this rank
103
+ si = fi * self.sample_per_file
104
+ ei = min((fi + 1) * self.sample_per_file, len(indices))
105
+ cur_idx = [indices[i] for i in range(si, ei)]
106
+ res_indices += cur_idx
107
+
108
+ self.num_samples = len(res_indices)
109
+ return res_indices
110
+
111
+ def __iter__(self):
112
+ return iter(self.get_indices())
113
+
114
+ def __len__(self):
115
+ return self.num_samples
116
+
117
+ def set_epoch(self, epoch):
118
+ self.epoch = epoch
119
+
120
+
121
+ if __name__ == "__main__":
122
+ # dataset = [1] * 9
123
+ # ds = DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=True)
124
+ # print(ds.get_indices())
125
+ # ds = DistributedSampler(dataset, num_replicas=2, rank=1, shuffle=True)
126
+ # print(ds.get_indices())
127
+
128
+ dataset = [1] * 190001
129
+ res = []
130
+ ds = DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=True, sample_per_file=777)
131
+ res.extend(ds.get_indices())
132
+ print(len(ds.get_indices()))
133
+ ds = DistributedSampler(dataset, num_replicas=2, rank=1, shuffle=True, sample_per_file=777)
134
+ res.extend(ds.get_indices())
135
+ print(len(ds.get_indices()))
136
+ print(len(set(res)))
137
+ print("hello")
grover/data/groverdataset.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The dataset used in training GROVER.
3
+ """
4
+ import math
5
+ import os
6
+ import csv
7
+ from typing import Union, List
8
+ import numpy as np
9
+ import torch
10
+ from torch.utils.data.dataset import Dataset
11
+ from rdkit import Chem
12
+
13
+ import grover.util.utils as feautils
14
+ from grover.data import mol2graph
15
+ from grover.data.moldataset import MoleculeDatapoint
16
+ from grover.data.task_labels import atom_to_vocab, bond_to_vocab
17
+
18
+
19
+ def get_data(data_path, logger=None):
20
+ """
21
+ Load data from the data_path.
22
+ :param data_path: the data_path.
23
+ :param logger: the logger.
24
+ :return:
25
+ """
26
+ debug = logger.debug if logger is not None else print
27
+ summary_path = os.path.join(data_path, "summary.txt")
28
+ smiles_path = os.path.join(data_path, "graph")
29
+ feature_path = os.path.join(data_path, "feature")
30
+
31
+ fin = open(summary_path)
32
+ n_files = int(fin.readline().strip().split(":")[-1])
33
+ n_samples = int(fin.readline().strip().split(":")[-1])
34
+ sample_per_file = int(fin.readline().strip().split(":")[-1])
35
+ debug("Loading data:")
36
+ debug("Number of files: %d" % n_files)
37
+ debug("Number of samples: %d" % n_samples)
38
+ debug("Samples/file: %d" % sample_per_file)
39
+
40
+ datapoints = []
41
+ for i in range(n_files):
42
+ smiles_path_i = os.path.join(smiles_path, str(i) + ".csv")
43
+ feature_path_i = os.path.join(feature_path, str(i) + ".npz")
44
+ n_samples_i = sample_per_file if i != (n_files - 1) else n_samples % sample_per_file
45
+ datapoints.append(BatchDatapoint(smiles_path_i, feature_path_i, n_samples_i))
46
+ return BatchMolDataset(datapoints), sample_per_file
47
+
48
+
49
+ def split_data(data,
50
+ split_type='random',
51
+ sizes=(0.8, 0.1, 0.1),
52
+ seed=0,
53
+ logger=None):
54
+ """
55
+ Split data with given train/validation/test ratio.
56
+ :param data:
57
+ :param split_type:
58
+ :param sizes:
59
+ :param seed:
60
+ :param logger:
61
+ :return:
62
+ """
63
+ assert len(sizes) == 3 and sum(sizes) == 1
64
+
65
+ if split_type == "random":
66
+ data.shuffle(seed=seed)
67
+ data = data.data
68
+
69
+ train_size = int(sizes[0] * len(data))
70
+ train_val_size = int((sizes[0] + sizes[1]) * len(data))
71
+
72
+ train = data[:train_size]
73
+ val = data[train_size:train_val_size]
74
+ test = data[train_val_size:]
75
+
76
+ return BatchMolDataset(train), BatchMolDataset(val), BatchMolDataset(test)
77
+ else:
78
+ raise NotImplementedError("Do not support %s splits" % split_type)
79
+
80
+
81
+ class BatchDatapoint:
82
+ def __init__(self,
83
+ smiles_file,
84
+ feature_file,
85
+ n_samples,
86
+ ):
87
+ self.smiles_file = smiles_file
88
+ self.feature_file = feature_file
89
+ # deal with the last batch graph numbers.
90
+ self.n_samples = n_samples
91
+ self.datapoints = None
92
+
93
+ def load_datapoints(self):
94
+ features = self.load_feature()
95
+ self.datapoints = []
96
+
97
+ with open(self.smiles_file) as f:
98
+ reader = csv.reader(f)
99
+ next(reader)
100
+ for i, line in enumerate(reader):
101
+ # line = line[0]
102
+ d = MoleculeDatapoint(line=line,
103
+ features=features[i])
104
+ self.datapoints.append(d)
105
+
106
+ assert len(self.datapoints) == self.n_samples
107
+
108
+ def load_feature(self):
109
+ return feautils.load_features(self.feature_file)
110
+
111
+ def shuffle(self):
112
+ pass
113
+
114
+ def clean_cache(self):
115
+ del self.datapoints
116
+ self.datapoints = None
117
+
118
+ def __len__(self):
119
+ return self.n_samples
120
+
121
+ def __getitem__(self, idx):
122
+ assert self.datapoints is not None
123
+ return self.datapoints[idx]
124
+
125
+ def is_loaded(self):
126
+ return self.datapoints is not None
127
+
128
+
129
+ class BatchMolDataset(Dataset):
130
+ def __init__(self, data: List[BatchDatapoint],
131
+ graph_per_file=None):
132
+ self.data = data
133
+
134
+ self.len = 0
135
+ for d in self.data:
136
+ self.len += len(d)
137
+ if graph_per_file is not None:
138
+ self.sample_per_file = graph_per_file
139
+ else:
140
+ self.sample_per_file = len(self.data[0]) if len(self.data) != 0 else None
141
+
142
+ def shuffle(self, seed: int = None):
143
+ pass
144
+
145
+ def clean_cache(self):
146
+ for d in self.data:
147
+ d.clean_cache()
148
+
149
+ def __len__(self) -> int:
150
+ return self.len
151
+
152
+ def __getitem__(self, idx) -> Union[MoleculeDatapoint, List[MoleculeDatapoint]]:
153
+ # print(idx)
154
+ dp_idx = int(idx / self.sample_per_file)
155
+ real_idx = idx % self.sample_per_file
156
+ return self.data[dp_idx][real_idx]
157
+
158
+ def load_data(self, idx):
159
+ dp_idx = int(idx / self.sample_per_file)
160
+ if not self.data[dp_idx].is_loaded():
161
+ self.data[dp_idx].load_datapoints()
162
+
163
+ def count_loaded_datapoints(self):
164
+ res = 0
165
+ for d in self.data:
166
+ if d.is_loaded():
167
+ res += 1
168
+ return res
169
+
170
+
171
+ class GroverCollator(object):
172
+ def __init__(self, shared_dict, atom_vocab, bond_vocab, args):
173
+ self.args = args
174
+ self.shared_dict = shared_dict
175
+ self.atom_vocab = atom_vocab
176
+ self.bond_vocab = bond_vocab
177
+
178
+ def atom_random_mask(self, smiles_batch):
179
+ """
180
+ Perform the random mask operation on atoms.
181
+ :param smiles_batch:
182
+ :return: The corresponding atom labels.
183
+ """
184
+ # There is a zero padding.
185
+ vocab_label = [0]
186
+ percent = 0.15
187
+ for smi in smiles_batch:
188
+ mol = Chem.MolFromSmiles(smi)
189
+ mlabel = [0] * mol.GetNumAtoms()
190
+ n_mask = math.ceil(mol.GetNumAtoms() * percent)
191
+ perm = np.random.permutation(mol.GetNumAtoms())[:n_mask]
192
+ for p in perm:
193
+ atom = mol.GetAtomWithIdx(int(p))
194
+ mlabel[p] = self.atom_vocab.stoi.get(atom_to_vocab(mol, atom), self.atom_vocab.other_index)
195
+
196
+ vocab_label.extend(mlabel)
197
+ return vocab_label
198
+
199
+ def bond_random_mask(self, smiles_batch):
200
+ """
201
+ Perform the random mask operaiion on bonds.
202
+ :param smiles_batch:
203
+ :return: The corresponding bond labels.
204
+ """
205
+ # There is a zero padding.
206
+ vocab_label = [0]
207
+ percent = 0.15
208
+ for smi in smiles_batch:
209
+ mol = Chem.MolFromSmiles(smi)
210
+ nm_atoms = mol.GetNumAtoms()
211
+ nm_bonds = mol.GetNumBonds()
212
+ mlabel = []
213
+ n_mask = math.ceil(nm_bonds * percent)
214
+ perm = np.random.permutation(nm_bonds)[:n_mask]
215
+ virtual_bond_id = 0
216
+ for a1 in range(nm_atoms):
217
+ for a2 in range(a1 + 1, nm_atoms):
218
+ bond = mol.GetBondBetweenAtoms(a1, a2)
219
+
220
+ if bond is None:
221
+ continue
222
+ if virtual_bond_id in perm:
223
+ label = self.bond_vocab.stoi.get(bond_to_vocab(mol, bond), self.bond_vocab.other_index)
224
+ mlabel.extend([label])
225
+ else:
226
+ mlabel.extend([0])
227
+
228
+ virtual_bond_id += 1
229
+ # todo: might need to consider bond_drop_rate
230
+ # todo: double check reverse bond
231
+ vocab_label.extend(mlabel)
232
+ return vocab_label
233
+
234
+ def __call__(self, batch):
235
+ smiles_batch = [d.smiles for d in batch]
236
+ batchgraph = mol2graph(smiles_batch, self.shared_dict, self.args).get_components()
237
+
238
+ atom_vocab_label = torch.Tensor(self.atom_random_mask(smiles_batch)).long()
239
+ bond_vocab_label = torch.Tensor(self.bond_random_mask(smiles_batch)).long()
240
+ fgroup_label = torch.Tensor([d.features for d in batch]).float()
241
+ # may be some mask here
242
+ res = {"graph_input": batchgraph,
243
+ "targets": {"av_task": atom_vocab_label,
244
+ "bv_task": bond_vocab_label,
245
+ "fg_task": fgroup_label}
246
+ }
247
+ return res
grover/data/moldataset.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The molecule dataset for finetuning.
3
+ This implementation is adapted from
4
+ https://github.com/chemprop/chemprop/blob/master/chemprop/data/data.py
5
+ """
6
+ import random
7
+ from argparse import Namespace
8
+ from typing import Callable, List, Union
9
+
10
+ import numpy as np
11
+ from rdkit import Chem
12
+ from torch.utils.data.dataset import Dataset
13
+
14
+ from grover.data.molfeaturegenerator import get_features_generator
15
+ from grover.data.scaler import StandardScaler
16
+
17
+
18
+ class MoleculeDatapoint:
19
+ """A MoleculeDatapoint contains a single molecule and its associated features and targets."""
20
+
21
+ def __init__(self,
22
+ line: List[str],
23
+ args: Namespace = None,
24
+ features: np.ndarray = None,
25
+ use_compound_names: bool = False):
26
+ """
27
+ Initializes a MoleculeDatapoint, which contains a single molecule.
28
+
29
+ :param line: A list of strings generated by separating a line in a data CSV file by comma.
30
+ :param args: Arguments.
31
+ :param features: A numpy array containing additional features (ex. Morgan fingerprint).
32
+ :param use_compound_names: Whether the data CSV includes the compound name on each line.
33
+ """
34
+ self.features_generator = None
35
+ self.args = None
36
+ if args is not None:
37
+ if hasattr(args, "features_generator"):
38
+ self.features_generator = args.features_generator
39
+ self.args = args
40
+
41
+ if features is not None and self.features_generator is not None:
42
+ raise ValueError('Currently cannot provide both loaded features and a features generator.')
43
+
44
+ self.features = features
45
+
46
+ if use_compound_names:
47
+ self.compound_name = line[0] # str
48
+ line = line[1:]
49
+ else:
50
+ self.compound_name = None
51
+
52
+ self.smiles = line[0] # str
53
+
54
+
55
+ # Generate additional features if given a generator
56
+ if self.features_generator is not None:
57
+ self.features = []
58
+ mol = Chem.MolFromSmiles(self.smiles)
59
+ for fg in self.features_generator:
60
+ features_generator = get_features_generator(fg)
61
+ if mol is not None and mol.GetNumHeavyAtoms() > 0:
62
+ if fg in ['morgan', 'morgan_count']:
63
+ self.features.extend(features_generator(mol, num_bits=args.num_bits))
64
+ else:
65
+ self.features.extend(features_generator(mol))
66
+
67
+ self.features = np.array(self.features)
68
+
69
+ # Fix nans in features
70
+ if self.features is not None:
71
+ replace_token = 0
72
+ self.features = np.where(np.isnan(self.features), replace_token, self.features)
73
+
74
+ # Create targets
75
+ self.targets = [float(x) if x != '' else None for x in line[1:]]
76
+
77
+ def set_features(self, features: np.ndarray):
78
+ """
79
+ Sets the features of the molecule.
80
+
81
+ :param features: A 1-D numpy array of features for the molecule.
82
+ """
83
+ self.features = features
84
+
85
+ def num_tasks(self) -> int:
86
+ """
87
+ Returns the number of prediction tasks.
88
+
89
+ :return: The number of tasks.
90
+ """
91
+ return len(self.targets)
92
+
93
+ def set_targets(self, targets: List[float]):
94
+ """
95
+ Sets the targets of a molecule.
96
+
97
+ :param targets: A list of floats containing the targets.
98
+ """
99
+ self.targets = targets
100
+
101
+
102
+ class MoleculeDataset(Dataset):
103
+ """A MoleculeDataset contains a list of molecules and their associated features and targets."""
104
+
105
+ def __init__(self, data: List[MoleculeDatapoint]):
106
+ """
107
+ Initializes a MoleculeDataset, which contains a list of MoleculeDatapoints (i.e. a list of molecules).
108
+
109
+ :param data: A list of MoleculeDatapoints.
110
+ """
111
+ self.data = data
112
+ self.args = self.data[0].args if len(self.data) > 0 else None
113
+ self.scaler = None
114
+
115
+ def compound_names(self) -> List[str]:
116
+ """
117
+ Returns the compound names associated with the molecule (if they exist).
118
+
119
+ :return: A list of compound names or None if the dataset does not contain compound names.
120
+ """
121
+ if len(self.data) == 0 or self.data[0].compound_name is None:
122
+ return None
123
+
124
+ return [d.compound_name for d in self.data]
125
+
126
+ def smiles(self) -> List[str]:
127
+ """
128
+ Returns the smiles strings associated with the molecules.
129
+
130
+ :return: A list of smiles strings.
131
+ """
132
+ return [d.smiles for d in self.data]
133
+
134
+ def features(self) -> List[np.ndarray]:
135
+ """
136
+ Returns the features associated with each molecule (if they exist).
137
+
138
+ :return: A list of 1D numpy arrays containing the features for each molecule or None if there are no features.
139
+ """
140
+ if len(self.data) == 0 or self.data[0].features is None:
141
+ return None
142
+
143
+ return [d.features for d in self.data]
144
+
145
+ def targets(self) -> List[List[float]]:
146
+ """
147
+ Returns the targets associated with each molecule.
148
+
149
+ :return: A list of lists of floats containing the targets.
150
+ """
151
+ return [d.targets for d in self.data]
152
+
153
+ def num_tasks(self) -> int:
154
+ """
155
+ Returns the number of prediction tasks.
156
+
157
+ :return: The number of tasks.
158
+ """
159
+ if self.args.dataset_type == 'multiclass':
160
+ return int(max([i[0] for i in self.targets()])) + 1
161
+ else:
162
+ return self.data[0].num_tasks() if len(self.data) > 0 else None
163
+
164
+ def features_size(self) -> int:
165
+ """
166
+ Returns the size of the features array associated with each molecule.
167
+
168
+ :return: The size of the features.
169
+ """
170
+ return len(self.data[0].features) if len(self.data) > 0 and self.data[0].features is not None else None
171
+
172
+ def shuffle(self, seed: int = None):
173
+ """
174
+ Shuffles the dataset.
175
+
176
+ :param seed: Optional random seed.
177
+ """
178
+ if seed is not None:
179
+ random.seed(seed)
180
+ random.shuffle(self.data)
181
+
182
+ def normalize_features(self, scaler: StandardScaler = None, replace_nan_token: int = 0) -> StandardScaler:
183
+ """
184
+ Normalizes the features of the dataset using a StandardScaler (subtract mean, divide by standard deviation).
185
+
186
+ If a scaler is provided, uses that scaler to perform the normalization. Otherwise fits a scaler to the
187
+ features in the dataset and then performs the normalization.
188
+
189
+ :param scaler: A fitted StandardScaler. Used if provided. Otherwise a StandardScaler is fit on
190
+ this dataset and is then used.
191
+ :param replace_nan_token: What to replace nans with.
192
+ :return: A fitted StandardScaler. If a scaler is provided, this is the same scaler. Otherwise, this is
193
+ a scaler fit on this dataset.
194
+ """
195
+ if len(self.data) == 0 or self.data[0].features is None:
196
+ return None
197
+
198
+ if scaler is not None:
199
+ self.scaler = scaler
200
+
201
+ elif self.scaler is None:
202
+ features = np.vstack([d.features for d in self.data])
203
+ self.scaler = StandardScaler(replace_nan_token=replace_nan_token)
204
+ self.scaler.fit(features)
205
+
206
+ for d in self.data:
207
+ d.set_features(self.scaler.transform(d.features.reshape(1, -1))[0])
208
+
209
+ return self.scaler
210
+
211
+ def set_targets(self, targets: List[List[float]]):
212
+ """
213
+ Sets the targets for each molecule in the dataset. Assumes the targets are aligned with the datapoints.
214
+
215
+ :param targets: A list of lists of floats containing targets for each molecule. This must be the
216
+ same length as the underlying dataset.
217
+ """
218
+ assert len(self.data) == len(targets)
219
+ for i in range(len(self.data)):
220
+ self.data[i].set_targets(targets[i])
221
+
222
+ def sort(self, key: Callable):
223
+ """
224
+ Sorts the dataset using the provided key.
225
+
226
+ :param key: A function on a MoleculeDatapoint to determine the sorting order.
227
+ """
228
+ self.data.sort(key=key)
229
+
230
+ def __len__(self) -> int:
231
+ """
232
+ Returns the length of the dataset (i.e. the number of molecules).
233
+
234
+ :return: The length of the dataset.
235
+ """
236
+ return len(self.data)
237
+
238
+ def __getitem__(self, idx) -> Union[MoleculeDatapoint, List[MoleculeDatapoint]]:
239
+ """
240
+ Gets one or more MoleculeDatapoints via an index or slice.
241
+
242
+ :param item: An index (int) or a slice object.
243
+ :return: A MoleculeDatapoint if an int is provided or a list of MoleculeDatapoints if a slice is provided.
244
+ """
245
+ return self.data[idx]
grover/data/molfeaturegenerator.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The registered feature generator for molecules.
3
+ This implementation is adapted from
4
+ https://github.com/chemprop/chemprop/blob/master/chemprop/features/features_generators.py
5
+ """
6
+
7
+ from typing import Callable, List, Union
8
+
9
+ import numpy as np
10
+ from rdkit import Chem, DataStructs
11
+ from rdkit.Chem import AllChem
12
+
13
+ Molecule = Union[str, Chem.Mol]
14
+ FeaturesGenerator = Callable[[Molecule], np.ndarray]
15
+ FEATURES_GENERATOR_REGISTRY = {}
16
+
17
+
18
+ def register_features_generator(features_generator_name: str) -> Callable[[FeaturesGenerator], FeaturesGenerator]:
19
+ """
20
+ Registers a features generator.
21
+
22
+ :param features_generator_name: The name to call the FeaturesGenerator.
23
+ :return: A decorator which will add a FeaturesGenerator to the registry using the specified name.
24
+ """
25
+ def decorator(features_generator: FeaturesGenerator) -> FeaturesGenerator:
26
+ FEATURES_GENERATOR_REGISTRY[features_generator_name] = features_generator
27
+ return features_generator
28
+
29
+ return decorator
30
+
31
+
32
+ def get_features_generator(features_generator_name: str) -> FeaturesGenerator:
33
+ """
34
+ Gets a registered FeaturesGenerator by name.
35
+
36
+ :param features_generator_name: The name of the FeaturesGenerator.
37
+ :return: The desired FeaturesGenerator.
38
+ """
39
+ if features_generator_name not in FEATURES_GENERATOR_REGISTRY:
40
+ raise ValueError(f'Features generator "{features_generator_name}" could not be found. '
41
+ f'If this generator relies on rdkit features, you may need to install descriptastorus.')
42
+
43
+ return FEATURES_GENERATOR_REGISTRY[features_generator_name]
44
+
45
+
46
+ def get_available_features_generators() -> List[str]:
47
+ """Returns the names of available features generators."""
48
+ return list(FEATURES_GENERATOR_REGISTRY.keys())
49
+
50
+
51
+ MORGAN_RADIUS = 2
52
+ MORGAN_NUM_BITS = 2048
53
+
54
+
55
+ @register_features_generator('morgan')
56
+ def morgan_binary_features_generator(mol: Molecule,
57
+ radius: int = MORGAN_RADIUS,
58
+ num_bits: int = MORGAN_NUM_BITS) -> np.ndarray:
59
+ """
60
+ Generates a binary Morgan fingerprint for a molecule.
61
+
62
+ :param mol: A molecule (i.e. either a SMILES string or an RDKit molecule).
63
+ :param radius: Morgan fingerprint radius.
64
+ :param num_bits: Number of bits in Morgan fingerprint.
65
+ :return: A 1-D numpy array containing the binary Morgan fingerprint.
66
+ """
67
+ mol = Chem.MolFromSmiles(mol) if type(mol) == str else mol
68
+ features_vec = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=num_bits)
69
+ features = np.zeros((1,))
70
+ DataStructs.ConvertToNumpyArray(features_vec, features)
71
+
72
+ return features
73
+
74
+
75
+ @register_features_generator('morgan_count')
76
+ def morgan_counts_features_generator(mol: Molecule,
77
+ radius: int = MORGAN_RADIUS,
78
+ num_bits: int = MORGAN_NUM_BITS) -> np.ndarray:
79
+ """
80
+ Generates a counts-based Morgan fingerprint for a molecule.
81
+
82
+ :param mol: A molecule (i.e. either a SMILES string or an RDKit molecule).
83
+ :param radius: Morgan fingerprint radius.
84
+ :param num_bits: Number of bits in Morgan fingerprint.
85
+ :return: A 1D numpy array containing the counts-based Morgan fingerprint.
86
+ """
87
+ mol = Chem.MolFromSmiles(mol) if type(mol) == str else mol
88
+ features_vec = AllChem.GetHashedMorganFingerprint(mol, radius, nBits=num_bits)
89
+ features = np.zeros((1,))
90
+ DataStructs.ConvertToNumpyArray(features_vec, features)
91
+
92
+ return features
93
+
94
+
95
+ try:
96
+ from descriptastorus.descriptors import rdDescriptors, rdNormalizedDescriptors
97
+
98
+ @register_features_generator('rdkit_2d')
99
+ def rdkit_2d_features_generator(mol: Molecule) -> np.ndarray:
100
+ """
101
+ Generates RDKit 2D features for a molecule.
102
+
103
+ :param mol: A molecule (i.e. either a SMILES string or an RDKit molecule).
104
+ :return: A 1D numpy array containing the RDKit 2D features.
105
+ """
106
+ smiles = Chem.MolToSmiles(mol, isomericSmiles=True) if type(mol) != str else mol
107
+ generator = rdDescriptors.RDKit2D()
108
+ features = generator.process(smiles)[1:]
109
+
110
+ return features
111
+
112
+ @register_features_generator('rdkit_2d_normalized')
113
+ def rdkit_2d_features_normalized_generator(mol: Molecule) -> np.ndarray:
114
+ """
115
+ Generates RDKit 2D normalized features for a molecule.
116
+
117
+ :param mol: A molecule (i.e. either a SMILES string or an RDKit molecule).
118
+ :return: A 1D numpy array containing the RDKit 2D normalized features.
119
+ """
120
+ smiles = Chem.MolToSmiles(mol, isomericSmiles=True) if type(mol) != str else mol
121
+ generator = rdNormalizedDescriptors.RDKit2DNormalized()
122
+ features = generator.process(smiles)[1:]
123
+ return features
124
+ except ImportError:
125
+ pass
126
+
127
+ """
128
+ Custom features generator template.
129
+
130
+ Note: The name you use to register the features generator is the name
131
+ you will specify on the command line when using the --features_generator <name> flag.
132
+ Ex. python train.py ... --features_generator custom ...
133
+
134
+ @register_features_generator('custom')
135
+ def custom_features_generator(mol: Molecule) -> np.ndarray:
136
+ # If you want to use the SMILES string
137
+ smiles = Chem.MolToSmiles(mol, isomericSmiles=True) if type(mol) != str else mol
138
+
139
+ # If you want to use the RDKit molecule
140
+ mol = Chem.MolFromSmiles(mol) if type(mol) == str else mol
141
+
142
+ # Replace this with code which generates features from the molecule
143
+ features = np.array([0, 0, 1])
144
+
145
+ return features
146
+ """
grover/data/molgraph.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The data structure of Molecules.
3
+ This implementation is adapted from
4
+ https://github.com/chemprop/chemprop/blob/master/chemprop/features/featurization.py
5
+ """
6
+ from argparse import Namespace
7
+ from typing import List, Tuple, Union
8
+
9
+ import numpy as np
10
+ import torch
11
+ from rdkit import Chem
12
+
13
+ # Atom feature sizes
14
+ MAX_ATOMIC_NUM = 100
15
+
16
+
17
+ ATOM_FEATURES = {
18
+ 'atomic_num': list(range(MAX_ATOMIC_NUM)),
19
+ 'degree': [0, 1, 2, 3, 4, 5],
20
+ 'formal_charge': [-1, -2, 1, 2, 0],
21
+ 'chiral_tag': [0, 1, 2, 3],
22
+ 'num_Hs': [0, 1, 2, 3, 4],
23
+ 'hybridization': [
24
+ Chem.rdchem.HybridizationType.SP,
25
+ Chem.rdchem.HybridizationType.SP2,
26
+ Chem.rdchem.HybridizationType.SP3,
27
+ Chem.rdchem.HybridizationType.SP3D,
28
+ Chem.rdchem.HybridizationType.SP3D2
29
+ ],
30
+ }
31
+
32
+ # len(choices) + 1 to include room for uncommon values; + 2 at end for IsAromatic and mass
33
+ ATOM_FDIM = sum(len(choices) + 1 for choices in ATOM_FEATURES.values()) + 2
34
+ BOND_FDIM = 14
35
+
36
+
37
+ def get_atom_fdim() -> int:
38
+ """
39
+ Gets the dimensionality of atom features.
40
+
41
+ :param: Arguments.
42
+ """
43
+ return ATOM_FDIM + 18
44
+
45
+
46
+ def get_bond_fdim() -> int:
47
+ """
48
+ Gets the dimensionality of bond features.
49
+
50
+ :param: Arguments.
51
+ """
52
+ return BOND_FDIM
53
+
54
+
55
+ def onek_encoding_unk(value: int, choices: List[int]) -> List[int]:
56
+ """
57
+ Creates a one-hot encoding.
58
+
59
+ :param value: The value for which the encoding should be one.
60
+ :param choices: A list of possible values.
61
+ :return: A one-hot encoding of the value in a list of length len(choices) + 1.
62
+ If value is not in the list of choices, then the final element in the encoding is 1.
63
+ """
64
+ encoding = [0] * (len(choices) + 1)
65
+ if min(choices) < 0:
66
+ index = value
67
+ else:
68
+ index = choices.index(value) if value in choices else -1
69
+ encoding[index] = 1
70
+
71
+ return encoding
72
+
73
+
74
+
75
+
76
+ class MolGraph:
77
+ """
78
+ A MolGraph represents the graph structure and featurization of a single molecule.
79
+
80
+ A MolGraph computes the following attributes:
81
+ - smiles: Smiles string.
82
+ - n_atoms: The number of atoms in the molecule.
83
+ - n_bonds: The number of bonds in the molecule.
84
+ - f_atoms: A mapping from an atom index to a list atom features.
85
+ - f_bonds: A mapping from a bond index to a list of bond features.
86
+ - a2b: A mapping from an atom index to a list of incoming bond indices.
87
+ - b2a: A mapping from a bond index to the index of the atom the bond originates from.
88
+ - b2revb: A mapping from a bond index to the index of the reverse bond.
89
+ """
90
+
91
+ def __init__(self, smiles: str, args: Namespace):
92
+ """
93
+ Computes the graph structure and featurization of a molecule.
94
+
95
+ :param smiles: A smiles string.
96
+ :param args: Arguments.
97
+ """
98
+ self.smiles = smiles
99
+ self.args = args
100
+ self.n_atoms = 0 # number of atoms
101
+ self.n_bonds = 0 # number of bonds
102
+ self.f_atoms = [] # mapping from atom index to atom features
103
+ self.f_bonds = [] # mapping from bond index to concat(in_atom, bond) features
104
+ self.a2b = [] # mapping from atom index to incoming bond indices
105
+ self.b2a = [] # mapping from bond index to the index of the atom the bond is coming from
106
+ self.b2revb = [] # mapping from bond index to the index of the reverse bond
107
+
108
+ # Convert smiles to molecule
109
+ mol = Chem.MolFromSmiles(smiles)
110
+
111
+ self.hydrogen_donor = Chem.MolFromSmarts("[$([N;!H0;v3,v4&+1]),$([O,S;H1;+0]),n&H1&+0]")
112
+ self.hydrogen_acceptor = Chem.MolFromSmarts(
113
+ "[$([O,S;H1;v2;!$(*-*=[O,N,P,S])]),$([O,S;H0;v2]),$([O,S;-]),$([N;v3;!$(N-*=[O,N,P,S])]),"
114
+ "n&H0&+0,$([o,s;+0;!$([o,s]:n);!$([o,s]:c:n)])]")
115
+ self.acidic = Chem.MolFromSmarts("[$([C,S](=[O,S,P])-[O;H1,-1])]")
116
+ self.basic = Chem.MolFromSmarts(
117
+ "[#7;+,$([N;H2&+0][$([C,a]);!$([C,a](=O))]),$([N;H1&+0]([$([C,a]);!$([C,a](=O))])[$([C,a]);"
118
+ "!$([C,a](=O))]),$([N;H0&+0]([C;!$(C(=O))])([C;!$(C(=O))])[C;!$(C(=O))])]")
119
+
120
+ self.hydrogen_donor_match = sum(mol.GetSubstructMatches(self.hydrogen_donor), ())
121
+ self.hydrogen_acceptor_match = sum(mol.GetSubstructMatches(self.hydrogen_acceptor), ())
122
+ self.acidic_match = sum(mol.GetSubstructMatches(self.acidic), ())
123
+ self.basic_match = sum(mol.GetSubstructMatches(self.basic), ())
124
+ self.ring_info = mol.GetRingInfo()
125
+
126
+
127
+ # fake the number of "atoms" if we are collapsing substructures
128
+ self.n_atoms = mol.GetNumAtoms()
129
+
130
+ # Get atom features
131
+ for _, atom in enumerate(mol.GetAtoms()):
132
+ self.f_atoms.append(self.atom_features(atom))
133
+ self.f_atoms = [self.f_atoms[i] for i in range(self.n_atoms)]
134
+
135
+ for _ in range(self.n_atoms):
136
+ self.a2b.append([])
137
+
138
+ # Get bond features
139
+ for a1 in range(self.n_atoms):
140
+ for a2 in range(a1 + 1, self.n_atoms):
141
+ bond = mol.GetBondBetweenAtoms(a1, a2)
142
+
143
+ if bond is None:
144
+ continue
145
+
146
+ if args.bond_drop_rate > 0:
147
+ if np.random.binomial(1, args.bond_drop_rate):
148
+ continue
149
+
150
+ f_bond = self.bond_features(bond)
151
+
152
+ # Always treat the bond as directed.
153
+ self.f_bonds.append(self.f_atoms[a1] + f_bond)
154
+ self.f_bonds.append(self.f_atoms[a2] + f_bond)
155
+
156
+ # Update index mappings
157
+ b1 = self.n_bonds
158
+ b2 = b1 + 1
159
+ self.a2b[a2].append(b1) # b1 = a1 --> a2
160
+ self.b2a.append(a1)
161
+ self.a2b[a1].append(b2) # b2 = a2 --> a1
162
+ self.b2a.append(a2)
163
+ self.b2revb.append(b2)
164
+ self.b2revb.append(b1)
165
+ self.n_bonds += 2
166
+
167
+ def atom_features(self, atom: Chem.rdchem.Atom) -> List[Union[bool, int, float]]:
168
+ """
169
+ Builds a feature vector for an atom.
170
+
171
+ :param atom: An RDKit atom.
172
+ :param functional_groups: A k-hot vector indicating the functional groups the atom belongs to.
173
+ :return: A list containing the atom features.
174
+ """
175
+ features = onek_encoding_unk(atom.GetAtomicNum() - 1, ATOM_FEATURES['atomic_num']) + \
176
+ onek_encoding_unk(atom.GetTotalDegree(), ATOM_FEATURES['degree']) + \
177
+ onek_encoding_unk(atom.GetFormalCharge(), ATOM_FEATURES['formal_charge']) + \
178
+ onek_encoding_unk(int(atom.GetChiralTag()), ATOM_FEATURES['chiral_tag']) + \
179
+ onek_encoding_unk(int(atom.GetTotalNumHs()), ATOM_FEATURES['num_Hs']) + \
180
+ onek_encoding_unk(int(atom.GetHybridization()), ATOM_FEATURES['hybridization']) + \
181
+ [1 if atom.GetIsAromatic() else 0] + \
182
+ [atom.GetMass() * 0.01]
183
+ atom_idx = atom.GetIdx()
184
+ features = features + \
185
+ onek_encoding_unk(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6]) + \
186
+ [atom_idx in self.hydrogen_acceptor_match] + \
187
+ [atom_idx in self.hydrogen_donor_match] + \
188
+ [atom_idx in self.acidic_match] + \
189
+ [atom_idx in self.basic_match] + \
190
+ [self.ring_info.IsAtomInRingOfSize(atom_idx, 3),
191
+ self.ring_info.IsAtomInRingOfSize(atom_idx, 4),
192
+ self.ring_info.IsAtomInRingOfSize(atom_idx, 5),
193
+ self.ring_info.IsAtomInRingOfSize(atom_idx, 6),
194
+ self.ring_info.IsAtomInRingOfSize(atom_idx, 7),
195
+ self.ring_info.IsAtomInRingOfSize(atom_idx, 8)]
196
+ return features
197
+
198
+ def bond_features(self, bond: Chem.rdchem.Bond
199
+ ) -> List[Union[bool, int, float]]:
200
+ """
201
+ Builds a feature vector for a bond.
202
+
203
+ :param bond: A RDKit bond.
204
+ :return: A list containing the bond features.
205
+ """
206
+
207
+ if bond is None:
208
+ fbond = [1] + [0] * (BOND_FDIM - 1)
209
+ else:
210
+ bt = bond.GetBondType()
211
+ fbond = [
212
+ 0, # bond is not None
213
+ bt == Chem.rdchem.BondType.SINGLE,
214
+ bt == Chem.rdchem.BondType.DOUBLE,
215
+ bt == Chem.rdchem.BondType.TRIPLE,
216
+ bt == Chem.rdchem.BondType.AROMATIC,
217
+ (bond.GetIsConjugated() if bt is not None else 0),
218
+ (bond.IsInRing() if bt is not None else 0)
219
+ ]
220
+ fbond += onek_encoding_unk(int(bond.GetStereo()), list(range(6)))
221
+ return fbond
222
+
223
+
224
+ class BatchMolGraph:
225
+ """
226
+ A BatchMolGraph represents the graph structure and featurization of a batch of molecules.
227
+
228
+ A BatchMolGraph contains the attributes of a MolGraph plus:
229
+ - smiles_batch: A list of smiles strings.
230
+ - n_mols: The number of molecules in the batch.
231
+ - atom_fdim: The dimensionality of the atom features.
232
+ - bond_fdim: The dimensionality of the bond features (technically the combined atom/bond features).
233
+ - a_scope: A list of tuples indicating the start and end atom indices for each molecule.
234
+ - b_scope: A list of tuples indicating the start and end bond indices for each molecule.
235
+ - max_num_bonds: The maximum number of bonds neighboring an atom in this batch.
236
+ - b2b: (Optional) A mapping from a bond index to incoming bond indices.
237
+ - a2a: (Optional): A mapping from an atom index to neighboring atom indices.
238
+ """
239
+
240
+ def __init__(self, mol_graphs: List[MolGraph], args: Namespace):
241
+ self.smiles_batch = [mol_graph.smiles for mol_graph in mol_graphs]
242
+ self.n_mols = len(self.smiles_batch)
243
+
244
+ self.atom_fdim = get_atom_fdim()
245
+ self.bond_fdim = get_bond_fdim() + self.atom_fdim
246
+
247
+ # Start n_atoms and n_bonds at 1 b/c zero padding
248
+ self.n_atoms = 1 # number of atoms (start at 1 b/c need index 0 as padding)
249
+ self.n_bonds = 1 # number of bonds (start at 1 b/c need index 0 as padding)
250
+ self.a_scope = [] # list of tuples indicating (start_atom_index, num_atoms) for each molecule
251
+ self.b_scope = [] # list of tuples indicating (start_bond_index, num_bonds) for each molecule
252
+
253
+ # All start with zero padding so that indexing with zero padding returns zeros
254
+ f_atoms = [[0] * self.atom_fdim] # atom features
255
+ f_bonds = [[0] * self.bond_fdim] # combined atom/bond features
256
+ a2b = [[]] # mapping from atom index to incoming bond indices
257
+ b2a = [0] # mapping from bond index to the index of the atom the bond is coming from
258
+ b2revb = [0] # mapping from bond index to the index of the reverse bond
259
+
260
+ for mol_graph in mol_graphs:
261
+ f_atoms.extend(mol_graph.f_atoms)
262
+ f_bonds.extend(mol_graph.f_bonds)
263
+
264
+ for a in range(mol_graph.n_atoms):
265
+ a2b.append([b + self.n_bonds for b in mol_graph.a2b[a]])
266
+
267
+ for b in range(mol_graph.n_bonds):
268
+ b2a.append(self.n_atoms + mol_graph.b2a[b])
269
+ b2revb.append(self.n_bonds + mol_graph.b2revb[b])
270
+
271
+ self.a_scope.append((self.n_atoms, mol_graph.n_atoms))
272
+ self.b_scope.append((self.n_bonds, mol_graph.n_bonds))
273
+ self.n_atoms += mol_graph.n_atoms
274
+ self.n_bonds += mol_graph.n_bonds
275
+
276
+ # max with 1 to fix a crash in rare case of all single-heavy-atom mols
277
+ self.max_num_bonds = max(1, max(len(in_bonds) for in_bonds in a2b))
278
+
279
+ self.f_atoms = torch.FloatTensor(f_atoms)
280
+ self.f_bonds = torch.FloatTensor(f_bonds)
281
+ self.a2b = torch.LongTensor([a2b[a] + [0] * (self.max_num_bonds - len(a2b[a])) for a in range(self.n_atoms)])
282
+ self.b2a = torch.LongTensor(b2a)
283
+ self.b2revb = torch.LongTensor(b2revb)
284
+ self.b2b = None # try to avoid computing b2b b/c O(n_atoms^3)
285
+ self.a2a = self.b2a[self.a2b] # only needed if using atom messages
286
+ self.a_scope = torch.LongTensor(self.a_scope)
287
+ self.b_scope = torch.LongTensor(self.b_scope)
288
+
289
+ def set_new_atom_feature(self, f_atoms):
290
+ """
291
+ Set the new atom feature. Do not update bond feature.
292
+ :param f_atoms:
293
+ """
294
+ self.f_atoms = f_atoms
295
+
296
+ def get_components(self) -> Tuple[torch.FloatTensor, torch.FloatTensor,
297
+ torch.LongTensor, torch.LongTensor, torch.LongTensor,
298
+ List[Tuple[int, int]], List[Tuple[int, int]]]:
299
+ """
300
+ Returns the components of the BatchMolGraph.
301
+
302
+ :return: A tuple containing PyTorch tensors with the atom features, bond features, and graph structure
303
+ and two lists indicating the scope of the atoms and bonds (i.e. which molecules they belong to).
304
+ """
305
+ return self.f_atoms, self.f_bonds, self.a2b, self.b2a, self.b2revb, self.a_scope, self.b_scope, self.a2a
306
+
307
+ def get_b2b(self) -> torch.LongTensor:
308
+ """
309
+ Computes (if necessary) and returns a mapping from each bond index to all the incoming bond indices.
310
+
311
+ :return: A PyTorch tensor containing the mapping from each bond index to all the incoming bond indices.
312
+ """
313
+
314
+ if self.b2b is None:
315
+ b2b = self.a2b[self.b2a] # num_bonds x max_num_bonds
316
+ # b2b includes reverse edge for each bond so need to mask out
317
+ revmask = (b2b != self.b2revb.unsqueeze(1).repeat(1, b2b.size(1))).long() # num_bonds x max_num_bonds
318
+ self.b2b = b2b * revmask
319
+
320
+ return self.b2b
321
+
322
+ def get_a2a(self) -> torch.LongTensor:
323
+ """
324
+ Computes (if necessary) and returns a mapping from each atom index to all neighboring atom indices.
325
+
326
+ :return: A PyTorch tensor containing the mapping from each bond index to all the incodming bond indices.
327
+ """
328
+ if self.a2a is None:
329
+ # b = a1 --> a2
330
+ # a2b maps a2 to all incoming bonds b
331
+ # b2a maps each bond b to the atom it comes from a1
332
+ # thus b2a[a2b] maps atom a2 to neighboring atoms a1
333
+ self.a2a = self.b2a[self.a2b] # num_atoms x max_num_bonds
334
+
335
+ return self.a2a
336
+
337
+
338
+ def mol2graph(smiles_batch: List[str], shared_dict,
339
+ args: Namespace) -> BatchMolGraph:
340
+ """
341
+ Converts a list of SMILES strings to a BatchMolGraph containing the batch of molecular graphs.
342
+
343
+ :param smiles_batch: A list of SMILES strings.
344
+ :param args: Arguments.
345
+ :return: A BatchMolGraph containing the combined molecular graph for the molecules
346
+ """
347
+ mol_graphs = []
348
+ for smiles in smiles_batch:
349
+ if smiles in shared_dict:
350
+ mol_graph = shared_dict[smiles]
351
+ else:
352
+ mol_graph = MolGraph(smiles, args)
353
+ if not args.no_cache:
354
+ shared_dict[smiles] = mol_graph
355
+ mol_graphs.append(mol_graph)
356
+
357
+ return BatchMolGraph(mol_graphs, args)
358
+
359
+
360
+ class MolCollator(object):
361
+ """
362
+ Collator for pytorch dataloader
363
+ :param shared_dict: a shared dict of multiprocess.
364
+ :param args: Arguments.
365
+ """
366
+ def __init__(self, shared_dict, args):
367
+ self.args = args
368
+ self.shared_dict = shared_dict
369
+
370
+ def __call__(self, batch):
371
+ smiles_batch = [d.smiles for d in batch]
372
+ features_batch = [d.features for d in batch]
373
+ target_batch = [d.targets for d in batch]
374
+ batch_mol_graph = mol2graph(smiles_batch, self.shared_dict, self.args)
375
+ batch = batch_mol_graph.get_components()
376
+ mask = torch.Tensor([[x is not None for x in tb] for tb in target_batch])
377
+ targets = torch.Tensor([[0 if x is None else x for x in tb] for tb in target_batch])
378
+ return smiles_batch, batch, features_batch, mask, targets
grover/data/scaler.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The scaler for the regression task.
3
+ This implementation is adapted from
4
+ https://github.com/chemprop/chemprop/blob/master/chemprop/data/scaler.py
5
+ """
6
+ from typing import Any, List
7
+ import numpy as np
8
+
9
+
10
+ class StandardScaler:
11
+ """A StandardScaler normalizes a dataset.
12
+
13
+ When fit on a dataset, the StandardScaler learns the mean and standard deviation across the 0th axis.
14
+ When transforming a dataset, the StandardScaler subtracts the means and divides by the standard deviations.
15
+ """
16
+
17
+ def __init__(self, means: np.ndarray = None, stds: np.ndarray = None, replace_nan_token: Any = None):
18
+ """
19
+ Initialize StandardScaler, optionally with means and standard deviations precomputed.
20
+
21
+ :param means: An optional 1D numpy array of precomputed means.
22
+ :param stds: An optional 1D numpy array of precomputed standard deviations.
23
+ :param replace_nan_token: The token to use in place of nans.
24
+ """
25
+ self.means = means
26
+ self.stds = stds
27
+ self.replace_nan_token = replace_nan_token
28
+
29
+ def fit(self, X: List[List[float]]) -> 'StandardScaler':
30
+ """
31
+ Learns means and standard deviations across the 0th axis.
32
+
33
+ :param X: A list of lists of floats.
34
+ :return: The fitted StandardScaler.
35
+ """
36
+ X = np.array(X).astype(float)
37
+ self.means = np.nanmean(X, axis=0)
38
+ self.stds = np.nanstd(X, axis=0)
39
+ self.means = np.where(np.isnan(self.means), np.zeros(self.means.shape), self.means)
40
+ self.stds = np.where(np.isnan(self.stds), np.ones(self.stds.shape), self.stds)
41
+ self.stds = np.where(self.stds == 0, np.ones(self.stds.shape), self.stds)
42
+
43
+ return self
44
+
45
+ def transform(self, X: List[List[float]]):
46
+ """
47
+ Transforms the data by subtracting the means and dividing by the standard deviations.
48
+
49
+ :param X: A list of lists of floats.
50
+ :return: The transformed data.
51
+ """
52
+ X = np.array(X).astype(float)
53
+ transformed_with_nan = (X - self.means) / self.stds
54
+ transformed_with_none = np.where(np.isnan(transformed_with_nan), self.replace_nan_token, transformed_with_nan)
55
+
56
+ return transformed_with_none
57
+
58
+ def inverse_transform(self, X: List[List[float]]):
59
+ """
60
+ Performs the inverse transformation by multiplying by the standard deviations and adding the means.
61
+
62
+ :param X: A list of lists of floats.
63
+ :return: The inverse transformed data.
64
+ """
65
+ if isinstance(X, np.ndarray) or isinstance(X, list):
66
+ X = np.array(X).astype(float)
67
+ transformed_with_nan = X * self.stds + self.means
68
+ transformed_with_none = np.where(np.isnan(transformed_with_nan),
69
+ self.replace_nan_token, transformed_with_nan)
70
+ return transformed_with_none
grover/data/task_labels.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The label generator for the pretraining.
3
+ """
4
+ from collections import Counter
5
+ from typing import Callable, Union
6
+
7
+ import numpy as np
8
+ from rdkit import Chem
9
+ from descriptastorus.descriptors import rdDescriptors
10
+
11
+ from grover.data.molfeaturegenerator import register_features_generator
12
+
13
+ Molecule = Union[str, Chem.Mol]
14
+ FeaturesGenerator = Callable[[Molecule], np.ndarray]
15
+
16
+ # The functional group descriptors in RDkit.
17
+ RDKIT_PROPS = ['fr_Al_COO', 'fr_Al_OH', 'fr_Al_OH_noTert', 'fr_ArN',
18
+ 'fr_Ar_COO', 'fr_Ar_N', 'fr_Ar_NH', 'fr_Ar_OH', 'fr_COO', 'fr_COO2',
19
+ 'fr_C_O', 'fr_C_O_noCOO', 'fr_C_S', 'fr_HOCCN', 'fr_Imine', 'fr_NH0',
20
+ 'fr_NH1', 'fr_NH2', 'fr_N_O', 'fr_Ndealkylation1', 'fr_Ndealkylation2',
21
+ 'fr_Nhpyrrole', 'fr_SH', 'fr_aldehyde', 'fr_alkyl_carbamate', 'fr_alkyl_halide',
22
+ 'fr_allylic_oxid', 'fr_amide', 'fr_amidine', 'fr_aniline', 'fr_aryl_methyl',
23
+ 'fr_azide', 'fr_azo', 'fr_barbitur', 'fr_benzene', 'fr_benzodiazepine',
24
+ 'fr_bicyclic', 'fr_diazo', 'fr_dihydropyridine', 'fr_epoxide', 'fr_ester',
25
+ 'fr_ether', 'fr_furan', 'fr_guanido', 'fr_halogen', 'fr_hdrzine', 'fr_hdrzone',
26
+ 'fr_imidazole', 'fr_imide', 'fr_isocyan', 'fr_isothiocyan', 'fr_ketone',
27
+ 'fr_ketone_Topliss', 'fr_lactam', 'fr_lactone', 'fr_methoxy', 'fr_morpholine',
28
+ 'fr_nitrile', 'fr_nitro', 'fr_nitro_arom', 'fr_nitro_arom_nonortho',
29
+ 'fr_nitroso', 'fr_oxazole', 'fr_oxime', 'fr_para_hydroxylation', 'fr_phenol',
30
+ 'fr_phenol_noOrthoHbond', 'fr_phos_acid', 'fr_phos_ester', 'fr_piperdine',
31
+ 'fr_piperzine', 'fr_priamide', 'fr_prisulfonamd', 'fr_pyridine', 'fr_quatN',
32
+ 'fr_sulfide', 'fr_sulfonamd', 'fr_sulfone', 'fr_term_acetylene', 'fr_tetrazole',
33
+ 'fr_thiazole', 'fr_thiocyan', 'fr_thiophene', 'fr_unbrch_alkane', 'fr_urea']
34
+
35
+ BOND_FEATURES = ['BondType', 'Stereo', 'BondDir']
36
+
37
+
38
+ # BOND_FEATURES = ['BondType', 'Stereo']
39
+ # BOND_FEATURES = ['Stereo']
40
+
41
+ @register_features_generator('fgtasklabel')
42
+ def rdkit_functional_group_label_features_generator(mol: Molecule) -> np.ndarray:
43
+ """
44
+ Generates functional group label for a molecule using RDKit.
45
+
46
+ :param mol: A molecule (i.e. either a SMILES string or an RDKit molecule).
47
+ :return: A 1D numpy array containing the RDKit 2D features.
48
+ """
49
+ smiles = Chem.MolToSmiles(mol, isomericSmiles=True) if type(mol) != str else mol
50
+ generator = rdDescriptors.RDKit2D(RDKIT_PROPS)
51
+ features = generator.process(smiles)[1:]
52
+ features = np.array(features)
53
+ features[features != 0] = 1
54
+ return features
55
+
56
+
57
+ def atom_to_vocab(mol, atom):
58
+ """
59
+ Convert atom to vocabulary. The convention is based on atom type and bond type.
60
+ :param mol: the molecular.
61
+ :param atom: the target atom.
62
+ :return: the generated atom vocabulary with its contexts.
63
+ """
64
+ nei = Counter()
65
+ for a in atom.GetNeighbors():
66
+ bond = mol.GetBondBetweenAtoms(atom.GetIdx(), a.GetIdx())
67
+ nei[str(a.GetSymbol()) + "-" + str(bond.GetBondType())] += 1
68
+ keys = nei.keys()
69
+ keys = list(keys)
70
+ keys.sort()
71
+ output = atom.GetSymbol()
72
+ for k in keys:
73
+ output = "%s_%s%d" % (output, k, nei[k])
74
+
75
+ # The generated atom_vocab is too long?
76
+ return output
77
+
78
+
79
+ def bond_to_vocab(mol, bond):
80
+ """
81
+ Convert bond to vocabulary. The convention is based on atom type and bond type.
82
+ Considering one-hop neighbor atoms
83
+ :param mol: the molecular.
84
+ :param atom: the target atom.
85
+ :return: the generated bond vocabulary with its contexts.
86
+ """
87
+ nei = Counter()
88
+ two_neighbors = (bond.GetBeginAtom(), bond.GetEndAtom())
89
+ two_indices = [a.GetIdx() for a in two_neighbors]
90
+ for nei_atom in two_neighbors:
91
+ for a in nei_atom.GetNeighbors():
92
+ a_idx = a.GetIdx()
93
+ if a_idx in two_indices:
94
+ continue
95
+ tmp_bond = mol.GetBondBetweenAtoms(nei_atom.GetIdx(), a_idx)
96
+ nei[str(nei_atom.GetSymbol()) + '-' + get_bond_feature_name(tmp_bond)] += 1
97
+ keys = list(nei.keys())
98
+ keys.sort()
99
+ output = get_bond_feature_name(bond)
100
+ for k in keys:
101
+ output = "%s_%s%d" % (output, k, nei[k])
102
+ return output
103
+
104
+
105
+ def get_bond_feature_name(bond):
106
+ """
107
+ Return the string format of bond features.
108
+ Bond features are surrounded with ()
109
+
110
+ """
111
+ ret = []
112
+ for bond_feature in BOND_FEATURES:
113
+ fea = eval(f"bond.Get{bond_feature}")()
114
+ ret.append(str(fea))
115
+
116
+ return '(' + '-'.join(ret) + ')'
grover/data/torchvocab.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The contextual property.
3
+ """
4
+ import pickle
5
+ from collections import Counter
6
+ from multiprocessing import Pool
7
+
8
+ import tqdm
9
+ from rdkit import Chem
10
+
11
+ from grover.data.task_labels import atom_to_vocab
12
+ from grover.data.task_labels import bond_to_vocab
13
+
14
+
15
+ class TorchVocab(object):
16
+ """
17
+ Defines the vocabulary for atoms/bonds in molecular.
18
+ """
19
+
20
+ def __init__(self, counter, max_size=None, min_freq=1, specials=('<pad>', '<other>'), vocab_type='atom'):
21
+ """
22
+
23
+ :param counter:
24
+ :param max_size:
25
+ :param min_freq:
26
+ :param specials:
27
+ :param vocab_type: 'atom': atom atom_vocab; 'bond': bond atom_vocab.
28
+ """
29
+ self.freqs = counter
30
+ counter = counter.copy()
31
+ min_freq = max(min_freq, 1)
32
+ if vocab_type in ('atom', 'bond'):
33
+ self.vocab_type = vocab_type
34
+ else:
35
+ raise ValueError('Wrong input for vocab_type!')
36
+ self.itos = list(specials)
37
+
38
+ max_size = None if max_size is None else max_size + len(self.itos)
39
+ # sort by frequency, then alphabetically
40
+ words_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0])
41
+ words_and_frequencies.sort(key=lambda tup: tup[1], reverse=True)
42
+
43
+ for word, freq in words_and_frequencies:
44
+ if freq < min_freq or len(self.itos) == max_size:
45
+ break
46
+ self.itos.append(word)
47
+ # stoi is simply a reverse dict for itos
48
+ self.stoi = {tok: i for i, tok in enumerate(self.itos)}
49
+ self.other_index = 1
50
+ self.pad_index = 0
51
+
52
+ def __eq__(self, other):
53
+ if self.freqs != other.freqs:
54
+ return False
55
+ if self.stoi != other.stoi:
56
+ return False
57
+ if self.itos != other.itos:
58
+ return False
59
+ # if self.vectors != other.vectors:
60
+ # return False
61
+ return True
62
+
63
+ def __len__(self):
64
+ return len(self.itos)
65
+
66
+ def vocab_rerank(self):
67
+ self.stoi = {word: i for i, word in enumerate(self.itos)}
68
+
69
+ def extend(self, v, sort=False):
70
+ words = sorted(v.itos) if sort else v.itos
71
+ for w in words:
72
+ if w not in self.stoi:
73
+ self.itos.append(w)
74
+ self.stoi[w] = len(self.itos) - 1
75
+ self.freqs[w] = 0
76
+ self.freqs[w] += v.freqs[w]
77
+
78
+ def mol_to_seq(self, mol, with_len=False):
79
+ mol = Chem.MolFromSmiles(mol) if type(mol) == str else mol
80
+ if self.vocab_type == 'atom':
81
+ seq = [self.stoi.get(atom_to_vocab(mol, atom), self.other_index) for i, atom in enumerate(mol.GetAtoms())]
82
+ else:
83
+ seq = [self.stoi.get(bond_to_vocab(mol, bond), self.other_index) for i, bond in enumerate(mol.GetBonds())]
84
+ return (seq, len(seq)) if with_len else seq
85
+
86
+ @staticmethod
87
+ def load_vocab(vocab_path: str) -> 'Vocab':
88
+ with open(vocab_path, "rb") as f:
89
+ return pickle.load(f)
90
+
91
+ def save_vocab(self, vocab_path):
92
+ with open(vocab_path, "wb") as f:
93
+ pickle.dump(self, f)
94
+
95
+
96
+ class MolVocab(TorchVocab):
97
+ def __init__(self, smiles, max_size=None, min_freq=1, vocab_type='atom'):
98
+ if vocab_type in ('atom', 'bond'):
99
+ self.vocab_type = vocab_type
100
+ else:
101
+ raise ValueError('Wrong input for vocab_type!')
102
+
103
+ print("Building %s vocab from smiles: %d" % (self.vocab_type, len(smiles)))
104
+ counter = Counter()
105
+
106
+ for smi in tqdm.tqdm(smiles):
107
+ mol = Chem.MolFromSmiles(smi)
108
+ if self.vocab_type == 'atom':
109
+ for _, atom in enumerate(mol.GetAtoms()):
110
+ v = atom_to_vocab(mol, atom)
111
+ counter[v] += 1
112
+ else:
113
+ for _, bond in enumerate(mol.GetBonds()):
114
+ v = bond_to_vocab(mol, bond)
115
+ counter[v] += 1
116
+ super().__init__(counter, max_size=max_size, min_freq=min_freq, vocab_type=vocab_type)
117
+
118
+ def __init__(self, file_path, max_size=None, min_freq=1, num_workers=1, total_lines=None, vocab_type='atom'):
119
+ if vocab_type in ('atom', 'bond'):
120
+ self.vocab_type = vocab_type
121
+ else:
122
+ raise ValueError('Wrong input for vocab_type!')
123
+ print("Building %s vocab from file: %s" % (self.vocab_type, file_path))
124
+
125
+ from rdkit import RDLogger
126
+ lg = RDLogger.logger()
127
+ lg.setLevel(RDLogger.CRITICAL)
128
+
129
+ if total_lines is None:
130
+ def file_len(fname):
131
+ f_len = 0
132
+ with open(fname) as f:
133
+ for f_len, _ in enumerate(f):
134
+ pass
135
+ return f_len + 1
136
+
137
+ total_lines = file_len(file_path)
138
+
139
+ counter = Counter()
140
+ pbar = tqdm.tqdm(total=total_lines)
141
+ pool = Pool(num_workers)
142
+ res = []
143
+ batch = 50000
144
+ callback = lambda a: pbar.update(batch)
145
+ for i in range(int(total_lines / batch + 1)):
146
+ start = int(batch * i)
147
+ end = min(total_lines, batch * (i + 1))
148
+ # print("Start: %d, End: %d"%(start, end))
149
+ res.append(pool.apply_async(MolVocab.read_smiles_from_file,
150
+ args=(file_path, start, end, vocab_type,),
151
+ callback=callback))
152
+ # read_smiles_from_file(lock, file_path, start, end)
153
+ pool.close()
154
+ pool.join()
155
+ for r in res:
156
+ sub_counter = r.get()
157
+ for k in sub_counter:
158
+ if k not in counter:
159
+ counter[k] = 0
160
+ counter[k] += sub_counter[k]
161
+ # print(counter)
162
+ super().__init__(counter, max_size=max_size, min_freq=min_freq, vocab_type=vocab_type)
163
+
164
+ @staticmethod
165
+ def read_smiles_from_file(file_path, start, end, vocab_type):
166
+ # print("start")
167
+ smiles = open(file_path, "r")
168
+ smiles.readline()
169
+ sub_counter = Counter()
170
+ for i, smi in enumerate(smiles):
171
+ if i < start:
172
+ continue
173
+ if i >= end:
174
+ break
175
+ mol = Chem.MolFromSmiles(smi)
176
+ if vocab_type == 'atom':
177
+ for atom in mol.GetAtoms():
178
+ v = atom_to_vocab(mol, atom)
179
+ sub_counter[v] += 1
180
+ else:
181
+ for bond in mol.GetBonds():
182
+ v = bond_to_vocab(mol, bond)
183
+ sub_counter[v] += 1
184
+ # print("end")
185
+ return sub_counter
186
+
187
+ @staticmethod
188
+ def load_vocab(vocab_path: str) -> 'MolVocab':
189
+ with open(vocab_path, "rb") as f:
190
+ return pickle.load(f)
grover/model/layers.py ADDED
@@ -0,0 +1,902 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The basic building blocks in model.
3
+ """
4
+ import math
5
+ from argparse import Namespace
6
+ from typing import Union
7
+
8
+ import numpy
9
+ import scipy.stats as stats
10
+ import torch
11
+ from torch import nn as nn
12
+ from torch.nn import LayerNorm, functional as F
13
+
14
+ from grover.util.nn_utils import get_activation_function, select_neighbor_and_aggregate
15
+
16
+
17
+ class SelfAttention(nn.Module):
18
+ """
19
+ Self SelfAttention Layer
20
+ Given $X\in \mathbb{R}^{n \times in_feature}$, the attention is calculated by: $a=Softmax(W_2tanh(W_1X))$, where
21
+ $W_1 \in \mathbb{R}^{hidden \times in_feature}$, $W_2 \in \mathbb{R}^{out_feature \times hidden}$.
22
+ The final output is: $out=aX$, which is unrelated with input $n$.
23
+ """
24
+
25
+ def __init__(self, *, hidden, in_feature, out_feature):
26
+ """
27
+ The init function.
28
+ :param hidden: the hidden dimension, can be viewed as the number of experts.
29
+ :param in_feature: the input feature dimension.
30
+ :param out_feature: the output feature dimension.
31
+ """
32
+ super(SelfAttention, self).__init__()
33
+ self.w1 = torch.nn.Parameter(torch.FloatTensor(hidden, in_feature))
34
+ self.w2 = torch.nn.Parameter(torch.FloatTensor(out_feature, hidden))
35
+ self.reset_parameters()
36
+
37
+ def reset_parameters(self):
38
+ """
39
+ Use xavier_normal method to initialize parameters.
40
+ """
41
+ nn.init.xavier_normal_(self.w1)
42
+ nn.init.xavier_normal_(self.w2)
43
+
44
+ def forward(self, X):
45
+ """
46
+ The forward function.
47
+ :param X: The input feature map. $X \in \mathbb{R}^{n \times in_feature}$.
48
+ :return: The final embeddings and attention matrix.
49
+ """
50
+ x = torch.tanh(torch.matmul(self.w1, X.transpose(1, 0)))
51
+ x = torch.matmul(self.w2, x)
52
+ attn = torch.nn.functional.softmax(x, dim=-1)
53
+ x = torch.matmul(attn, X)
54
+ return x, attn
55
+
56
+
57
+ class Readout(nn.Module):
58
+ """The readout function. Convert the node embeddings to the graph embeddings."""
59
+
60
+ def __init__(self,
61
+ rtype: str = "none",
62
+ hidden_size: int = 0,
63
+ attn_hidden: int = None,
64
+ attn_out: int = None,
65
+ ):
66
+ """
67
+ The readout function.
68
+ :param rtype: readout type, can be "mean" and "self_attention".
69
+ :param hidden_size: input hidden size
70
+ :param attn_hidden: only valid if rtype == "self_attention". The attention hidden size.
71
+ :param attn_out: only valid if rtype == "self_attention". The attention out size.
72
+ :param args: legacy use.
73
+ """
74
+ super(Readout, self).__init__()
75
+ # Cached zeros
76
+ self.cached_zero_vector = nn.Parameter(torch.zeros(hidden_size), requires_grad=False)
77
+ self.rtype = "mean"
78
+
79
+ if rtype == "self_attention":
80
+ self.attn = SelfAttention(hidden=attn_hidden,
81
+ in_feature=hidden_size,
82
+ out_feature=attn_out)
83
+ self.rtype = "self_attention"
84
+
85
+ def forward(self, embeddings, scope):
86
+ """
87
+ The forward function, given a batch node/edge embedding and a scope list,
88
+ produce the graph-level embedding by a scope.
89
+ :param embeddings: The embedding matrix, num_atoms or num_bonds \times hidden_size.
90
+ :param scope: a list, in which the element is a list [start, range]. `start` is the index
91
+ :return:
92
+ """
93
+ # Readout
94
+ mol_vecs = []
95
+ self.attns = []
96
+ for _, (a_start, a_size) in enumerate(scope):
97
+ if a_size == 0:
98
+ mol_vecs.append(self.cached_zero_vector)
99
+ else:
100
+ cur_hiddens = embeddings.narrow(0, a_start, a_size)
101
+ if self.rtype == "self_attention":
102
+ cur_hiddens, attn = self.attn(cur_hiddens)
103
+ cur_hiddens = cur_hiddens.flatten()
104
+ # Temporarily disable. Enable it if you want to save attentions.
105
+ # self.attns.append(attn.cpu().detach().numpy())
106
+ else:
107
+ cur_hiddens = cur_hiddens.sum(dim=0) / a_size
108
+ mol_vecs.append(cur_hiddens)
109
+
110
+ mol_vecs = torch.stack(mol_vecs, dim=0) # (num_molecules, hidden_size)
111
+ return mol_vecs
112
+
113
+
114
+ class MPNEncoder(nn.Module):
115
+ """A message passing neural network for encoding a molecule."""
116
+
117
+ def __init__(self, args: Namespace,
118
+ atom_messages: bool,
119
+ init_message_dim: int,
120
+ attached_fea_fdim: int,
121
+ hidden_size: int,
122
+ bias: bool,
123
+ depth: int,
124
+ dropout: float,
125
+ undirected: bool,
126
+ dense: bool,
127
+ aggregate_to_atom: bool,
128
+ attach_fea: bool,
129
+ input_layer="fc",
130
+ dynamic_depth='none'
131
+ ):
132
+ """
133
+ Initializes the MPNEncoder.
134
+ :param args: the arguments.
135
+ :param atom_messages: enables atom_messages or not.
136
+ :param init_message_dim: the initial input message dimension.
137
+ :param attached_fea_fdim: the attached feature dimension.
138
+ :param hidden_size: the output message dimension during message passing.
139
+ :param bias: the bias in the message passing.
140
+ :param depth: the message passing depth.
141
+ :param dropout: the dropout rate.
142
+ :param undirected: the message passing is undirected or not.
143
+ :param dense: enables the dense connections.
144
+ :param attach_fea: enables the feature attachment during the message passing process.
145
+ :param dynamic_depth: enables the dynamic depth. Possible choices: "none", "uniform" and "truncnorm"
146
+ """
147
+ super(MPNEncoder, self).__init__()
148
+ self.init_message_dim = init_message_dim
149
+ self.attached_fea_fdim = attached_fea_fdim
150
+ self.hidden_size = hidden_size
151
+ self.bias = bias
152
+ self.depth = depth
153
+ self.dropout = dropout
154
+ self.input_layer = input_layer
155
+ self.layers_per_message = 1
156
+ self.undirected = undirected
157
+ self.atom_messages = atom_messages
158
+ self.dense = dense
159
+ self.aggreate_to_atom = aggregate_to_atom
160
+ self.attached_fea = attach_fea
161
+ self.dynamic_depth = dynamic_depth
162
+
163
+ # Dropout
164
+ self.dropout_layer = nn.Dropout(p=self.dropout)
165
+
166
+ # Activation
167
+ self.act_func = get_activation_function(args.activation)
168
+
169
+ # Input
170
+ if self.input_layer == "fc":
171
+ input_dim = self.init_message_dim
172
+ self.W_i = nn.Linear(input_dim, self.hidden_size, bias=self.bias)
173
+
174
+ if self.attached_fea:
175
+ w_h_input_size = self.hidden_size + self.attached_fea_fdim
176
+ else:
177
+ w_h_input_size = self.hidden_size
178
+
179
+ # Shared weight matrix across depths (default)
180
+ self.W_h = nn.Linear(w_h_input_size, self.hidden_size, bias=self.bias)
181
+
182
+ def forward(self,
183
+ init_messages,
184
+ init_attached_features,
185
+ a2nei,
186
+ a2attached,
187
+ b2a=None,
188
+ b2revb=None,
189
+ adjs=None
190
+ ) -> torch.FloatTensor:
191
+ """
192
+ The forward function.
193
+ :param init_messages: initial massages, can be atom features or bond features.
194
+ :param init_attached_features: initial attached_features.
195
+ :param a2nei: the relation of item to its neighbors. For the atom message passing, a2nei = a2a. For bond
196
+ messages a2nei = a2b
197
+ :param a2attached: the relation of item to the attached features during message passing. For the atom message
198
+ passing, a2attached = a2b. For the bond message passing a2attached = a2a
199
+ :param b2a: remove the reversed bond in bond message passing
200
+ :param b2revb: remove the revered atom in bond message passing
201
+ :return: if aggreate_to_atom or self.atom_messages, return num_atoms x hidden.
202
+ Otherwise, return num_bonds x hidden
203
+ """
204
+
205
+ # Input
206
+ if self.input_layer == 'fc':
207
+ input = self.W_i(init_messages) # num_bonds x hidden_size # f_bond
208
+ message = self.act_func(input) # num_bonds x hidden_size
209
+ elif self.input_layer == 'none':
210
+ input = init_messages
211
+ message = input
212
+
213
+ attached_fea = init_attached_features # f_atom / f_bond
214
+
215
+ # dynamic depth
216
+ # uniform sampling from depth - 1 to depth + 1
217
+ # only works in training.
218
+ if self.training and self.dynamic_depth != "none":
219
+ if self.dynamic_depth == "uniform":
220
+ # uniform sampling
221
+ ndepth = numpy.random.randint(self.depth - 3, self.depth + 3)
222
+ else:
223
+ # truncnorm
224
+ mu = self.depth
225
+ sigma = 1
226
+ lower = mu - 3 * sigma
227
+ upper = mu + 3 * sigma
228
+ X = stats.truncnorm((lower - mu) / sigma, (upper - mu) / sigma, loc=mu, scale=sigma)
229
+ ndepth = int(X.rvs(1))
230
+ else:
231
+ ndepth = self.depth
232
+
233
+ # Message passing
234
+ for _ in range(ndepth - 1):
235
+ if self.undirected:
236
+ # two directions should be the same
237
+ message = (message + message[b2revb]) / 2
238
+
239
+ nei_message = select_neighbor_and_aggregate(message, a2nei)
240
+ a_message = nei_message
241
+ if self.attached_fea:
242
+ attached_nei_fea = select_neighbor_and_aggregate(attached_fea, a2attached)
243
+ a_message = torch.cat((nei_message, attached_nei_fea), dim=1)
244
+
245
+ if not self.atom_messages:
246
+ rev_message = message[b2revb]
247
+ if self.attached_fea:
248
+ atom_rev_message = attached_fea[b2a[b2revb]]
249
+ rev_message = torch.cat((rev_message, atom_rev_message), dim=1)
250
+ # Except reverse bond its-self(w) ! \sum_{k\in N(u) \ w}
251
+ message = a_message[b2a] - rev_message # num_bonds x hidden
252
+ else:
253
+ message = a_message
254
+
255
+ message = self.W_h(message)
256
+
257
+ # BUG here, by default MPNEncoder use the dense connection in the message passing step.
258
+ # The correct form should if not self.dense
259
+ if self.dense:
260
+ message = self.act_func(message) # num_bonds x hidden_size
261
+ else:
262
+ message = self.act_func(input + message)
263
+ message = self.dropout_layer(message) # num_bonds x hidden
264
+
265
+ output = message
266
+
267
+ return output # num_atoms x hidden
268
+
269
+
270
+ class PositionwiseFeedForward(nn.Module):
271
+ """Implements FFN equation."""
272
+
273
+ def __init__(self, d_model, d_ff, activation="PReLU", dropout=0.1, d_out=None):
274
+ """Initialization.
275
+
276
+ :param d_model: the input dimension.
277
+ :param d_ff: the hidden dimension.
278
+ :param activation: the activation function.
279
+ :param dropout: the dropout rate.
280
+ :param d_out: the output dimension, the default value is equal to d_model.
281
+ """
282
+ super(PositionwiseFeedForward, self).__init__()
283
+ if d_out is None:
284
+ d_out = d_model
285
+ # By default, bias is on.
286
+ self.W_1 = nn.Linear(d_model, d_ff)
287
+ self.W_2 = nn.Linear(d_ff, d_out)
288
+ self.dropout = nn.Dropout(dropout)
289
+ self.act_func = get_activation_function(activation)
290
+
291
+ def forward(self, x):
292
+ """
293
+ The forward function
294
+ :param x: input tensor.
295
+ :return:
296
+ """
297
+ return self.W_2(self.dropout(self.act_func(self.W_1(x))))
298
+
299
+
300
+ class SublayerConnection(nn.Module):
301
+ """
302
+ A residual connection followed by a layer norm.
303
+ Note for code simplicity the norm is first as opposed to last.
304
+ """
305
+
306
+ def __init__(self, size, dropout):
307
+ """Initialization.
308
+
309
+ :param size: the input dimension.
310
+ :param dropout: the dropout ratio.
311
+ """
312
+ super(SublayerConnection, self).__init__()
313
+ self.norm = LayerNorm(size, elementwise_affine=True)
314
+ self.dropout = nn.Dropout(dropout)
315
+
316
+ def forward(self, inputs, outputs):
317
+ """Apply residual connection to any sublayer with the same size."""
318
+ # return x + self.dropout(self.norm(x))
319
+ if inputs is None:
320
+ return self.dropout(self.norm(outputs))
321
+ return inputs + self.dropout(self.norm(outputs))
322
+
323
+
324
+ class Attention(nn.Module):
325
+ """
326
+ Compute 'Scaled Dot Product SelfAttention
327
+ """
328
+
329
+ def forward(self, query, key, value, mask=None, dropout=None):
330
+ """
331
+ :param query:
332
+ :param key:
333
+ :param value:
334
+ :param mask:
335
+ :param dropout:
336
+ :return:
337
+ """
338
+ scores = torch.matmul(query, key.transpose(-2, -1)) \
339
+ / math.sqrt(query.size(-1))
340
+
341
+ if mask is not None:
342
+ scores = scores.masked_fill(mask == 0, -1e9)
343
+
344
+ p_attn = F.softmax(scores, dim=-1)
345
+
346
+ if dropout is not None:
347
+ p_attn = dropout(p_attn)
348
+
349
+ return torch.matmul(p_attn, value), p_attn
350
+
351
+
352
+ class MultiHeadedAttention(nn.Module):
353
+ """
354
+ The multi-head attention module. Take in model size and number of heads.
355
+ """
356
+
357
+ def __init__(self, h, d_model, dropout=0.1, bias=False):
358
+ """
359
+
360
+ :param h:
361
+ :param d_model:
362
+ :param dropout:
363
+ :param bias:
364
+ """
365
+ super().__init__()
366
+ assert d_model % h == 0
367
+
368
+ # We assume d_v always equals d_k
369
+ self.d_k = d_model // h
370
+ self.h = h # number of heads
371
+
372
+ self.linear_layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(3)]) # why 3: query, key, value
373
+ self.output_linear = nn.Linear(d_model, d_model, bias)
374
+ self.attention = Attention()
375
+
376
+ self.dropout = nn.Dropout(p=dropout)
377
+
378
+ def forward(self, query, key, value, mask=None):
379
+ """
380
+
381
+ :param query:
382
+ :param key:
383
+ :param value:
384
+ :param mask:
385
+ :return:
386
+ """
387
+ batch_size = query.size(0)
388
+
389
+ # 1) Do all the linear projections in batch from d_model => h x d_k
390
+ query, key, value = [l(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
391
+ for l, x in zip(self.linear_layers, (query, key, value))]
392
+
393
+ # 2) Apply attention on all the projected vectors in batch.
394
+ x, _ = self.attention(query, key, value, mask=mask, dropout=self.dropout)
395
+
396
+ # 3) "Concat" using a view and apply a final linear.
397
+ x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k)
398
+
399
+ return self.output_linear(x)
400
+
401
+
402
+ class Head(nn.Module):
403
+ """
404
+ One head for multi-headed attention.
405
+ :return: (query, key, value)
406
+ """
407
+
408
+ def __init__(self, args, hidden_size, atom_messages=False):
409
+ """
410
+ Initialization.
411
+ :param args: The argument.
412
+ :param hidden_size: the dimension of hidden layer in Head.
413
+ :param atom_messages: the MPNEncoder type.
414
+ """
415
+ super(Head, self).__init__()
416
+ atom_fdim = hidden_size
417
+ bond_fdim = hidden_size
418
+ hidden_size = hidden_size
419
+ self.atom_messages = atom_messages
420
+ if self.atom_messages:
421
+ init_message_dim = atom_fdim
422
+ attached_fea_dim = bond_fdim
423
+ else:
424
+ init_message_dim = bond_fdim
425
+ attached_fea_dim = atom_fdim
426
+
427
+ # Here we use the message passing network as query, key and value.
428
+ self.mpn_q = MPNEncoder(args=args,
429
+ atom_messages=atom_messages,
430
+ init_message_dim=init_message_dim,
431
+ attached_fea_fdim=attached_fea_dim,
432
+ hidden_size=hidden_size,
433
+ bias=args.bias,
434
+ depth=args.depth,
435
+ dropout=args.dropout,
436
+ undirected=args.undirected,
437
+ dense=args.dense,
438
+ aggregate_to_atom=False,
439
+ attach_fea=False,
440
+ input_layer="none",
441
+ dynamic_depth="truncnorm")
442
+ self.mpn_k = MPNEncoder(args=args,
443
+ atom_messages=atom_messages,
444
+ init_message_dim=init_message_dim,
445
+ attached_fea_fdim=attached_fea_dim,
446
+ hidden_size=hidden_size,
447
+ bias=args.bias,
448
+ depth=args.depth,
449
+ dropout=args.dropout,
450
+ undirected=args.undirected,
451
+ dense=args.dense,
452
+ aggregate_to_atom=False,
453
+ attach_fea=False,
454
+ input_layer="none",
455
+ dynamic_depth="truncnorm")
456
+ self.mpn_v = MPNEncoder(args=args,
457
+ atom_messages=atom_messages,
458
+ init_message_dim=init_message_dim,
459
+ attached_fea_fdim=attached_fea_dim,
460
+ hidden_size=hidden_size,
461
+ bias=args.bias,
462
+ depth=args.depth,
463
+ dropout=args.dropout,
464
+ undirected=args.undirected,
465
+ dense=args.dense,
466
+ aggregate_to_atom=False,
467
+ attach_fea=False,
468
+ input_layer="none",
469
+ dynamic_depth="truncnorm")
470
+
471
+ def forward(self, f_atoms, f_bonds, a2b, a2a, b2a, b2revb):
472
+ """
473
+ The forward function.
474
+ :param f_atoms: the atom features, num_atoms * atom_dim
475
+ :param f_bonds: the bond features, num_bonds * bond_dim
476
+ :param a2b: mapping from atom index to incoming bond indices.
477
+ :param a2a: mapping from atom index to its neighbors. num_atoms * max_num_bonds
478
+ :param b2a: mapping from bond index to the index of the atom the bond is coming from.
479
+ :param b2revb: mapping from bond index to the index of the reverse bond.
480
+ :return:
481
+ """
482
+ if self.atom_messages:
483
+ init_messages = f_atoms
484
+ init_attached_features = f_bonds
485
+ a2nei = a2a
486
+ a2attached = a2b
487
+ b2a = b2a
488
+ b2revb = b2revb
489
+ else:
490
+ init_messages = f_bonds
491
+ init_attached_features = f_atoms
492
+ a2nei = a2b
493
+ a2attached = a2a
494
+ b2a = b2a
495
+ b2revb = b2revb
496
+
497
+ q = self.mpn_q(init_messages=init_messages,
498
+ init_attached_features=init_attached_features,
499
+ a2nei=a2nei,
500
+ a2attached=a2attached,
501
+ b2a=b2a,
502
+ b2revb=b2revb)
503
+ k = self.mpn_k(init_messages=init_messages,
504
+ init_attached_features=init_attached_features,
505
+ a2nei=a2nei,
506
+ a2attached=a2attached,
507
+ b2a=b2a,
508
+ b2revb=b2revb)
509
+ v = self.mpn_v(init_messages=init_messages,
510
+ init_attached_features=init_attached_features,
511
+ a2nei=a2nei,
512
+ a2attached=a2attached,
513
+ b2a=b2a,
514
+ b2revb=b2revb)
515
+ return q, k, v
516
+
517
+
518
+ class MTBlock(nn.Module):
519
+ """
520
+ The Multi-headed attention block.
521
+ """
522
+
523
+ def __init__(self,
524
+ args,
525
+ num_attn_head,
526
+ input_dim,
527
+ hidden_size,
528
+ activation="ReLU",
529
+ dropout=0.0,
530
+ bias=True,
531
+ atom_messages=False,
532
+ cuda=True,
533
+ res_connection=False):
534
+ """
535
+
536
+ :param args: the arguments.
537
+ :param num_attn_head: the number of attention head.
538
+ :param input_dim: the input dimension.
539
+ :param hidden_size: the hidden size of the model.
540
+ :param activation: the activation function.
541
+ :param dropout: the dropout ratio
542
+ :param bias: if true: all linear layer contains bias term.
543
+ :param atom_messages: the MPNEncoder type
544
+ :param cuda: if true, the model run with GPU.
545
+ :param res_connection: enables the skip-connection in MTBlock.
546
+ """
547
+ super(MTBlock, self).__init__()
548
+ # self.args = args
549
+ self.atom_messages = atom_messages
550
+ self.hidden_size = hidden_size
551
+ self.heads = nn.ModuleList()
552
+ self.input_dim = input_dim
553
+ self.cuda = cuda
554
+ self.res_connection = res_connection
555
+ self.act_func = get_activation_function(activation)
556
+ self.dropout_layer = nn.Dropout(p=dropout)
557
+ # Note: elementwise_affine has to be consistent with the pre-training phase
558
+ self.layernorm = nn.LayerNorm(self.hidden_size, elementwise_affine=True)
559
+
560
+ self.W_i = nn.Linear(self.input_dim, self.hidden_size, bias=bias)
561
+ self.attn = MultiHeadedAttention(h=num_attn_head,
562
+ d_model=self.hidden_size,
563
+ bias=bias,
564
+ dropout=dropout)
565
+ self.W_o = nn.Linear(self.hidden_size * num_attn_head, self.hidden_size, bias=bias)
566
+ self.sublayer = SublayerConnection(self.hidden_size, dropout)
567
+ for _ in range(num_attn_head):
568
+ self.heads.append(Head(args, hidden_size=hidden_size, atom_messages=atom_messages))
569
+
570
+ def forward(self, batch, features_batch=None):
571
+ """
572
+
573
+ :param batch: the graph batch generated by GroverCollator.
574
+ :param features_batch: the additional features of molecules. (deprecated)
575
+ :return:
576
+ """
577
+ f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope, a2a = batch
578
+
579
+ if self.atom_messages:
580
+ # Only add linear transformation in the input feature.
581
+ if f_atoms.shape[1] != self.hidden_size:
582
+ f_atoms = self.W_i(f_atoms)
583
+ f_atoms = self.dropout_layer(self.layernorm(self.act_func(f_atoms)))
584
+
585
+ else: # bond messages
586
+ if f_bonds.shape[1] != self.hidden_size:
587
+ f_bonds = self.W_i(f_bonds)
588
+ f_bonds = self.dropout_layer(self.layernorm(self.act_func(f_bonds)))
589
+
590
+ queries = []
591
+ keys = []
592
+ values = []
593
+ for head in self.heads:
594
+ q, k, v = head(f_atoms, f_bonds, a2b, a2a, b2a, b2revb)
595
+ queries.append(q.unsqueeze(1))
596
+ keys.append(k.unsqueeze(1))
597
+ values.append(v.unsqueeze(1))
598
+ queries = torch.cat(queries, dim=1)
599
+ keys = torch.cat(keys, dim=1)
600
+ values = torch.cat(values, dim=1)
601
+
602
+ x_out = self.attn(queries, keys, values) # multi-headed attention
603
+ x_out = x_out.view(x_out.shape[0], -1)
604
+ x_out = self.W_o(x_out)
605
+
606
+ x_in = None
607
+ # support no residual connection in MTBlock.
608
+ if self.res_connection:
609
+ if self.atom_messages:
610
+ x_in = f_atoms
611
+ else:
612
+ x_in = f_bonds
613
+
614
+ if self.atom_messages:
615
+ f_atoms = self.sublayer(x_in, x_out)
616
+ else:
617
+ f_bonds = self.sublayer(x_in, x_out)
618
+
619
+ batch = f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope, a2a
620
+ features_batch = features_batch
621
+ return batch, features_batch
622
+
623
+
624
+ class GTransEncoder(nn.Module):
625
+ def __init__(self,
626
+ args,
627
+ hidden_size,
628
+ edge_fdim,
629
+ node_fdim,
630
+ dropout=0.0,
631
+ activation="ReLU",
632
+ num_mt_block=1,
633
+ num_attn_head=4,
634
+ atom_emb_output: Union[bool, str] = False, # options: True, False, None, "atom", "bond", "both"
635
+ bias=False,
636
+ cuda=True,
637
+ res_connection=False):
638
+ """
639
+
640
+ :param args: the arguments.
641
+ :param hidden_size: the hidden size of the model.
642
+ :param edge_fdim: the dimension of additional feature for edge/bond.
643
+ :param node_fdim: the dimension of additional feature for node/atom.
644
+ :param dropout: the dropout ratio
645
+ :param activation: the activation function
646
+ :param num_mt_block: the number of mt block.
647
+ :param num_attn_head: the number of attention head.
648
+ :param atom_emb_output: enable the output aggregation after message passing.
649
+ atom_messages: True False
650
+ -False: no aggregating to atom. output size: (num_atoms, hidden_size) (num_bonds, hidden_size)
651
+ -True: aggregating to atom. output size: (num_atoms, hidden_size) (num_atoms, hidden_size)
652
+ -None: same as False
653
+ -"atom": same as True
654
+ -"bond": aggragating to bond. output size: (num_bonds, hidden_size) (num_bonds, hidden_size)
655
+ -"both": aggregating to atom&bond. output size: (num_atoms, hidden_size) (num_bonds, hidden_size)
656
+ (num_bonds, hidden_size) (num_atoms, hidden_size)
657
+ :param bias: enable bias term in all linear layers.
658
+ :param cuda: run with cuda.
659
+ :param res_connection: enables the skip-connection in MTBlock.
660
+ """
661
+ super(GTransEncoder, self).__init__()
662
+
663
+ # For the compatibility issue.
664
+ if atom_emb_output is False:
665
+ atom_emb_output = None
666
+ if atom_emb_output is True:
667
+ atom_emb_output = 'atom'
668
+
669
+ self.hidden_size = hidden_size
670
+ self.dropout = dropout
671
+ self.activation = activation
672
+ self.cuda = cuda
673
+ self.bias = bias
674
+ self.res_connection = res_connection
675
+ self.edge_blocks = nn.ModuleList()
676
+ self.node_blocks = nn.ModuleList()
677
+
678
+ edge_input_dim = edge_fdim
679
+ node_input_dim = node_fdim
680
+ edge_input_dim_i = edge_input_dim
681
+ node_input_dim_i = node_input_dim
682
+
683
+ for i in range(num_mt_block):
684
+ if i != 0:
685
+ edge_input_dim_i = self.hidden_size
686
+ node_input_dim_i = self.hidden_size
687
+ self.edge_blocks.append(MTBlock(args=args,
688
+ num_attn_head=num_attn_head,
689
+ input_dim=edge_input_dim_i,
690
+ hidden_size=self.hidden_size,
691
+ activation=activation,
692
+ dropout=dropout,
693
+ bias=self.bias,
694
+ atom_messages=False,
695
+ cuda=cuda))
696
+ self.node_blocks.append(MTBlock(args=args,
697
+ num_attn_head=num_attn_head,
698
+ input_dim=node_input_dim_i,
699
+ hidden_size=self.hidden_size,
700
+ activation=activation,
701
+ dropout=dropout,
702
+ bias=self.bias,
703
+ atom_messages=True,
704
+ cuda=cuda))
705
+
706
+ self.atom_emb_output = atom_emb_output
707
+
708
+ self.ffn_atom_from_atom = PositionwiseFeedForward(self.hidden_size + node_fdim,
709
+ self.hidden_size * 4,
710
+ activation=self.activation,
711
+ dropout=self.dropout,
712
+ d_out=self.hidden_size)
713
+
714
+ self.ffn_atom_from_bond = PositionwiseFeedForward(self.hidden_size + node_fdim,
715
+ self.hidden_size * 4,
716
+ activation=self.activation,
717
+ dropout=self.dropout,
718
+ d_out=self.hidden_size)
719
+
720
+ self.ffn_bond_from_atom = PositionwiseFeedForward(self.hidden_size + edge_fdim,
721
+ self.hidden_size * 4,
722
+ activation=self.activation,
723
+ dropout=self.dropout,
724
+ d_out=self.hidden_size)
725
+
726
+ self.ffn_bond_from_bond = PositionwiseFeedForward(self.hidden_size + edge_fdim,
727
+ self.hidden_size * 4,
728
+ activation=self.activation,
729
+ dropout=self.dropout,
730
+ d_out=self.hidden_size)
731
+
732
+ self.atom_from_atom_sublayer = SublayerConnection(size=self.hidden_size, dropout=self.dropout)
733
+ self.atom_from_bond_sublayer = SublayerConnection(size=self.hidden_size, dropout=self.dropout)
734
+ self.bond_from_atom_sublayer = SublayerConnection(size=self.hidden_size, dropout=self.dropout)
735
+ self.bond_from_bond_sublayer = SublayerConnection(size=self.hidden_size, dropout=self.dropout)
736
+
737
+ self.act_func_node = get_activation_function(self.activation)
738
+ self.act_func_edge = get_activation_function(self.activation)
739
+
740
+ self.dropout_layer = nn.Dropout(p=args.dropout)
741
+
742
+ def pointwise_feed_forward_to_atom_embedding(self, emb_output, atom_fea, index, ffn_layer):
743
+ """
744
+ The point-wise feed forward and long-range residual connection for atom view.
745
+ aggregate to atom.
746
+ :param emb_output: the output embedding from the previous multi-head attentions.
747
+ :param atom_fea: the atom/node feature embedding.
748
+ :param index: the index of neighborhood relations.
749
+ :param ffn_layer: the feed forward layer
750
+ :return:
751
+ """
752
+ aggr_output = select_neighbor_and_aggregate(emb_output, index)
753
+ aggr_outputx = torch.cat([atom_fea, aggr_output], dim=1)
754
+ return ffn_layer(aggr_outputx), aggr_output
755
+
756
+ def pointwise_feed_forward_to_bond_embedding(self, emb_output, bond_fea, a2nei, b2revb, ffn_layer):
757
+ """
758
+ The point-wise feed forward and long-range residual connection for bond view.
759
+ aggregate to bond.
760
+ :param emb_output: the output embedding from the previous multi-head attentions.
761
+ :param bond_fea: the bond/edge feature embedding.
762
+ :param index: the index of neighborhood relations.
763
+ :param ffn_layer: the feed forward layer
764
+ :return:
765
+ """
766
+ aggr_output = select_neighbor_and_aggregate(emb_output, a2nei)
767
+ # remove rev bond / atom --- need for bond view
768
+ aggr_output = self.remove_rev_bond_message(emb_output, aggr_output, b2revb)
769
+ aggr_outputx = torch.cat([bond_fea, aggr_output], dim=1)
770
+ return ffn_layer(aggr_outputx), aggr_output
771
+
772
+ @staticmethod
773
+ def remove_rev_bond_message(orginal_message, aggr_message, b2revb):
774
+ """
775
+
776
+ :param orginal_message:
777
+ :param aggr_message:
778
+ :param b2revb:
779
+ :return:
780
+ """
781
+ rev_message = orginal_message[b2revb]
782
+ return aggr_message - rev_message
783
+
784
+ def atom_bond_transform(self,
785
+ to_atom=True, # False: to bond
786
+ atomwise_input=None,
787
+ bondwise_input=None,
788
+ original_f_atoms=None,
789
+ original_f_bonds=None,
790
+ a2a=None,
791
+ a2b=None,
792
+ b2a=None,
793
+ b2revb=None
794
+ ):
795
+ """
796
+ Transfer the output of atom/bond multi-head attention to the final atom/bond output.
797
+ :param to_atom: if true, the output is atom emebedding, otherwise, the output is bond embedding.
798
+ :param atomwise_input: the input embedding of atom/node.
799
+ :param bondwise_input: the input embedding of bond/edge.
800
+ :param original_f_atoms: the initial atom features.
801
+ :param original_f_bonds: the initial bond features.
802
+ :param a2a: mapping from atom index to its neighbors. num_atoms * max_num_bonds
803
+ :param a2b: mapping from atom index to incoming bond indices.
804
+ :param b2a: mapping from bond index to the index of the atom the bond is coming from.
805
+ :param b2revb: mapping from bond index to the index of the reverse bond.
806
+ :return:
807
+ """
808
+
809
+ if to_atom:
810
+ # atom input to atom output
811
+ atomwise_input, _ = self.pointwise_feed_forward_to_atom_embedding(atomwise_input, original_f_atoms, a2a,
812
+ self.ffn_atom_from_atom)
813
+ atom_in_atom_out = self.atom_from_atom_sublayer(None, atomwise_input)
814
+ # bond to atom
815
+ bondwise_input, _ = self.pointwise_feed_forward_to_atom_embedding(bondwise_input, original_f_atoms, a2b,
816
+ self.ffn_atom_from_bond)
817
+ bond_in_atom_out = self.atom_from_bond_sublayer(None, bondwise_input)
818
+ return atom_in_atom_out, bond_in_atom_out
819
+ else: # to bond embeddings
820
+
821
+ # atom input to bond output
822
+ atom_list_for_bond = torch.cat([b2a.unsqueeze(dim=1), a2a[b2a]], dim=1)
823
+ atomwise_input, _ = self.pointwise_feed_forward_to_bond_embedding(atomwise_input, original_f_bonds,
824
+ atom_list_for_bond,
825
+ b2a[b2revb], self.ffn_bond_from_atom)
826
+ atom_in_bond_out = self.bond_from_atom_sublayer(None, atomwise_input)
827
+ # bond input to bond output
828
+ bond_list_for_bond = a2b[b2a]
829
+ bondwise_input, _ = self.pointwise_feed_forward_to_bond_embedding(bondwise_input, original_f_bonds,
830
+ bond_list_for_bond,
831
+ b2revb, self.ffn_bond_from_bond)
832
+ bond_in_bond_out = self.bond_from_bond_sublayer(None, bondwise_input)
833
+ return atom_in_bond_out, bond_in_bond_out
834
+
835
+ def forward(self, batch, features_batch = None):
836
+ f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope, a2a = batch
837
+ if self.cuda or next(self.parameters()).is_cuda:
838
+ f_atoms, f_bonds, a2b, b2a, b2revb = f_atoms.cuda(), f_bonds.cuda(), a2b.cuda(), b2a.cuda(), b2revb.cuda()
839
+ a2a = a2a.cuda()
840
+
841
+ node_batch = f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope, a2a
842
+ edge_batch = f_atoms, f_bonds, a2b, b2a, b2revb, a_scope, b_scope, a2a
843
+
844
+ # opt pointwise_feed_forward
845
+ original_f_atoms, original_f_bonds = f_atoms, f_bonds
846
+
847
+ # Note: features_batch is not used here.
848
+ for nb in self.node_blocks: # atom messages. Multi-headed attention
849
+ node_batch, features_batch = nb(node_batch, features_batch)
850
+ for eb in self.edge_blocks: # bond messages. Multi-headed attention
851
+ edge_batch, features_batch = eb(edge_batch, features_batch)
852
+
853
+ atom_output, _, _, _, _, _, _, _ = node_batch # atom hidden states
854
+ _, bond_output, _, _, _, _, _, _ = edge_batch # bond hidden states
855
+
856
+ if self.atom_emb_output is None:
857
+ # output the embedding from multi-head attention directly.
858
+ return atom_output, bond_output
859
+
860
+ if self.atom_emb_output == 'atom':
861
+ return self.atom_bond_transform(to_atom=True, # False: to bond
862
+ atomwise_input=atom_output,
863
+ bondwise_input=bond_output,
864
+ original_f_atoms=original_f_atoms,
865
+ original_f_bonds=original_f_bonds,
866
+ a2a=a2a,
867
+ a2b=a2b,
868
+ b2a=b2a,
869
+ b2revb=b2revb)
870
+ elif self.atom_emb_output == 'bond':
871
+ return self.atom_bond_transform(to_atom=False, # False: to bond
872
+ atomwise_input=atom_output,
873
+ bondwise_input=bond_output,
874
+ original_f_atoms=original_f_atoms,
875
+ original_f_bonds=original_f_bonds,
876
+ a2a=a2a,
877
+ a2b=a2b,
878
+ b2a=b2a,
879
+ b2revb=b2revb)
880
+ else: # 'both'
881
+ atom_embeddings = self.atom_bond_transform(to_atom=True, # False: to bond
882
+ atomwise_input=atom_output,
883
+ bondwise_input=bond_output,
884
+ original_f_atoms=original_f_atoms,
885
+ original_f_bonds=original_f_bonds,
886
+ a2a=a2a,
887
+ a2b=a2b,
888
+ b2a=b2a,
889
+ b2revb=b2revb)
890
+
891
+ bond_embeddings = self.atom_bond_transform(to_atom=False, # False: to bond
892
+ atomwise_input=atom_output,
893
+ bondwise_input=bond_output,
894
+ original_f_atoms=original_f_atoms,
895
+ original_f_bonds=original_f_bonds,
896
+ a2a=a2a,
897
+ a2b=a2b,
898
+ b2a=b2a,
899
+ b2revb=b2revb)
900
+ # Notice: need to be consistent with output format of DualMPNN encoder
901
+ return ((atom_embeddings[0], bond_embeddings[0]),
902
+ (atom_embeddings[1], bond_embeddings[1]))
grover/model/models.py ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The GROVER models for pretraining, finetuning and fingerprint generating.
3
+ """
4
+ from argparse import Namespace
5
+ from typing import List, Dict, Callable
6
+
7
+ import numpy as np
8
+ import torch
9
+ from torch import nn as nn
10
+
11
+ from grover.data import get_atom_fdim, get_bond_fdim
12
+ from grover.model.layers import Readout, GTransEncoder
13
+ from grover.util.nn_utils import get_activation_function
14
+
15
+
16
+ class GROVEREmbedding(nn.Module):
17
+ """
18
+ The GROVER Embedding class. It contains the GTransEncoder.
19
+ This GTransEncoder can be replaced by any validate encoders.
20
+ """
21
+
22
+ def __init__(self, args: Namespace):
23
+ """
24
+ Initialize the GROVEREmbedding class.
25
+ :param args:
26
+ """
27
+ super(GROVEREmbedding, self).__init__()
28
+ self.embedding_output_type = args.embedding_output_type
29
+ edge_dim = get_bond_fdim() + get_atom_fdim()
30
+ node_dim = get_atom_fdim()
31
+ if not hasattr(args, "backbone"):
32
+ print("No backbone specified in args, use gtrans backbone.")
33
+ args.backbone = "gtrans"
34
+ if args.backbone == "gtrans" or args.backbone == "dualtrans":
35
+ # dualtrans is the old name.
36
+ self.encoders = GTransEncoder(args,
37
+ hidden_size=args.hidden_size,
38
+ edge_fdim=edge_dim,
39
+ node_fdim=node_dim,
40
+ dropout=args.dropout,
41
+ activation=args.activation,
42
+ num_mt_block=args.num_mt_block,
43
+ num_attn_head=args.num_attn_head,
44
+ atom_emb_output=self.embedding_output_type,
45
+ bias=args.bias,
46
+ cuda=args.cuda)
47
+
48
+ def forward(self, graph_batch: List) -> Dict:
49
+ """
50
+ The forward function takes graph_batch as input and output a dict. The content of the dict is decided by
51
+ self.embedding_output_type.
52
+
53
+ :param graph_batch: the input graph batch generated by MolCollator.
54
+ :return: a dict containing the embedding results.
55
+ """
56
+ output = self.encoders(graph_batch)
57
+ if self.embedding_output_type == 'atom':
58
+ return {"atom_from_atom": output[0], "atom_from_bond": output[1],
59
+ "bond_from_atom": None, "bond_from_bond": None} # atom_from_atom, atom_from_bond
60
+ elif self.embedding_output_type == 'bond':
61
+ return {"atom_from_atom": None, "atom_from_bond": None,
62
+ "bond_from_atom": output[0], "bond_from_bond": output[1]} # bond_from_atom, bond_from_bond
63
+ elif self.embedding_output_type == "both":
64
+ return {"atom_from_atom": output[0][0], "bond_from_atom": output[0][1],
65
+ "atom_from_bond": output[1][0], "bond_from_bond": output[1][1]}
66
+
67
+
68
+ class AtomVocabPrediction(nn.Module):
69
+ """
70
+ The atom-wise vocabulary prediction task. The atom vocabulary is constructed by the context.
71
+ """
72
+ def __init__(self, args, vocab_size, hidden_size=None):
73
+ """
74
+ :param args: the argument.
75
+ :param vocab_size: the size of atom vocabulary.
76
+ """
77
+ super(AtomVocabPrediction, self).__init__()
78
+ if not hidden_size:
79
+ hidden_size = args.hidden_size
80
+ self.linear = nn.Linear(hidden_size, vocab_size)
81
+ self.logsoftmax = nn.LogSoftmax(dim=1)
82
+
83
+ def forward(self, embeddings):
84
+ """
85
+ If embeddings is None: do not go through forward pass.
86
+ :param embeddings: the atom embeddings, num_atom X fea_dim.
87
+ :return: the prediction for each atom, num_atom X vocab_size.
88
+ """
89
+ if embeddings is None:
90
+ return None
91
+ return self.logsoftmax(self.linear(embeddings))
92
+
93
+
94
+ class BondVocabPrediction(nn.Module):
95
+ """
96
+ The bond-wise vocabulary prediction task. The bond vocabulary is constructed by the context.
97
+ """
98
+ def __init__(self, args, vocab_size, hidden_size=None):
99
+ """
100
+ Might need to use different architecture for bond vocab prediction.
101
+ :param args:
102
+ :param vocab_size: size of bond vocab.
103
+ :param hidden_size: hidden size
104
+ """
105
+ super(BondVocabPrediction, self).__init__()
106
+ if not hidden_size:
107
+ hidden_size = args.hidden_size
108
+ self.linear = nn.Linear(hidden_size, vocab_size)
109
+
110
+ # ad-hoc here
111
+ # If TWO_FC_4_BOND_VOCAB, we will use two distinct fc layer to deal with the bond and rev bond.
112
+ self.TWO_FC_4_BOND_VOCAB = True
113
+ if self.TWO_FC_4_BOND_VOCAB:
114
+ self.linear_rev = nn.Linear(hidden_size, vocab_size)
115
+ self.logsoftmax = nn.LogSoftmax(dim=1)
116
+
117
+ def forward(self, embeddings):
118
+ """
119
+ If embeddings is None: do not go through forward pass.
120
+ :param embeddings: the atom embeddings, num_bond X fea_dim.
121
+ :return: the prediction for each atom, num_bond X vocab_size.
122
+ """
123
+ if embeddings is None:
124
+ return None
125
+ nm_bonds = embeddings.shape[0] # must be an odd number
126
+ # The bond and rev bond have odd and even ids respectively. See definition in molgraph.
127
+ ids1 = [0] + list(range(1, nm_bonds, 2))
128
+ ids2 = list(range(0, nm_bonds, 2))
129
+ if self.TWO_FC_4_BOND_VOCAB:
130
+ logits = self.linear(embeddings[ids1]) + self.linear_rev(embeddings[ids2])
131
+ else:
132
+ logits = self.linear(embeddings[ids1] + embeddings[ids2])
133
+
134
+ return self.logsoftmax(logits)
135
+
136
+
137
+ class FunctionalGroupPrediction(nn.Module):
138
+ """
139
+ The functional group (semantic motifs) prediction task. This is a graph-level task.
140
+ """
141
+ def __init__(self, args, fg_size):
142
+ """
143
+ :param args: The arguments.
144
+ :param fg_size: The size of semantic motifs.
145
+ """
146
+ super(FunctionalGroupPrediction, self).__init__()
147
+ first_linear_dim = args.hidden_size
148
+ hidden_size = args.hidden_size
149
+
150
+ # In order to retain maximal information in the encoder, we use a simple readout function here.
151
+ self.readout = Readout(rtype="mean", hidden_size=hidden_size)
152
+ # We have four branches here. But the input with less than four branch is OK.
153
+ # Since we use BCEWithLogitsLoss as the loss function, we only need to output logits here.
154
+ self.linear_atom_from_atom = nn.Linear(first_linear_dim, fg_size)
155
+ self.linear_atom_from_bond = nn.Linear(first_linear_dim, fg_size)
156
+ self.linear_bond_from_atom = nn.Linear(first_linear_dim, fg_size)
157
+ self.linear_bond_from_bond = nn.Linear(first_linear_dim, fg_size)
158
+
159
+ def forward(self, embeddings: Dict, ascope: List, bscope: List) -> Dict:
160
+ """
161
+ The forward function of semantic motif prediction. It takes the node/bond embeddings, and the corresponding
162
+ atom/bond scope as input and produce the prediction logits for different branches.
163
+ :param embeddings: The input embeddings are organized as dict. The output of GROVEREmbedding.
164
+ :param ascope: The scope for bonds. Please refer BatchMolGraph for more details.
165
+ :param bscope: The scope for aotms. Please refer BatchMolGraph for more details.
166
+ :return: a dict contains the predicted logits.
167
+ """
168
+
169
+ preds_atom_from_atom, preds_atom_from_bond, preds_bond_from_atom, preds_bond_from_bond = \
170
+ None, None, None, None
171
+
172
+ if embeddings["bond_from_atom"] is not None:
173
+ preds_bond_from_atom = self.linear_bond_from_atom(self.readout(embeddings["bond_from_atom"], bscope))
174
+ if embeddings["bond_from_bond"] is not None:
175
+ preds_bond_from_bond = self.linear_bond_from_bond(self.readout(embeddings["bond_from_bond"], bscope))
176
+
177
+ if embeddings["atom_from_atom"] is not None:
178
+ preds_atom_from_atom = self.linear_atom_from_atom(self.readout(embeddings["atom_from_atom"], ascope))
179
+ if embeddings["atom_from_bond"] is not None:
180
+ preds_atom_from_bond = self.linear_atom_from_bond(self.readout(embeddings["atom_from_bond"], ascope))
181
+
182
+ return {"atom_from_atom": preds_atom_from_atom, "atom_from_bond": preds_atom_from_bond,
183
+ "bond_from_atom": preds_bond_from_atom, "bond_from_bond": preds_bond_from_bond}
184
+
185
+
186
+ class GroverTask(nn.Module):
187
+ """
188
+ The pretrain module.
189
+ """
190
+ def __init__(self, args, grover, atom_vocab_size, bond_vocab_size, fg_size):
191
+ super(GroverTask, self).__init__()
192
+ self.grover = grover
193
+ self.av_task_atom = AtomVocabPrediction(args, atom_vocab_size)
194
+ self.av_task_bond = AtomVocabPrediction(args, atom_vocab_size)
195
+ self.bv_task_atom = BondVocabPrediction(args, bond_vocab_size)
196
+ self.bv_task_bond = BondVocabPrediction(args, bond_vocab_size)
197
+
198
+ self.fg_task_all = FunctionalGroupPrediction(args, fg_size)
199
+
200
+ self.embedding_output_type = args.embedding_output_type
201
+
202
+ @staticmethod
203
+ def get_loss_func(args: Namespace) -> Callable:
204
+ """
205
+ The loss function generator.
206
+ :param args: the arguments.
207
+ :return: the loss fucntion for GroverTask.
208
+ """
209
+ def loss_func(preds, targets, dist_coff=args.dist_coff):
210
+ """
211
+ The loss function for GroverTask.
212
+ :param preds: the predictions.
213
+ :param targets: the targets.
214
+ :param dist_coff: the default disagreement coefficient for the distances between different branches.
215
+ :return:
216
+ """
217
+ av_task_loss = nn.NLLLoss(ignore_index=0, reduction="mean") # same for av and bv
218
+
219
+ fg_task_loss = nn.BCEWithLogitsLoss(reduction="mean")
220
+ # av_task_dist_loss = nn.KLDivLoss(reduction="mean")
221
+ av_task_dist_loss = nn.MSELoss(reduction="mean")
222
+ fg_task_dist_loss = nn.MSELoss(reduction="mean")
223
+ sigmoid = nn.Sigmoid()
224
+
225
+ av_atom_loss, av_bond_loss, av_dist_loss = 0.0, 0.0, 0.0
226
+ fg_atom_from_atom_loss, fg_atom_from_bond_loss, fg_atom_dist_loss = 0.0, 0.0, 0.0
227
+ bv_atom_loss, bv_bond_loss, bv_dist_loss = 0.0, 0.0, 0.0
228
+ fg_bond_from_atom_loss, fg_bond_from_bond_loss, fg_bond_dist_loss = 0.0, 0.0, 0.0
229
+
230
+ if preds["av_task"][0] is not None:
231
+ av_atom_loss = av_task_loss(preds['av_task'][0], targets["av_task"])
232
+ fg_atom_from_atom_loss = fg_task_loss(preds["fg_task"]["atom_from_atom"], targets["fg_task"])
233
+
234
+ if preds["av_task"][1] is not None:
235
+ av_bond_loss = av_task_loss(preds['av_task'][1], targets["av_task"])
236
+ fg_atom_from_bond_loss = fg_task_loss(preds["fg_task"]["atom_from_bond"], targets["fg_task"])
237
+
238
+ if preds["bv_task"][0] is not None:
239
+ bv_atom_loss = av_task_loss(preds['bv_task'][0], targets["bv_task"])
240
+ fg_bond_from_atom_loss = fg_task_loss(preds["fg_task"]["bond_from_atom"], targets["fg_task"])
241
+
242
+ if preds["bv_task"][1] is not None:
243
+ bv_bond_loss = av_task_loss(preds['bv_task'][1], targets["bv_task"])
244
+ fg_bond_from_bond_loss = fg_task_loss(preds["fg_task"]["bond_from_bond"], targets["fg_task"])
245
+
246
+ if preds["av_task"][0] is not None and preds["av_task"][1] is not None:
247
+ av_dist_loss = av_task_dist_loss(preds['av_task'][0], preds['av_task'][1])
248
+ fg_atom_dist_loss = fg_task_dist_loss(sigmoid(preds["fg_task"]["atom_from_atom"]),
249
+ sigmoid(preds["fg_task"]["atom_from_bond"]))
250
+
251
+ if preds["bv_task"][0] is not None and preds["bv_task"][1] is not None:
252
+ bv_dist_loss = av_task_dist_loss(preds['bv_task'][0], preds['bv_task'][1])
253
+ fg_bond_dist_loss = fg_task_dist_loss(sigmoid(preds["fg_task"]["bond_from_atom"]),
254
+ sigmoid(preds["fg_task"]["bond_from_bond"]))
255
+
256
+ av_loss = av_atom_loss + av_bond_loss
257
+ bv_loss = bv_atom_loss + bv_bond_loss
258
+ fg_atom_loss = fg_atom_from_atom_loss + fg_atom_from_bond_loss
259
+ fg_bond_loss = fg_bond_from_atom_loss + fg_bond_from_bond_loss
260
+
261
+ fg_loss = fg_atom_loss + fg_bond_loss
262
+ fg_dist_loss = fg_atom_dist_loss + fg_bond_dist_loss
263
+
264
+ # dist_loss = av_dist_loss + bv_dist_loss + fg_dist_loss
265
+ # print("%.4f %.4f %.4f %.4f %.4f %.4f"%(av_atom_loss,
266
+ # av_bond_loss,
267
+ # fg_atom_loss,
268
+ # fg_bond_loss,
269
+ # av_dist_loss,
270
+ # fg_dist_loss))
271
+ # return av_loss + fg_loss + dist_coff * dist_loss
272
+ overall_loss = av_loss + bv_loss + fg_loss + dist_coff * av_dist_loss + \
273
+ dist_coff * bv_dist_loss + fg_dist_loss
274
+
275
+ return overall_loss, av_loss, bv_loss, fg_loss, av_dist_loss, bv_dist_loss, fg_dist_loss
276
+
277
+ return loss_func
278
+
279
+ def forward(self, graph_batch: List):
280
+ """
281
+ The forward function.
282
+ :param graph_batch:
283
+ :return:
284
+ """
285
+ _, _, _, _, _, a_scope, b_scope, _ = graph_batch
286
+ a_scope = a_scope.data.cpu().numpy().tolist()
287
+
288
+ embeddings = self.grover(graph_batch)
289
+
290
+ av_task_pred_atom = self.av_task_atom(
291
+ embeddings["atom_from_atom"]) # if None: means not go through this fowward
292
+ av_task_pred_bond = self.av_task_bond(embeddings["atom_from_bond"])
293
+
294
+ bv_task_pred_atom = self.bv_task_atom(embeddings["bond_from_atom"])
295
+ bv_task_pred_bond = self.bv_task_bond(embeddings["bond_from_bond"])
296
+
297
+ fg_task_pred_all = self.fg_task_all(embeddings, a_scope, b_scope)
298
+
299
+ return {"av_task": (av_task_pred_atom, av_task_pred_bond),
300
+ "bv_task": (bv_task_pred_atom, bv_task_pred_bond),
301
+ "fg_task": fg_task_pred_all}
302
+
303
+
304
+ class GroverFpGeneration(nn.Module):
305
+ """
306
+ GroverFpGeneration class.
307
+ It loads the pre-trained model and produce the fingerprints for input molecules.
308
+ """
309
+ def __init__(self, args):
310
+ """
311
+ Init function.
312
+ :param args: the arguments.
313
+ """
314
+ super(GroverFpGeneration, self).__init__()
315
+
316
+ self.fingerprint_source = args.fingerprint_source
317
+ self.iscuda = args.cuda
318
+
319
+ self.grover = GROVEREmbedding(args)
320
+ self.readout = Readout(rtype="mean", hidden_size=args.hidden_size)
321
+
322
+ def forward(self, batch, features_batch):
323
+ """
324
+ The forward function.
325
+ It takes graph batch and molecular feature batch as input and produce the fingerprints of this molecules.
326
+ :param batch:
327
+ :param features_batch:
328
+ :return:
329
+ """
330
+ _, _, _, _, _, a_scope, b_scope, _ = batch
331
+
332
+ output = self.grover(batch)
333
+ # Share readout
334
+ mol_atom_from_bond_output = self.readout(output["atom_from_bond"], a_scope)
335
+ mol_atom_from_atom_output = self.readout(output["atom_from_atom"], a_scope)
336
+
337
+ if self.fingerprint_source == "bond" or self.fingerprint_source == "both":
338
+ mol_bond_from_atom_output = self.readout(output["bond_from_atom"], b_scope)
339
+ mol_bond_from_bodd_output = self.readout(output["bond_from_bond"], b_scope)
340
+
341
+ if features_batch[0] is not None:
342
+ features_batch = torch.from_numpy(np.stack(features_batch)).float()
343
+ if self.iscuda:
344
+ features_batch = features_batch.cuda()
345
+ features_batch = features_batch.to(output["atom_from_atom"])
346
+ if len(features_batch.shape) == 1:
347
+ features_batch = features_batch.view([1, features_batch.shape[0]])
348
+ else:
349
+ features_batch = None
350
+
351
+ if self.fingerprint_source == "atom":
352
+ fp = torch.cat([mol_atom_from_atom_output, mol_atom_from_bond_output], 1)
353
+ elif self.fingerprint_source == "bond":
354
+ fp = torch.cat([mol_bond_from_atom_output, mol_bond_from_bodd_output], 1)
355
+ else:
356
+ # the both case.
357
+ fp = torch.cat([mol_atom_from_atom_output, mol_atom_from_bond_output,
358
+ mol_bond_from_atom_output, mol_bond_from_bodd_output], 1)
359
+ if features_batch is not None:
360
+ fp = torch.cat([fp, features_batch], 1)
361
+ return fp
362
+
363
+
364
+ class GroverFinetuneTask(nn.Module):
365
+ """
366
+ The finetune
367
+ """
368
+ def __init__(self, args):
369
+ super(GroverFinetuneTask, self).__init__()
370
+
371
+ self.hidden_size = args.hidden_size
372
+ self.iscuda = args.cuda
373
+
374
+ self.grover = GROVEREmbedding(args)
375
+
376
+ if args.self_attention:
377
+ self.readout = Readout(rtype="self_attention", hidden_size=self.hidden_size,
378
+ attn_hidden=args.attn_hidden,
379
+ attn_out=args.attn_out)
380
+ else:
381
+ self.readout = Readout(rtype="mean", hidden_size=self.hidden_size)
382
+
383
+ self.mol_atom_from_atom_ffn = self.create_ffn(args)
384
+ self.mol_atom_from_bond_ffn = self.create_ffn(args)
385
+ #self.ffn = nn.ModuleList()
386
+ #self.ffn.append(self.mol_atom_from_atom_ffn)
387
+ #self.ffn.append(self.mol_atom_from_bond_ffn)
388
+
389
+ self.classification = args.dataset_type == 'classification'
390
+ if self.classification:
391
+ self.sigmoid = nn.Sigmoid()
392
+
393
+ def create_ffn(self, args: Namespace):
394
+ """
395
+ Creates the feed-forward network for the model.
396
+
397
+ :param args: Arguments.
398
+ """
399
+ # Note: args.features_dim is set according the real loaded features data
400
+ if args.features_only:
401
+ first_linear_dim = args.features_size + args.features_dim
402
+ else:
403
+ if args.self_attention:
404
+ first_linear_dim = args.hidden_size * args.attn_out
405
+ # TODO: Ad-hoc!
406
+ # if args.use_input_features:
407
+ first_linear_dim += args.features_dim
408
+ else:
409
+ first_linear_dim = args.hidden_size + args.features_dim
410
+
411
+ dropout = nn.Dropout(args.dropout)
412
+ activation = get_activation_function(args.activation)
413
+ # TODO: ffn_hidden_size
414
+ # Create FFN layers
415
+ if args.ffn_num_layers == 1:
416
+ ffn = [
417
+ dropout,
418
+ nn.Linear(first_linear_dim, args.output_size)
419
+ ]
420
+ else:
421
+ ffn = [
422
+ dropout,
423
+ nn.Linear(first_linear_dim, args.ffn_hidden_size)
424
+ ]
425
+ for _ in range(args.ffn_num_layers - 2):
426
+ ffn.extend([
427
+ activation,
428
+ dropout,
429
+ nn.Linear(args.ffn_hidden_size, args.ffn_hidden_size),
430
+ ])
431
+ ffn.extend([
432
+ activation,
433
+ dropout,
434
+ nn.Linear(args.ffn_hidden_size, args.output_size),
435
+ ])
436
+
437
+ # Create FFN model
438
+ return nn.Sequential(*ffn)
439
+
440
+ @staticmethod
441
+ def get_loss_func(args):
442
+ def loss_func(preds, targets,
443
+ dt=args.dataset_type,
444
+ dist_coff=args.dist_coff):
445
+
446
+ if dt == 'classification':
447
+ pred_loss = nn.BCEWithLogitsLoss(reduction='none')
448
+ elif dt == 'regression':
449
+ pred_loss = nn.MSELoss(reduction='none')
450
+ else:
451
+ raise ValueError(f'Dataset type "{args.dataset_type}" not supported.')
452
+
453
+ # print(type(preds))
454
+ # TODO: Here, should we need to involve the model status? Using len(preds) is just a hack.
455
+ if type(preds) is not tuple:
456
+ # in eval mode.
457
+ return pred_loss(preds, targets)
458
+
459
+ # in train mode.
460
+ dist_loss = nn.MSELoss(reduction='none')
461
+ # dist_loss = nn.CosineSimilarity(dim=0)
462
+ # print(pred_loss)
463
+
464
+ dist = dist_loss(preds[0], preds[1])
465
+ pred_loss1 = pred_loss(preds[0], targets)
466
+ pred_loss2 = pred_loss(preds[1], targets)
467
+ return pred_loss1 + pred_loss2 + dist_coff * dist
468
+
469
+ return loss_func
470
+
471
+ def forward(self, batch, features_batch):
472
+ _, _, _, _, _, a_scope, _, _ = batch
473
+
474
+ output = self.grover(batch)
475
+ # Share readout
476
+ mol_atom_from_bond_output = self.readout(output["atom_from_bond"], a_scope)
477
+ mol_atom_from_atom_output = self.readout(output["atom_from_atom"], a_scope)
478
+
479
+ if features_batch[0] is not None:
480
+ features_batch = torch.from_numpy(np.stack(features_batch)).float()
481
+ if self.iscuda:
482
+ features_batch = features_batch.cuda()
483
+ features_batch = features_batch.to(output["atom_from_atom"])
484
+ if len(features_batch.shape) == 1:
485
+ features_batch = features_batch.view([1, features_batch.shape[0]])
486
+ else:
487
+ features_batch = None
488
+
489
+
490
+ if features_batch is not None:
491
+ mol_atom_from_atom_output = torch.cat([mol_atom_from_atom_output, features_batch], 1)
492
+ mol_atom_from_bond_output = torch.cat([mol_atom_from_bond_output, features_batch], 1)
493
+
494
+ if self.training:
495
+ atom_ffn_output = self.mol_atom_from_atom_ffn(mol_atom_from_atom_output)
496
+ bond_ffn_output = self.mol_atom_from_bond_ffn(mol_atom_from_bond_output)
497
+ return atom_ffn_output, bond_ffn_output
498
+ else:
499
+ atom_ffn_output = self.mol_atom_from_atom_ffn(mol_atom_from_atom_output)
500
+ bond_ffn_output = self.mol_atom_from_bond_ffn(mol_atom_from_bond_output)
501
+ if self.classification:
502
+ atom_ffn_output = self.sigmoid(atom_ffn_output)
503
+ bond_ffn_output = self.sigmoid(bond_ffn_output)
504
+ output = (atom_ffn_output + bond_ffn_output) / 2
505
+
506
+ return output
grover/util/metrics.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The evaluation metrics.
3
+ """
4
+ import math
5
+ from typing import List, Callable, Union
6
+
7
+ from sklearn.metrics import accuracy_score, mean_squared_error, roc_auc_score, mean_absolute_error, r2_score, \
8
+ precision_recall_curve, auc, recall_score, confusion_matrix
9
+
10
+
11
+ def accuracy(targets: List[int], preds: List[float], threshold: float = 0.5) -> float:
12
+ """
13
+ Computes the accuracy of a binary prediction task using a given threshold for generating hard predictions.
14
+
15
+ :param targets: A list of binary targets.
16
+ :param preds: A list of prediction probabilities.
17
+ :param threshold: The threshold above which a prediction is a 1 and below which (inclusive) a prediction is a 0
18
+ :return: The computed accuracy.
19
+ """
20
+ hard_preds = [1 if p > threshold else 0 for p in preds]
21
+ return accuracy_score(targets, hard_preds)
22
+
23
+
24
+ def recall(targets: List[int], preds: List[float], threshold: float = 0.5) -> float:
25
+ """
26
+ Computes the recall of a binary prediction task using a given threshold for generating hard predictions.
27
+
28
+ :param targets: A list of binary targets.
29
+ :param preds: A list of prediction probabilities.
30
+ :param threshold: The threshold above which a prediction is a 1 and below which (inclusive) a prediction is a 0
31
+ :return: The computed recall.
32
+ """
33
+ hard_preds = [1 if p > threshold else 0 for p in preds]
34
+ return recall_score(targets, hard_preds)
35
+
36
+
37
+ def sensitivity(targets: List[int], preds: List[float], threshold: float = 0.5) -> float:
38
+ """
39
+ Computes the sensitivity of a binary prediction task using a given threshold for generating hard predictions.
40
+
41
+ :param targets: A list of binary targets.
42
+ :param preds: A list of prediction probabilities.
43
+ :param threshold: The threshold above which a prediction is a 1 and below which (inclusive) a prediction is a 0
44
+ :return: The computed sensitivity.
45
+ """
46
+ return recall(targets, preds, threshold)
47
+
48
+
49
+ def specificity(targets: List[int], preds: List[float], threshold: float = 0.5) -> float:
50
+ """
51
+ Computes the specificity of a binary prediction task using a given threshold for generating hard predictions.
52
+
53
+ :param targets: A list of binary targets.
54
+ :param preds: A list of prediction probabilities.
55
+ :param threshold: The threshold above which a prediction is a 1 and below which (inclusive) a prediction is a 0
56
+ :return: The computed specificity.
57
+ """
58
+ hard_preds = [1 if p > threshold else 0 for p in preds]
59
+ tn, fp, _, _ = confusion_matrix(targets, hard_preds).ravel()
60
+ return tn / float(tn + fp)
61
+
62
+
63
+
64
+ def rmse(targets: List[float], preds: List[float]) -> float:
65
+ """
66
+ Computes the root mean squared error.
67
+
68
+ :param targets: A list of targets.
69
+ :param preds: A list of predictions.
70
+ :return: The computed rmse.
71
+ """
72
+ return math.sqrt(mean_squared_error(targets, preds))
73
+
74
+
75
+ def get_metric_func(metric: str) -> Callable[[Union[List[int], List[float]], List[float]], float]:
76
+ """
77
+ Gets the metric function corresponding to a given metric name.
78
+
79
+ :param metric: Metric name.
80
+ :return: A metric function which takes as arguments a list of targets and a list of predictions and returns.
81
+ """
82
+ # Note: If you want to add a new metric, please also update the parser argument --metric in parsing.py.
83
+ if metric == 'auc':
84
+ return roc_auc_score
85
+
86
+ if metric == 'prc-auc':
87
+ return prc_auc
88
+
89
+ if metric == 'rmse':
90
+ return rmse
91
+
92
+ if metric == 'mae':
93
+ return mean_absolute_error
94
+
95
+ if metric == 'r2':
96
+ return r2_score
97
+
98
+ if metric == 'accuracy':
99
+ return accuracy
100
+
101
+ if metric == 'recall':
102
+ return recall
103
+
104
+ if metric == 'sensitivity':
105
+ return sensitivity
106
+
107
+ if metric == 'specificity':
108
+ return specificity
109
+
110
+ raise ValueError(f'Metric "{metric}" not supported.')
111
+
112
+
113
+ def prc_auc(targets: List[int], preds: List[float]) -> float:
114
+ """
115
+ Computes the area under the precision-recall curve.
116
+
117
+ :param targets: A list of binary targets.
118
+ :param preds: A list of prediction probabilities.
119
+ :return: The computed prc-auc.
120
+ """
121
+ precision, recall, _ = precision_recall_curve(targets, preds)
122
+ return auc(recall, precision)
grover/util/multi_gpu_wrapper.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Wrapper for multi-GPU training.
3
+ """
4
+ # use Hovorod for multi-GPU pytorch training
5
+ try:
6
+ import horovod.torch as mgw
7
+ import torch
8
+
9
+ print('using Horovod for multi-GPU training')
10
+ except ImportError:
11
+ print('[WARNING] Horovod cannot be imported; multi-GPU training is unsupported')
12
+ pass
13
+
14
+
15
+ class MultiGpuWrapper(object):
16
+ """Wrapper for multi-GPU training."""
17
+
18
+ def __init__(self):
19
+ """Constructor function."""
20
+ pass
21
+
22
+ @classmethod
23
+ def init(cls, *args):
24
+ """Initialization."""
25
+
26
+ try:
27
+ return mgw.init(*args)
28
+ except NameError:
29
+ raise NameError('module <mgw> not imported')
30
+
31
+ @classmethod
32
+ def size(cls, *args):
33
+ """Get the number of workers at all nodes."""
34
+
35
+ try:
36
+ return mgw.size(*args)
37
+ except NameError:
38
+ raise NameError('module <mgw> not imported')
39
+
40
+ @classmethod
41
+ def rank(cls, *args):
42
+ """Get the rank of current worker at all nodes."""
43
+
44
+ try:
45
+ return mgw.rank(*args)
46
+ except NameError:
47
+ raise NameError('module <mgw> not imported')
48
+
49
+ @classmethod
50
+ def local_size(cls, *args):
51
+ """Get the number of workers at the current node."""
52
+
53
+ try:
54
+ return mgw.local_size(*args)
55
+ except NameError:
56
+ raise NameError('module <mgw> not imported')
57
+
58
+ @classmethod
59
+ def local_rank(cls, *args):
60
+ """Get the rank of current worker at the current node."""
61
+
62
+ try:
63
+ return mgw.local_rank(*args)
64
+ except NameError:
65
+ raise NameError('module <mgw> not imported')
66
+
67
+ @classmethod
68
+ def DistributedOptimizer(cls, *args, **kwargs):
69
+ """Get a distributed optimizer from the base optimizer."""
70
+
71
+ try:
72
+ return mgw.DistributedOptimizer(*args, **kwargs)
73
+ except NameError:
74
+ raise NameError('module <mgw> not imported')
75
+
76
+ @classmethod
77
+ def broadcast_parameters(cls, *args, **kwargs):
78
+ """Get a operation to broadcast all the parameters."""
79
+
80
+ try:
81
+ return mgw.broadcast_parameters(*args, **kwargs)
82
+ except NameError:
83
+ raise NameError('module <mgw> not imported')
84
+
85
+ @classmethod
86
+ def broadcast_optimizer_state(cls, *args, **kwargs):
87
+ """Get a operation to broadcast all the optimizer state."""
88
+
89
+ try:
90
+ return mgw.broadcast_optimizer_state(*args, **kwargs)
91
+ except NameError:
92
+ raise NameError('module <mgw> not imported')
93
+
94
+ @classmethod
95
+ def broadcast(cls, *args, **kwargs):
96
+ """Get a operation to broadcast all the optimizer state."""
97
+
98
+ try:
99
+ return mgw.broadcast(*args, **kwargs)
100
+ except NameError:
101
+ raise NameError('module <mgw> not imported')
102
+
103
+ @classmethod
104
+ def barrier(cls):
105
+ """Add a barrier to synchronize different processes"""
106
+
107
+ try:
108
+ return mgw.allreduce(torch.tensor(0), name='barrier')
109
+ except NameError:
110
+ raise NameError('module <mgw> not imported')
grover/util/nn_utils.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The utility function for model construction.
3
+ This implementation is adapted from
4
+ https://github.com/chemprop/chemprop/blob/master/chemprop/nn_utils.py
5
+ """
6
+ import torch
7
+ from torch import nn as nn
8
+
9
+
10
+ def param_count(model: nn.Module) -> int:
11
+ """
12
+ Determines number of trainable parameters.
13
+ :param model: An nn.Module.
14
+ :return: The number of trainable parameters.
15
+ """
16
+ return sum(param.numel() for param in model.parameters() if param.requires_grad)
17
+
18
+
19
+ def index_select_nd(source: torch.Tensor, index: torch.Tensor) -> torch.Tensor:
20
+ """
21
+ Selects the message features from source corresponding to the atom or bond indices in index.
22
+
23
+ :param source: A tensor of shape (num_bonds, hidden_size) containing message features.
24
+ :param index: A tensor of shape (num_atoms/num_bonds, max_num_bonds) containing the atom or bond
25
+ indices to select from source.
26
+ :return: A tensor of shape (num_atoms/num_bonds, max_num_bonds, hidden_size) containing the message
27
+ features corresponding to the atoms/bonds specified in index.
28
+ """
29
+ index_size = index.size() # (num_atoms/num_bonds, max_num_bonds)
30
+ suffix_dim = source.size()[1:] # (hidden_size,)
31
+ final_size = index_size + suffix_dim # (num_atoms/num_bonds, max_num_bonds, hidden_size)
32
+
33
+ target = source.index_select(dim=0, index=index.view(-1)) # (num_atoms/num_bonds * max_num_bonds, hidden_size)
34
+ target = target.view(final_size) # (num_atoms/num_bonds, max_num_bonds, hidden_size)
35
+
36
+ return target
37
+
38
+
39
+ def get_activation_function(activation: str) -> nn.Module:
40
+ """
41
+ Gets an activation function module given the name of the activation.
42
+
43
+ :param activation: The name of the activation function.
44
+ :return: The activation function module.
45
+ """
46
+ if activation == 'ReLU':
47
+ return nn.ReLU()
48
+ elif activation == 'LeakyReLU':
49
+ return nn.LeakyReLU(0.1)
50
+ elif activation == 'PReLU':
51
+ return nn.PReLU()
52
+ elif activation == 'tanh':
53
+ return nn.Tanh()
54
+ elif activation == 'SELU':
55
+ return nn.SELU()
56
+ elif activation == 'ELU':
57
+ return nn.ELU()
58
+ elif activation == "Linear":
59
+ return lambda x: x
60
+ else:
61
+ raise ValueError(f'Activation "{activation}" not supported.')
62
+
63
+
64
+ def initialize_weights(model: nn.Module, distinct_init=False, model_idx=0):
65
+ """
66
+ Initializes the weights of a model in place.
67
+
68
+ :param model: An nn.Module.
69
+ """
70
+ init_fns = [nn.init.kaiming_normal_, nn.init.kaiming_uniform_,
71
+ nn.init.xavier_normal_, nn.init.xavier_uniform_]
72
+ for param in model.parameters():
73
+ if param.dim() == 1:
74
+ nn.init.constant_(param, 0)
75
+ else:
76
+ if distinct_init:
77
+ init_fn = init_fns[model_idx % 4]
78
+ if 'kaiming' in init_fn.__name__:
79
+ init_fn(param, nonlinearity='relu')
80
+ else:
81
+ init_fn(param)
82
+ else:
83
+ nn.init.xavier_normal_(param)
84
+
85
+
86
+ def select_neighbor_and_aggregate(feature, index):
87
+ """
88
+ The basic operation in message passing.
89
+ Caution: the index_selec_ND would cause the reproducibility issue when performing the training on CUDA.
90
+ See: https://pytorch.org/docs/stable/notes/randomness.html
91
+ :param feature: the candidate feature for aggregate. (n_nodes, hidden)
92
+ :param index: the selected index (neighbor indexes).
93
+ :return:
94
+ """
95
+ neighbor = index_select_nd(feature, index)
96
+ return neighbor.sum(dim=1)
grover/util/parsing.py ADDED
@@ -0,0 +1,487 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The parsing functions for the argument input.
3
+ """
4
+ import os
5
+ import pickle
6
+ from argparse import ArgumentParser, Namespace
7
+ from tempfile import TemporaryDirectory
8
+
9
+ import torch
10
+
11
+ from grover.data.molfeaturegenerator import get_available_features_generators
12
+ from grover.util.utils import makedirs
13
+
14
+
15
+ def add_common_args(parser: ArgumentParser):
16
+ parser.add_argument('--no_cache', action='store_true', default=True,
17
+ help='Turn off caching mol2graph computation')
18
+ parser.add_argument('--gpu', type=int, default=0,
19
+ choices=list(range(torch.cuda.device_count())),
20
+ help='Which GPU to use')
21
+ parser.add_argument('--no_cuda', action='store_true', default=False,
22
+ help='Turn off cuda')
23
+ parser.add_argument('--batch_size', type=int, default=32,
24
+ help='Batch size')
25
+
26
+
27
+ def add_predict_args(parser: ArgumentParser):
28
+ """
29
+ Adds predict arguments to an ArgumentParser.
30
+
31
+ :param parser: An ArgumentParser.
32
+ """
33
+ add_common_args(parser)
34
+
35
+ parser.add_argument('--data_path', type=str,
36
+ help='Path to CSV file containing testing data for which predictions will be made')
37
+
38
+ parser.add_argument('--output_path', type=str,
39
+ help='Path to CSV file where predictions will be saved')
40
+ parser.add_argument('--checkpoint_dir', type=str,
41
+ help='Directory from which to load model checkpoints'
42
+ '(walks directory and ensembles all models that are found)')
43
+
44
+ parser.add_argument('--features_generator', type=str, nargs='*',
45
+ choices=get_available_features_generators(),
46
+ help='Method of generating additional features')
47
+ parser.add_argument('--features_path', type=str, nargs='*',
48
+ help='Path to features to use in FNN (instead of features_generator)')
49
+ parser.add_argument('--no_features_scaling', action='store_true', default=False,
50
+ help='Turn off scaling of features')
51
+
52
+
53
+ def add_fingerprint_args(parser):
54
+ add_common_args(parser)
55
+ # parameters for fingerprints generation
56
+ parser.add_argument('--data_path', type=str, help='Input csv file which contains SMILES')
57
+ parser.add_argument('--output_path', type=str,
58
+ help='Path to npz file where predictions will be saved')
59
+ parser.add_argument('--features_path', type=str, nargs='*',
60
+ help='Path to features to use in FNN (instead of features_generator)')
61
+ parser.add_argument('--fingerprint_source', type=str,
62
+ choices=['atom', 'bond', 'both'], default='both',
63
+ help='The source to generate the fingerprints.')
64
+ parser.add_argument('--checkpoint_path', type=str, help='model path')
65
+
66
+
67
+ def add_finetune_args(parser: ArgumentParser):
68
+ """
69
+ Adds training arguments to an ArgumentParser.
70
+
71
+ :param parser: An ArgumentParser.
72
+ """
73
+
74
+ # General arguments
75
+ add_common_args(parser)
76
+ parser.add_argument('--tensorboard', action='store_true', default=False, help='Add tensorboard logger')
77
+
78
+ # Data argumenets
79
+ parser.add_argument('--data_path', type=str,
80
+ help='Path to data CSV file.')
81
+ parser.add_argument('--use_compound_names', action='store_true', default=False,
82
+ help='Use when test data file contains compound names in addition to SMILES strings')
83
+ parser.add_argument('--max_data_size', type=int,
84
+ help='Maximum number of data points to load')
85
+ # Disable this option due to some bugs.
86
+ # parser.add_argument('--test', action='store_true', default=False,
87
+ # help='Whether to skip training and only test the model')
88
+ parser.add_argument('--features_only', action='store_true', default=False,
89
+ help='Use only the additional features in an FFN, no graph network')
90
+ parser.add_argument('--features_generator', type=str, nargs='*',
91
+ choices=get_available_features_generators(),
92
+ help='Method of generating additional features.')
93
+ parser.add_argument('--features_path', type=str, nargs='*',
94
+ help='Path to features to use in FNN (instead of features_generator).')
95
+ parser.add_argument('--save_dir', type=str, default=None,
96
+ help='Directory where model checkpoints will be saved')
97
+ parser.add_argument('--save_smiles_splits', action='store_true', default=False,
98
+ help='Save smiles for each train/val/test splits for prediction convenience later')
99
+ parser.add_argument('--checkpoint_dir', type=str, default=None,
100
+ help='Directory from which to load model checkpoints'
101
+ '(walks directory and ensembles all models that are found)')
102
+ parser.add_argument('--checkpoint_path', type=str, default=None,
103
+ help='Path to model checkpoint (.pt file)')
104
+
105
+ # Data splitting.
106
+ parser.add_argument('--dataset_type', type=str,
107
+ choices=['classification', 'regression'], default='classification',
108
+ help='Type of dataset, e.g. classification or regression.'
109
+ 'This determines the loss function used during training.')
110
+ parser.add_argument('--separate_val_path', type=str,
111
+ help='Path to separate val set, optional')
112
+ parser.add_argument('--separate_val_features_path', type=str, nargs='*',
113
+ help='Path to file with features for separate val set')
114
+ parser.add_argument('--separate_test_path', type=str,
115
+ help='Path to separate test set, optional')
116
+ parser.add_argument('--separate_test_features_path', type=str, nargs='*',
117
+ help='Path to file with features for separate test set')
118
+ parser.add_argument('--split_type', type=str, default='random',
119
+ choices=['random', 'scaffold_balanced', 'predetermined', 'crossval', 'index_predetermined'],
120
+ help='Method of splitting the data into train/val/test')
121
+ parser.add_argument('--split_sizes', type=float, nargs=3, default=[0.8, 0.1, 0.1],
122
+ help='Split proportions for train/validation/test sets')
123
+ parser.add_argument('--num_folds', type=int, default=1,
124
+ help='Number of folds when performing cross validation')
125
+ parser.add_argument('--folds_file', type=str, default=None,
126
+ help='Optional file of fold labels')
127
+ parser.add_argument('--val_fold_index', type=int, default=None,
128
+ help='Which fold to use as val for leave-one-out cross val')
129
+ parser.add_argument('--test_fold_index', type=int, default=None,
130
+ help='Which fold to use as test for leave-one-out cross val')
131
+ parser.add_argument('--crossval_index_dir', type=str,
132
+ help='Directory in which to find cross validation index files')
133
+ parser.add_argument('--crossval_index_file', type=str,
134
+ help='Indices of files to use as train/val/test'
135
+ 'Overrides --num_folds and --seed.')
136
+ parser.add_argument('--seed', type=int, default=0,
137
+ help='Random seed to use when splitting data into train/val/test sets.'
138
+ 'When `num_folds` > 1, the first fold uses this seed and all'
139
+ 'subsequent folds add 1 to the seed.')
140
+
141
+ # Metric
142
+ parser.add_argument('--metric', type=str, default=None,
143
+ choices=['auc',
144
+ 'prc-auc',
145
+ 'rmse',
146
+ 'mae',
147
+ 'r2',
148
+ 'accuracy',
149
+ 'recall',
150
+ 'sensitivity',
151
+ 'specificity',
152
+ 'matthews_corrcoef'],
153
+ help='Metric to use during evaluation.'
154
+ 'Note: Does NOT affect loss function used during training'
155
+ '(loss is determined by the `dataset_type` argument).'
156
+ 'Note: Defaults to "auc" for classification and "rmse" for regression.')
157
+ parser.add_argument('--show_individual_scores', action='store_true', default=False,
158
+ help='Show all scores for individual targets, not just average, at the end')
159
+
160
+
161
+
162
+
163
+ # Training arguments
164
+ parser.add_argument('--epochs', type=int, default=30,
165
+ help='Number of epochs to task')
166
+ parser.add_argument('--warmup_epochs', type=float, default=2.0,
167
+ help='Number of epochs during which learning rate increases linearly from'
168
+ 'init_lr to max_lr. Afterwards, learning rate decreases exponentially'
169
+ 'from max_lr to final_lr.')
170
+ parser.add_argument('--init_lr', type=float, default=1e-4,
171
+ help='Initial learning rate')
172
+ parser.add_argument('--max_lr', type=float, default=1e-3,
173
+ help='Maximum learning rate')
174
+ parser.add_argument('--final_lr', type=float, default=1e-4,
175
+ help='Final learning rate')
176
+ parser.add_argument('--no_features_scaling', action='store_true', default=False,
177
+ help='Turn off scaling of features')
178
+ parser.add_argument('--early_stop_epoch', type=int, default=1000, help='If val loss did not drop in '
179
+ 'this epochs, stop running')
180
+
181
+ # Model arguments
182
+ parser.add_argument('--ensemble_size', type=int, default=1,
183
+ help='Number of models for ensemble prediction.')
184
+ parser.add_argument('--dropout', type=float, default=0.0,
185
+ help='Dropout probability')
186
+ parser.add_argument('--activation', type=str, default='ReLU',
187
+ choices=['ReLU', 'LeakyReLU', 'PReLU', 'tanh', 'SELU', 'ELU'],
188
+ help='Activation function')
189
+ parser.add_argument('--ffn_hidden_size', type=int, default=None,
190
+ help='Hidden dim for higher-capacity FFN (defaults to hidden_size)')
191
+ parser.add_argument('--ffn_num_layers', type=int, default=2,
192
+ help='Number of layers in FFN after MPN encoding')
193
+ parser.add_argument('--weight_decay', type=float, default=0.0, help='weight_decay')
194
+ parser.add_argument('--select_by_loss', action='store_true', default=False,
195
+ help='Use validation loss as refence standard to select best model to predict')
196
+
197
+ parser.add_argument("--embedding_output_type", default="atom", choices=["atom", "bond", "both"],
198
+ help="This the model parameters for pretrain model. The current finetuning task only use the "
199
+ "embeddings from atom branch. ")
200
+
201
+ # Self-attentive readout.
202
+ parser.add_argument('--self_attention', action='store_true', default=False, help='Use self attention layer. '
203
+ 'Otherwise use mean aggregation '
204
+ 'layer.')
205
+ parser.add_argument('--attn_hidden', type=int, default=4, nargs='?', help='Self attention layer '
206
+ 'hidden layer size.')
207
+ parser.add_argument('--attn_out', type=int, default=128, nargs='?', help='Self attention layer '
208
+ 'output feature size.')
209
+
210
+ parser.add_argument('--dist_coff', type=float, default=0.1, help='The dist coefficient for output of two branches.')
211
+
212
+
213
+ parser.add_argument('--bond_drop_rate', type=float, default=0, help='Drop out bond in molecular.')
214
+ parser.add_argument('--distinct_init', action='store_true', default=False,
215
+ help='Using distinct weight init for model ensemble')
216
+ parser.add_argument('--fine_tune_coff', type=float, default=1,
217
+ help='Enable distinct fine tune learning rate for fc and other layer')
218
+
219
+ # For multi-gpu finetune.
220
+ parser.add_argument('--enbl_multi_gpu', dest='enbl_multi_gpu',
221
+ action='store_true', default=False,
222
+ help='enable multi-GPU training')
223
+
224
+
225
+ def add_pretrain_args(parser: ArgumentParser):
226
+ parser.add_argument('--cuda', type=bool, default=True,
227
+ help='Enable gpu traning or not.')
228
+ parser.add_argument('--enable_multi_gpu', dest='enable_multi_gpu',
229
+ action='store_true', default=False,
230
+ help='enable multi-GPU training')
231
+
232
+ # Data arguments
233
+ parser.add_argument('--data_path', type=str,
234
+ help='Path to data CSV file')
235
+ parser.add_argument('--fg_label_path', type=str, nargs='*',
236
+ help='Path to the label of fg task.')
237
+ parser.add_argument('--atom_vocab_path', type=str, help="Path to the vocabulary.")
238
+ parser.add_argument('--bond_vocab_path', type=str,
239
+ help="Path to the bond vocabulary.")
240
+
241
+ # Model arguments
242
+ parser.add_argument('--embedding_output_type', type=str, default='both', nargs='?',
243
+ choices=("atom", "bond", "both"),
244
+ help="Type of output embeddings. Options: atom, bond, both")
245
+
246
+ #parser.add_argument('--source_branch', type=str, default='both', nargs='?', choices=("atom", "bond", "both"),
247
+ # help="Type of source branch in gtrans. Options: atom, bond, both")
248
+
249
+ parser.add_argument('--save_dir', type=str, default=None,
250
+ help='Directory where model checkpoints will be saved')
251
+ parser.add_argument('--save_interval', type=int, default=9999999999, help='The model saving interval.')
252
+ parser.add_argument('--hidden_size', type=float, default=3,
253
+ help='Dimensionality of hidden layers. The actual dimension is hidden_size * 100.')
254
+ parser.add_argument('--bias', action='store_true', default=False,
255
+ help='Whether to add bias to linear layers')
256
+ parser.add_argument('--depth', type=int, default=3,
257
+ help='Number of message passing steps')
258
+ parser.add_argument('--dropout', type=float, default=0.0,
259
+ help='Dropout probability')
260
+ parser.add_argument('--activation', type=str, default='PReLU',
261
+ choices=['ReLU', 'LeakyReLU', 'PReLU', 'tanh', 'SELU', 'ELU'],
262
+ help='Activation function')
263
+ parser.add_argument('--undirected', action='store_true', default=False,
264
+ help='Undirected edges (always sum the two relevant bond vectors)')
265
+ parser.add_argument('--weight_decay', type=float, default=0.0, help='weight_decay')
266
+ parser.add_argument('--num_attn_head', type=int, default=4, help='The attention head in MTBlock.')
267
+ parser.add_argument('--num_mt_block', type=int, default=1, help="The number of MTBlock.")
268
+ parser.add_argument('--dist_coff', type=float, default=0.1, help='The disagreement coefficient for '
269
+ 'the atom and bond branch.')
270
+
271
+
272
+ # Training arguments
273
+ parser.add_argument("--backbone", default="gtrans", choices=["gtrans"])
274
+ parser.add_argument('--epochs', type=int, default=30,
275
+ help='Number of epochs to run')
276
+ parser.add_argument('--batch_size', type=int, default=32,
277
+ help='Batch size')
278
+ parser.add_argument('--warmup_epochs', type=float, default=2.0,
279
+ help='Number of epochs during which learning rate increases linearly from'
280
+ 'init_lr to max_lr. Afterwards, learning rate decreases exponentially'
281
+ 'from max_lr to final_lr.')
282
+ parser.add_argument('--init_lr', type=float, default=1e-4,
283
+ help='Initial learning rate')
284
+ parser.add_argument('--max_lr', type=float, default=1e-3,
285
+ help='Maximum learning rate')
286
+ parser.add_argument('--final_lr', type=float, default=1e-4,
287
+ help='Final learning rate')
288
+ parser.add_argument('--bond_drop_rate', type=float, default=0, help='Drop out bond in molecular')
289
+
290
+
291
+
292
+ def update_checkpoint_args(args: Namespace):
293
+ """
294
+ Walks the checkpoint directory to find all checkpoints, updating args.checkpoint_paths and args.ensemble_size.
295
+
296
+ :param args: Arguments.
297
+ """
298
+ if hasattr(args, 'checkpoint_paths') and args.checkpoint_paths is not None:
299
+ return
300
+ if not hasattr(args, 'checkpoint_path'):
301
+ args.checkpoint_path = None
302
+
303
+ if not hasattr(args, 'checkpoint_dir'):
304
+ args.checkpoint_dir = None
305
+
306
+ if args.checkpoint_dir is not None and args.checkpoint_path is not None:
307
+ raise ValueError('Only one of checkpoint_dir and checkpoint_path can be specified.')
308
+
309
+ if args.checkpoint_dir is None:
310
+ args.checkpoint_paths = [args.checkpoint_path] if args.checkpoint_path is not None else None
311
+ return
312
+
313
+ args.checkpoint_paths = []
314
+
315
+ for root, _, files in os.walk(args.checkpoint_dir):
316
+ for fname in files:
317
+ if fname.endswith('.pt'):
318
+ args.checkpoint_paths.append(os.path.join(root, fname))
319
+
320
+ if args.parser_name == "eval":
321
+ assert args.ensemble_size * args.num_folds == len(args.checkpoint_paths)
322
+
323
+ args.ensemble_size = len(args.checkpoint_paths)
324
+
325
+
326
+
327
+ if args.ensemble_size == 0:
328
+ raise ValueError(f'Failed to find any model checkpoints in directory "{args.checkpoint_dir}"')
329
+
330
+
331
+ def modify_predict_args(args: Namespace):
332
+ """
333
+ Modifies and validates predicting args in place.
334
+
335
+ :param args: Arguments.
336
+ """
337
+ assert args.data_path
338
+ assert args.output_path
339
+ assert args.checkpoint_dir is not None or args.checkpoint_path is not None or args.checkpoint_paths is not None
340
+
341
+ update_checkpoint_args(args)
342
+
343
+ args.cuda = not args.no_cuda and torch.cuda.is_available()
344
+ del args.no_cuda
345
+
346
+ # Create directory for preds path
347
+ makedirs(args.output_path, isfile=True)
348
+ setattr(args, 'fingerprint', False)
349
+
350
+
351
+ def modify_fingerprint_args(args):
352
+ assert args.data_path
353
+ assert args.output_path
354
+ assert args.checkpoint_path is not None or args.checkpoint_paths is not None
355
+
356
+
357
+ update_checkpoint_args(args)
358
+ args.cuda = not args.no_cuda and torch.cuda.is_available()
359
+ del args.no_cuda
360
+ makedirs(args.output_path, isfile=True)
361
+ setattr(args, 'fingerprint', True)
362
+
363
+
364
+ def get_newest_train_args():
365
+ """
366
+ For backward compatibility.
367
+
368
+ :return: A Namespace containing the newest training arguments
369
+ """
370
+ dummy_parser = ArgumentParser()
371
+ add_finetune_args(dummy_parser)
372
+ args = dummy_parser.parse_args(args=[])
373
+ args.data_path = ''
374
+ modify_train_args(args)
375
+ return args
376
+
377
+
378
+ def modify_train_args(args: Namespace):
379
+ """
380
+ Modifies and validates training arguments in place.
381
+
382
+ :param args: Arguments.
383
+ """
384
+ global TEMP_DIR # Prevents the temporary directory from being deleted upon function return
385
+
386
+ assert args.data_path is not None
387
+ assert args.dataset_type is not None
388
+
389
+ if args.save_dir is not None:
390
+ makedirs(args.save_dir)
391
+ else:
392
+ TEMP_DIR = TemporaryDirectory()
393
+ args.save_dir = TEMP_DIR.name
394
+
395
+ args.cuda = not args.no_cuda and torch.cuda.is_available()
396
+ del args.no_cuda
397
+
398
+ args.features_scaling = not args.no_features_scaling
399
+ del args.no_features_scaling
400
+
401
+ if args.metric is None:
402
+ if args.dataset_type == 'classification':
403
+ args.metric = 'auc'
404
+ else:
405
+ args.metric = 'rmse'
406
+
407
+ if not ((args.dataset_type == 'classification' and args.metric in ['auc', 'prc-auc', 'accuracy']) or
408
+ (args.dataset_type == 'regression' and args.metric in ['rmse', 'mae', 'r2'])):
409
+ raise ValueError(f'Metric "{args.metric}" invalid for dataset type "{args.dataset_type}".')
410
+
411
+ args.minimize_score = args.metric in ['rmse', 'mae']
412
+
413
+ update_checkpoint_args(args)
414
+
415
+ if args.features_only:
416
+ assert args.features_generator or args.features_path
417
+
418
+ args.use_input_features = args.features_generator or args.features_path
419
+
420
+ if args.features_generator is not None and 'rdkit_2d_normalized' in args.features_generator:
421
+ assert not args.features_scaling
422
+
423
+ args.num_lrs = 1
424
+
425
+
426
+
427
+ assert (args.split_type == 'predetermined') == (args.folds_file is not None) == (args.test_fold_index is not None)
428
+ assert (args.split_type == 'crossval') == (args.crossval_index_dir is not None)
429
+ assert (args.split_type in ['crossval', 'index_predetermined']) == (args.crossval_index_file is not None)
430
+ if args.split_type in ['crossval', 'index_predetermined']:
431
+ with open(args.crossval_index_file, 'rb') as rf:
432
+ args.crossval_index_sets = pickle.load(rf)
433
+ args.num_folds = len(args.crossval_index_sets)
434
+ args.seed = 0
435
+
436
+
437
+ if args.bond_drop_rate > 0:
438
+ args.no_cache = True
439
+
440
+ setattr(args, 'fingerprint', False)
441
+
442
+
443
+ def modify_pretrain_args(args: Namespace):
444
+ """
445
+
446
+ :param args:
447
+ :return:
448
+ """
449
+ args.dense = False
450
+ args.fine_tune_coff = 1
451
+ args.no_cache = True
452
+ args.hidden_size = int(args.hidden_size)
453
+
454
+
455
+ def parse_args() -> Namespace:
456
+ """
457
+ Parses arguments for training and testing (includes modifying/validating arguments).
458
+
459
+ :return: A Namespace containing the parsed, modified, and validated args.
460
+ """
461
+ parser = ArgumentParser()
462
+ subparser = parser.add_subparsers(title="subcommands",
463
+ dest="parser_name",
464
+ help="Subcommands for fintune, prediction, and fingerprint.")
465
+ parser_finetune = subparser.add_parser('finetune', help="Fine tune the pre-trained model.")
466
+ add_finetune_args(parser_finetune)
467
+ parser_eval = subparser.add_parser('eval', help="Evaluate the results of the pre-trained model.")
468
+ add_finetune_args(parser_eval)
469
+ parser_predict = subparser.add_parser('predict', help="Predict results from fine tuned model.")
470
+ add_predict_args(parser_predict)
471
+ parser_fp = subparser.add_parser('fingerprint', help="Get the fingerprints of SMILES.")
472
+ add_fingerprint_args(parser_fp)
473
+ parser_pretrain = subparser.add_parser('pretrain', help="Pretrain with unlabelled SMILES.")
474
+ add_pretrain_args(parser_pretrain)
475
+
476
+ args = parser.parse_args()
477
+
478
+ if args.parser_name == 'finetune' or args.parser_name == 'eval':
479
+ modify_train_args(args)
480
+ elif args.parser_name == "pretrain":
481
+ modify_pretrain_args(args)
482
+ elif args.parser_name == 'predict':
483
+ modify_predict_args(args)
484
+ elif args.parser_name == 'fingerprint':
485
+ modify_fingerprint_args(args)
486
+
487
+ return args
grover/util/scheduler.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The learning rate scheduler.
3
+ This implementation is adapted from
4
+ https://github.com/chemprop/chemprop/blob/master/chemprop/nn_utils.py
5
+ """
6
+ from typing import List, Union
7
+
8
+ import numpy as np
9
+ from torch.optim.lr_scheduler import _LRScheduler
10
+
11
+
12
+ class NoamLR(_LRScheduler):
13
+ """
14
+ Noam learning rate scheduler with piecewise linear increase and exponential decay.
15
+
16
+ The learning rate increases linearly from init_lr to max_lr over the course of
17
+ the first warmup_steps (where warmup_steps = warmup_epochs * steps_per_epoch).
18
+ Then the learning rate decreases exponentially from max_lr to final_lr over the
19
+ course of the remaining total_steps - warmup_steps (where total_steps =
20
+ total_epochs * steps_per_epoch). This is roughly based on the learning rate
21
+ schedule from SelfAttention is All You Need, section 5.3 (https://arxiv.org/abs/1706.03762).
22
+ """
23
+ def __init__(self,
24
+ optimizer,
25
+ warmup_epochs: List[Union[float, int]],
26
+ total_epochs: List[int],
27
+ steps_per_epoch: int,
28
+ init_lr: List[float],
29
+ max_lr: List[float],
30
+ final_lr: List[float],
31
+ fine_tune_coff: float = 1.0,
32
+ fine_tune_param_idx: int = 0):
33
+ """
34
+ Initializes the learning rate scheduler.
35
+
36
+
37
+ :param optimizer: A PyTorch optimizer.
38
+ :param warmup_epochs: The number of epochs during which to linearly increase the learning rate.
39
+ :param total_epochs: The total number of epochs.
40
+ :param steps_per_epoch: The number of steps (batches) per epoch.
41
+ :param init_lr: The initial learning rate.
42
+ :param max_lr: The maximum learning rate (achieved after warmup_epochs).
43
+ :param final_lr: The final learning rate (achieved after total_epochs).
44
+ :param fine_tune_coff: The fine tune coefficient for the target param group. The true learning rate for the
45
+ target param group would be lr*fine_tune_coff.
46
+ :param fine_tune_param_idx: The index of target param group. Default is index 0.
47
+ """
48
+
49
+ # assert len(optimizer.param_groups) == len(warmup_epochs) == len(total_epochs) == len(init_lr) == \
50
+ # len(max_lr) == len(final_lr)
51
+
52
+ self.num_lrs = len(optimizer.param_groups)
53
+
54
+ self.optimizer = optimizer
55
+ self.warmup_epochs = np.array([warmup_epochs] * self.num_lrs)
56
+ self.total_epochs = np.array([total_epochs] * self.num_lrs)
57
+ self.steps_per_epoch = steps_per_epoch
58
+ self.init_lr = np.array([init_lr] * self.num_lrs)
59
+ self.max_lr = np.array([max_lr] * self.num_lrs)
60
+ self.final_lr = np.array([final_lr] * self.num_lrs)
61
+ self.lr_coff = np.array([1] * self.num_lrs)
62
+ self.fine_tune_param_idx = fine_tune_param_idx
63
+ self.lr_coff[self.fine_tune_param_idx] = fine_tune_coff
64
+
65
+ self.current_step = 0
66
+ self.lr = [init_lr] * self.num_lrs
67
+ self.warmup_steps = (self.warmup_epochs * self.steps_per_epoch).astype(int)
68
+ self.total_steps = self.total_epochs * self.steps_per_epoch
69
+ self.linear_increment = (self.max_lr - self.init_lr) / self.warmup_steps
70
+
71
+ self.exponential_gamma = (self.final_lr / self.max_lr) ** (1 / (self.total_steps - self.warmup_steps))
72
+ super(NoamLR, self).__init__(optimizer)
73
+
74
+ def get_lr(self) -> List[float]:
75
+ """Gets a list of the current learning rates."""
76
+ return list(self.lr)
77
+
78
+ def step(self, current_step: int = None):
79
+ """
80
+ Updates the learning rate by taking a step.
81
+
82
+ :param current_step: Optionally specify what step to set the learning rate to.
83
+ If None, current_step = self.current_step + 1.
84
+ """
85
+ if current_step is not None:
86
+ self.current_step = current_step
87
+ else:
88
+ self.current_step += 1
89
+ for i in range(self.num_lrs):
90
+ if self.current_step <= self.warmup_steps[i]:
91
+ self.lr[i] = self.init_lr[i] + self.current_step * self.linear_increment[i]
92
+ elif self.current_step <= self.total_steps[i]:
93
+ self.lr[i] = self.max_lr[i] * (self.exponential_gamma[i] ** (self.current_step - self.warmup_steps[i]))
94
+ else: # theoretically this case should never be reached since training should stop at total_steps
95
+ self.lr[i] = self.final_lr[i]
96
+ self.lr[i] *= self.lr_coff[i]
97
+ self.optimizer.param_groups[i]['lr'] = self.lr[i]
grover/util/utils.py ADDED
@@ -0,0 +1,797 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The general utility functions.
3
+ """
4
+ import csv
5
+ import logging
6
+ import os
7
+ import pickle
8
+ import random
9
+ from argparse import Namespace
10
+ from collections import defaultdict
11
+ from logging import Logger
12
+ from typing import List, Set, Tuple, Union, Dict
13
+
14
+ import numpy as np
15
+ import torch
16
+ from rdkit import Chem
17
+ from rdkit.Chem.Scaffolds import MurckoScaffold
18
+ from torch import nn as nn
19
+ from tqdm import tqdm as core_tqdm
20
+
21
+ from grover.data import MoleculeDatapoint, MoleculeDataset, StandardScaler
22
+ from grover.model.models import GroverFpGeneration, GroverFinetuneTask
23
+ from grover.util.nn_utils import initialize_weights
24
+ from grover.util.scheduler import NoamLR
25
+
26
+
27
+ np.float = float
28
+
29
+
30
+ def get_model_args():
31
+ """
32
+ Get model structure related parameters
33
+
34
+ :return: a list containing parameters
35
+ """
36
+ return ['model_type', 'ensemble_size', 'input_layer', 'hidden_size', 'bias', 'depth',
37
+ 'dropout', 'activation', 'undirected', 'ffn_hidden_size', 'ffn_num_layers',
38
+ 'atom_message', 'weight_decay', 'select_by_loss', 'skip_epoch', 'backbone',
39
+ 'embedding_output_type', 'self_attention', 'attn_hidden', 'attn_out', 'dense',
40
+ 'bond_drop_rate', 'distinct_init', 'aug_rate', 'fine_tune_coff', 'nencoders',
41
+ 'dist_coff', 'no_attach_fea', 'coord', "num_attn_head", "num_mt_block",
42
+ ]
43
+
44
+
45
+ def save_features(path: str, features: List[np.ndarray]):
46
+ """
47
+ Saves features to a compressed .npz file with array name "features".
48
+
49
+ :param path: Path to a .npz file where the features will be saved.
50
+ :param features: A list of 1D numpy arrays containing the features for molecules.
51
+ """
52
+ np.savez_compressed(path, features=features)
53
+
54
+
55
+ def load_features(path: str) -> np.ndarray:
56
+ """
57
+ Loads features saved in a variety of formats.
58
+
59
+ Supported formats:
60
+ - .npz compressed (assumes features are saved with name "features")
61
+
62
+ All formats assume that the SMILES strings loaded elsewhere in the code are in the same
63
+ order as the features loaded here.
64
+
65
+ :param path: Path to a file containing features.
66
+ :return: A 2D numpy array of size (num_molecules, features_size) containing the features.
67
+ """
68
+ extension = os.path.splitext(path)[1]
69
+
70
+ if extension == '.npz':
71
+ features = np.load(path)['features']
72
+ else:
73
+ raise ValueError(f'Features path extension {extension} not supported.')
74
+
75
+ return features
76
+
77
+
78
+ class tqdm(core_tqdm):
79
+ def __init__(self, *args, **kwargs):
80
+ kwargs.setdefault("ascii", True)
81
+ super(tqdm, self).__init__(*args, **kwargs)
82
+
83
+
84
+ def get_task_names(path: str, use_compound_names: bool = False) -> List[str]:
85
+ """
86
+ Gets the task names from a data CSV file.
87
+
88
+ :param path: Path to a CSV file.
89
+ :param use_compound_names: Whether file has compound names in addition to smiles strings.
90
+ :return: A list of task names.
91
+ """
92
+ index = 2 if use_compound_names else 1
93
+ task_names = get_header(path)[index:]
94
+
95
+ return task_names
96
+
97
+
98
+ def get_header(path: str) -> List[str]:
99
+ """
100
+ Returns the header of a data CSV file.
101
+
102
+ :param path: Path to a CSV file.
103
+ :return: A list of strings containing the strings in the comma-separated header.
104
+ """
105
+ with open(path) as f:
106
+ header = next(csv.reader(f))
107
+
108
+ return header
109
+
110
+
111
+ def get_num_tasks(path: str) -> int:
112
+ """
113
+ Gets the number of tasks in a data CSV file.
114
+
115
+ :param path: Path to a CSV file.
116
+ :return: The number of tasks.
117
+ """
118
+ return len(get_header(path)) - 1
119
+
120
+
121
+
122
+ def filter_invalid_smiles(data: MoleculeDataset) -> MoleculeDataset:
123
+ """
124
+ Filters out invalid SMILES.
125
+
126
+ :param data: A MoleculeDataset.
127
+ :return: A MoleculeDataset with only valid molecules.
128
+ """
129
+ datapoint_list = []
130
+ for idx, datapoint in enumerate(data):
131
+ if datapoint.smiles == '':
132
+ print(f'invalid smiles {idx}: {datapoint.smiles}')
133
+ continue
134
+ mol = Chem.MolFromSmiles(datapoint.smiles)
135
+ if mol.GetNumHeavyAtoms() == 0:
136
+ print(f'invalid heavy {idx}')
137
+ continue
138
+ datapoint_list.append(datapoint)
139
+ return MoleculeDataset(datapoint_list)
140
+
141
+
142
+ def get_data(path: str,
143
+ skip_invalid_smiles: bool = True,
144
+ args: Namespace = None,
145
+ features_path: List[str] = None,
146
+ max_data_size: int = None,
147
+ use_compound_names: bool = None,
148
+ logger: Logger = None) -> MoleculeDataset:
149
+ """
150
+ Gets smiles string and target values (and optionally compound names if provided) from a CSV file.
151
+
152
+ :param path: Path to a CSV file.
153
+ :param skip_invalid_smiles: Whether to skip and filter out invalid smiles.
154
+ :param args: Arguments.
155
+ :param features_path: A list of paths to files containing features. If provided, it is used
156
+ in place of args.features_path.
157
+ :param max_data_size: The maximum number of data points to load.
158
+ :param use_compound_names: Whether file has compound names in addition to smiles strings.
159
+ :param logger: Logger.
160
+ :return: A MoleculeDataset containing smiles strings and target values along
161
+ with other info such as additional features and compound names when desired.
162
+ """
163
+ debug = logger.debug if logger is not None else print
164
+
165
+ if args is not None:
166
+ # Prefer explicit function arguments but default to args if not provided
167
+ features_path = features_path if features_path is not None else args.features_path
168
+ max_data_size = max_data_size if max_data_size is not None else args.max_data_size
169
+ use_compound_names = use_compound_names if use_compound_names is not None else args.use_compound_names
170
+ else:
171
+ use_compound_names = False
172
+
173
+ max_data_size = max_data_size or float('inf')
174
+
175
+ # Load features
176
+ if features_path is not None:
177
+ features_data = []
178
+ for feat_path in features_path:
179
+ features_data.append(load_features(feat_path)) # each is num_data x num_features
180
+ features_data = np.concatenate(features_data, axis=1)
181
+ args.features_dim = len(features_data[0])
182
+ else:
183
+ features_data = None
184
+ if args is not None:
185
+ args.features_dim = 0
186
+
187
+ skip_smiles = set()
188
+
189
+ # Load data
190
+ with open(path) as f:
191
+ reader = csv.reader(f)
192
+ next(reader) # skip header
193
+
194
+ lines = []
195
+ for line in reader:
196
+ smiles = line[0]
197
+
198
+ if smiles in skip_smiles:
199
+ continue
200
+
201
+ lines.append(line)
202
+
203
+ if len(lines) >= max_data_size:
204
+ break
205
+
206
+ data = MoleculeDataset([
207
+ MoleculeDatapoint(
208
+ line=line,
209
+ args=args,
210
+ features=features_data[i] if features_data is not None else None,
211
+ use_compound_names=use_compound_names
212
+ ) for i, line in tqdm(enumerate(lines), total=len(lines), disable=True)
213
+ ])
214
+
215
+ # Filter out invalid SMILES
216
+ if skip_invalid_smiles:
217
+ original_data_len = len(data)
218
+ data = filter_invalid_smiles(data)
219
+
220
+ if len(data) < original_data_len:
221
+ debug(f'Warning: {original_data_len - len(data)} SMILES are invalid.')
222
+
223
+ return data
224
+
225
+
226
+ def get_data_from_smiles(smiles: List[str], skip_invalid_smiles: bool = True, logger: Logger = None,
227
+ args: Namespace = None) -> MoleculeDataset:
228
+ """
229
+ Converts SMILES to a MoleculeDataset.
230
+
231
+ :param smiles: A list of SMILES strings.
232
+ :param skip_invalid_smiles: Whether to skip and filter out invalid smiles.
233
+ :param logger: Logger.
234
+ :return: A MoleculeDataset with all of the provided SMILES.
235
+ """
236
+ debug = logger.debug if logger is not None else print
237
+
238
+ data = MoleculeDataset([MoleculeDatapoint(line=[smile], args=args) for smile in smiles])
239
+
240
+ # Filter out invalid SMILES
241
+ if skip_invalid_smiles:
242
+ original_data_len = len(data)
243
+ data = filter_invalid_smiles(data)
244
+
245
+ if len(data) < original_data_len:
246
+ debug(f'Warning: {original_data_len - len(data)} SMILES are invalid.')
247
+
248
+ return data
249
+
250
+
251
+ def split_data(data: MoleculeDataset,
252
+ split_type: str = 'random',
253
+ sizes: Tuple[float, float, float] = (0.8, 0.1, 0.1),
254
+ seed: int = 0,
255
+ args: Namespace = None,
256
+ logger: Logger = None) -> Tuple[MoleculeDataset,
257
+ MoleculeDataset,
258
+ MoleculeDataset]:
259
+ """
260
+ Splits data into training, validation, and test splits.
261
+
262
+ :param data: A MoleculeDataset.
263
+ :param split_type: Split type.
264
+ :param sizes: A length-3 tuple with the proportions of data in the
265
+ train, validation, and test sets.
266
+ :param seed: The random seed to use before shuffling data.
267
+ :param args: Namespace of arguments.
268
+ :param logger: A logger.
269
+ :return: A tuple containing the train, validation, and test splits of the data.
270
+ """
271
+ assert len(sizes) == 3 and sum(sizes) == 1
272
+
273
+ if args is not None:
274
+ folds_file, val_fold_index, test_fold_index = \
275
+ args.folds_file, args.val_fold_index, args.test_fold_index
276
+ else:
277
+ folds_file = val_fold_index = test_fold_index = None
278
+
279
+ if split_type == 'crossval':
280
+ index_set = args.crossval_index_sets[args.seed]
281
+ data_split = []
282
+ for split in range(3):
283
+ split_indices = []
284
+ for index in index_set[split]:
285
+ with open(os.path.join(args.crossval_index_dir, f'{index}.pkl'), 'rb') as rf:
286
+ split_indices.extend(pickle.load(rf))
287
+ data_split.append([data[i] for i in split_indices])
288
+ train, val, test = tuple(data_split)
289
+ return MoleculeDataset(train), MoleculeDataset(val), MoleculeDataset(test)
290
+
291
+ elif split_type == 'index_predetermined':
292
+ split_indices = args.crossval_index_sets[args.seed]
293
+ assert len(split_indices) == 3
294
+ data_split = []
295
+ for split in range(3):
296
+ data_split.append([data[i] for i in split_indices[split]])
297
+ train, val, test = tuple(data_split)
298
+ return MoleculeDataset(train), MoleculeDataset(val), MoleculeDataset(test)
299
+
300
+ elif split_type == 'predetermined':
301
+ if not val_fold_index:
302
+ assert sizes[2] == 0 # test set is created separately so use all of the other data for train and val
303
+ assert folds_file is not None
304
+ assert test_fold_index is not None
305
+
306
+ try:
307
+ with open(folds_file, 'rb') as f:
308
+ all_fold_indices = pickle.load(f)
309
+ except UnicodeDecodeError:
310
+ with open(folds_file, 'rb') as f:
311
+ all_fold_indices = pickle.load(f, encoding='latin1') # in case we're loading indices from python2
312
+ # assert len(data) == sum([len(fold_indices) for fold_indices in all_fold_indices])
313
+
314
+ log_scaffold_stats(data, all_fold_indices, logger=logger)
315
+
316
+ folds = [[data[i] for i in fold_indices] for fold_indices in all_fold_indices]
317
+
318
+ test = folds[test_fold_index]
319
+ if val_fold_index is not None:
320
+ val = folds[val_fold_index]
321
+
322
+ train_val = []
323
+ for i in range(len(folds)):
324
+ if i != test_fold_index and (val_fold_index is None or i != val_fold_index):
325
+ train_val.extend(folds[i])
326
+
327
+ if val_fold_index is not None:
328
+ train = train_val
329
+ else:
330
+ random.seed(seed)
331
+ random.shuffle(train_val)
332
+ train_size = int(sizes[0] * len(train_val))
333
+ train = train_val[:train_size]
334
+ val = train_val[train_size:]
335
+
336
+ return MoleculeDataset(train), MoleculeDataset(val), MoleculeDataset(test)
337
+
338
+ elif split_type == 'scaffold_balanced':
339
+ return scaffold_split(data, sizes=sizes, balanced=True, seed=seed, logger=logger)
340
+
341
+ elif split_type == 'random':
342
+ data.shuffle(seed=seed)
343
+
344
+ train_size = int(sizes[0] * len(data))
345
+ train_val_size = int((sizes[0] + sizes[1]) * len(data))
346
+
347
+ train = data[:train_size]
348
+ val = data[train_size:train_val_size]
349
+ test = data[train_val_size:]
350
+
351
+ return MoleculeDataset(train), MoleculeDataset(val), MoleculeDataset(test)
352
+
353
+ else:
354
+ raise ValueError(f'split_type "{split_type}" not supported.')
355
+
356
+
357
+ def get_class_sizes(data: MoleculeDataset) -> List[List[float]]:
358
+ """
359
+ Determines the proportions of the different classes in the classification dataset.
360
+
361
+ :param data: A classification dataset
362
+ :return: A list of lists of class proportions. Each inner list contains the class proportions
363
+ for a task.
364
+ """
365
+ targets = data.targets()
366
+
367
+ # Filter out Nones
368
+ valid_targets = [[] for _ in range(data.num_tasks())]
369
+ for i in range(len(targets)):
370
+ for task_num in range(len(targets[i])):
371
+ if targets[i][task_num] is not None:
372
+ valid_targets[task_num].append(targets[i][task_num])
373
+
374
+ class_sizes = []
375
+ for task_targets in valid_targets:
376
+ # Make sure we're dealing with a binary classification task
377
+ assert set(np.unique(task_targets)) <= {0, 1}
378
+
379
+ try:
380
+ ones = np.count_nonzero(task_targets) / len(task_targets)
381
+ except ZeroDivisionError:
382
+ ones = float('nan')
383
+ print('Warning: class has no targets')
384
+ class_sizes.append([1 - ones, ones])
385
+
386
+ return class_sizes
387
+
388
+
389
+ def generate_scaffold(mol: Union[str, Chem.Mol], include_chirality: bool = False) -> str:
390
+ """
391
+ Compute the Bemis-Murcko scaffold for a SMILES string.
392
+
393
+ :param mol: A smiles string or an RDKit molecule.
394
+ :param include_chirality: Whether to include chirality.
395
+ :return:
396
+ """
397
+ mol = Chem.MolFromSmiles(mol) if type(mol) == str else mol
398
+ scaffold = MurckoScaffold.MurckoScaffoldSmiles(mol=mol, includeChirality=include_chirality)
399
+
400
+ return scaffold
401
+
402
+
403
+ def scaffold_to_smiles(mols: Union[List[str], List[Chem.Mol]],
404
+ use_indices: bool = False) -> Dict[str, Union[Set[str], Set[int]]]:
405
+ """
406
+ Computes scaffold for each smiles string and returns a mapping from scaffolds to sets of smiles.
407
+
408
+ :param mols: A list of smiles strings or RDKit molecules.
409
+ :param use_indices: Whether to map to the smiles' index in all_smiles rather than mapping
410
+ to the smiles string itself. This is necessary if there are duplicate smiles.
411
+ :return: A dictionary mapping each unique scaffold to all smiles (or smiles indices) which have that scaffold.
412
+ """
413
+ scaffolds = defaultdict(set)
414
+ for i, mol in tqdm(enumerate(mols), total=len(mols)):
415
+ scaffold = generate_scaffold(mol)
416
+ if use_indices:
417
+ scaffolds[scaffold].add(i)
418
+ else:
419
+ scaffolds[scaffold].add(mol)
420
+
421
+ return scaffolds
422
+
423
+
424
+ def scaffold_split(data: MoleculeDataset,
425
+ sizes: Tuple[float, float, float] = (0.8, 0.1, 0.1),
426
+ balanced: bool = False,
427
+ seed: int = 0,
428
+ logger: logging.Logger = None) -> Tuple[MoleculeDataset,
429
+ MoleculeDataset,
430
+ MoleculeDataset]:
431
+ """
432
+ Split a dataset by scaffold so that no molecules sharing a scaffold are in the same split.
433
+
434
+ :param data: A MoleculeDataset.
435
+ :param sizes: A length-3 tuple with the proportions of data in the
436
+ train, validation, and test sets.
437
+ :param balanced: Try to balance sizes of scaffolds in each set, rather than just putting smallest in test set.
438
+ :param seed: Seed for shuffling when doing balanced splitting.
439
+ :param logger: A logger.
440
+ :return: A tuple containing the train, validation, and test splits of the data.
441
+ """
442
+ assert sum(sizes) == 1
443
+
444
+ # Split
445
+ train_size, val_size, test_size = sizes[0] * len(data), sizes[1] * len(data), sizes[2] * len(data)
446
+ train, val, test = [], [], []
447
+ train_scaffold_count, val_scaffold_count, test_scaffold_count = 0, 0, 0
448
+
449
+ # Map from scaffold to index in the data
450
+ scaffold_to_indices = scaffold_to_smiles(data.smiles(), use_indices=True)
451
+
452
+ if balanced: # Put stuff that's bigger than half the val/test size into train, rest just order randomly
453
+ index_sets = list(scaffold_to_indices.values())
454
+ big_index_sets = []
455
+ small_index_sets = []
456
+ for index_set in index_sets:
457
+ if len(index_set) > val_size / 2 or len(index_set) > test_size / 2:
458
+ big_index_sets.append(index_set)
459
+ else:
460
+ small_index_sets.append(index_set)
461
+ random.seed(seed)
462
+ random.shuffle(big_index_sets)
463
+ random.shuffle(small_index_sets)
464
+ index_sets = big_index_sets + small_index_sets
465
+ else: # Sort from largest to smallest scaffold sets
466
+ index_sets = sorted(list(scaffold_to_indices.values()),
467
+ key=lambda index_set: len(index_set),
468
+ reverse=True)
469
+
470
+ for index_set in index_sets:
471
+ if len(train) + len(index_set) <= train_size:
472
+ train += index_set
473
+ train_scaffold_count += 1
474
+ elif len(val) + len(index_set) <= val_size:
475
+ val += index_set
476
+ val_scaffold_count += 1
477
+ else:
478
+ test += index_set
479
+ test_scaffold_count += 1
480
+
481
+ if logger is not None:
482
+ logger.debug(f'Total scaffolds = {len(scaffold_to_indices):,} | '
483
+ f'train scaffolds = {train_scaffold_count:,} | '
484
+ f'val scaffolds = {val_scaffold_count:,} | '
485
+ f'test scaffolds = {test_scaffold_count:,}')
486
+
487
+ log_scaffold_stats(data, index_sets, logger=logger)
488
+
489
+ # Map from indices to data
490
+ train = [data[i] for i in train]
491
+ val = [data[i] for i in val]
492
+ test = [data[i] for i in test]
493
+
494
+ return MoleculeDataset(train), MoleculeDataset(val), MoleculeDataset(test)
495
+
496
+
497
+ def log_scaffold_stats(data: MoleculeDataset,
498
+ index_sets: List[Set[int]],
499
+ num_scaffolds: int = 10,
500
+ num_labels: int = 20,
501
+ logger: logging.Logger = None) -> List[Tuple[List[float], List[int]]]:
502
+ """
503
+ Logs and returns statistics about counts and average target values in molecular scaffolds.
504
+
505
+ :param data: A MoleculeDataset.
506
+ :param index_sets: A list of sets of indices representing splits of the data.
507
+ :param num_scaffolds: The number of scaffolds about which to display statistics.
508
+ :param num_labels: The number of labels about which to display statistics.
509
+ :param logger: A Logger.
510
+ :return: A list of tuples where each tuple contains a list of average target values
511
+ across the first num_labels labels and a list of the number of non-zero values for
512
+ the first num_scaffolds scaffolds, sorted in decreasing order of scaffold frequency.
513
+ """
514
+ # print some statistics about scaffolds
515
+ target_avgs = []
516
+ counts = []
517
+ for index_set in index_sets:
518
+ data_set = [data[i] for i in index_set]
519
+ targets = [d.targets for d in data_set]
520
+ targets = np.array(targets, dtype=np.float)
521
+ target_avgs.append(np.nanmean(targets, axis=0))
522
+ counts.append(np.count_nonzero(~np.isnan(targets), axis=0))
523
+ stats = [(target_avgs[i][:num_labels], counts[i][:num_labels]) for i in range(min(num_scaffolds, len(target_avgs)))]
524
+
525
+ if logger is not None:
526
+ logger.debug('Label averages per scaffold, in decreasing order of scaffold frequency,'
527
+ f'capped at {num_scaffolds} scaffolds and {num_labels} labels: {stats}')
528
+
529
+ return stats
530
+
531
+
532
+ def makedirs(path: str, isfile: bool = False):
533
+ """
534
+ Creates a directory given a path to either a directory or file.
535
+
536
+ If a directory is provided, creates that directory. If a file is provided (i.e. isfiled == True),
537
+ creates the parent directory for that file.
538
+
539
+ :param path: Path to a directory or file.
540
+ :param isfile: Whether the provided path is a directory or file.
541
+ """
542
+ if isfile:
543
+ path = os.path.dirname(path)
544
+ if path != '':
545
+ os.makedirs(path, exist_ok=True)
546
+
547
+
548
+ def load_args(path: str) -> Namespace:
549
+ """
550
+ Loads the arguments a model was trained with.
551
+
552
+ :param path: Path where model checkpoint is saved.
553
+ :return: The arguments Namespace that the model was trained with.
554
+ """
555
+ return torch.load(path, map_location=lambda storage, loc: storage)['args']
556
+
557
+
558
+
559
+ def get_ffn_layer_id(model: GroverFinetuneTask):
560
+ """
561
+ Get the ffn layer id for GroverFinetune Task. (Adhoc!)
562
+ :param model:
563
+ :return:
564
+ """
565
+ return [id(x) for x in model.state_dict() if "grover" not in x and "ffn" in x]
566
+
567
+
568
+ def build_optimizer(model: nn.Module, args: Namespace):
569
+ """
570
+ Builds an Optimizer.
571
+
572
+ :param model: The model to optimize.
573
+ :param args: Arguments.
574
+ :return: An initialized Optimizer.
575
+ """
576
+
577
+ # Only adjust the learning rate for the GroverFinetuneTask.
578
+ if type(model) == GroverFinetuneTask:
579
+ ffn_params = get_ffn_layer_id(model)
580
+ else:
581
+ # if not, init adam optimizer normally.
582
+ return torch.optim.Adam(model.parameters(), lr=args.init_lr, weight_decay=args.weight_decay)
583
+ base_params = filter(lambda p: id(p) not in ffn_params, model.parameters())
584
+ ffn_params = filter(lambda p: id(p) in ffn_params, model.parameters())
585
+ if args.fine_tune_coff == 0:
586
+ for param in base_params:
587
+ param.requires_grad = False
588
+
589
+ optimizer = torch.optim.Adam([
590
+ {'params': base_params, 'lr': args.init_lr * args.fine_tune_coff},
591
+ {'params': ffn_params, 'lr': args.init_lr}
592
+ ], lr=args.init_lr, weight_decay=args.weight_decay)
593
+
594
+ return optimizer
595
+
596
+
597
+ def build_lr_scheduler(optimizer, args: Namespace, total_epochs: List[int] = None):
598
+ """
599
+ Builds a learning rate scheduler.
600
+
601
+ :param optimizer: The Optimizer whose learning rate will be scheduled.
602
+ :param args: Arguments.
603
+ :param total_epochs: The total number of epochs for which the model will be task.
604
+ :return: An initialized learning rate scheduler.
605
+ """
606
+
607
+ # Learning rate scheduler
608
+ # Divide the parameter into two groups for the finetune.
609
+ return NoamLR(
610
+ optimizer=optimizer,
611
+ warmup_epochs=args.warmup_epochs,
612
+ total_epochs=args.epochs,
613
+ steps_per_epoch=args.train_data_size // args.batch_size,
614
+ init_lr=args.init_lr,
615
+ max_lr=args.max_lr,
616
+ final_lr=args.final_lr,
617
+ fine_tune_coff=args.fine_tune_coff
618
+ )
619
+
620
+
621
+ def create_logger(name: str, save_dir: str = None, quiet: bool = False) -> logging.Logger:
622
+ """
623
+ Creates a logger with a stream handler and two file handlers.
624
+
625
+ The stream handler prints to the screen depending on the value of `quiet`.
626
+ One file handler (verbose.log) saves all logs, the other (quiet.log) only saves important info.
627
+
628
+ :param name: The name of the logger.
629
+ :param save_dir: The directory in which to save the logs.
630
+ :param quiet: Whether the stream handler should be quiet (i.e. print only important info).
631
+ :return: The logger.
632
+ """
633
+ logger = logging.getLogger(name)
634
+ logger.setLevel(logging.DEBUG)
635
+ logger.propagate = False
636
+
637
+ # Set logger depending on desired verbosity
638
+ ch = logging.StreamHandler()
639
+ if quiet:
640
+ ch.setLevel(logging.INFO)
641
+ else:
642
+ ch.setLevel(logging.DEBUG)
643
+ logger.addHandler(ch)
644
+
645
+ if save_dir is not None:
646
+ makedirs(save_dir)
647
+ fh_v = logging.FileHandler(os.path.join(save_dir, 'verbose.log'))
648
+ fh_v.setLevel(logging.DEBUG)
649
+ fh_q = logging.FileHandler(os.path.join(save_dir, 'quiet.log'))
650
+ fh_q.setLevel(logging.INFO)
651
+
652
+ logger.addHandler(fh_v)
653
+ logger.addHandler(fh_q)
654
+
655
+ return logger
656
+
657
+
658
+ def load_checkpoint(path: str,
659
+ current_args: Namespace = None,
660
+ cuda: bool = None,
661
+ logger: logging.Logger = None):
662
+ """
663
+ Loads a model checkpoint.
664
+
665
+ :param path: Path where checkpoint is saved.
666
+ :param current_args: The current arguments. Replaces the arguments loaded from the checkpoint if provided.
667
+ :param cuda: Whether to move model to cuda.
668
+ :param logger: A logger.
669
+ :return: The loaded MPNN.
670
+ """
671
+ debug = logger.debug if logger is not None else print
672
+
673
+ # Load model and args
674
+ state = torch.load(path, map_location=lambda storage, loc: storage)
675
+ args, loaded_state_dict = state['args'], state['state_dict']
676
+ model_ralated_args = get_model_args()
677
+
678
+ if current_args is not None:
679
+ for key, value in vars(args).items():
680
+ if key in model_ralated_args:
681
+ setattr(current_args, key, value)
682
+ else:
683
+ current_args = args
684
+
685
+ # args.cuda = cuda if cuda is not None else args.cuda
686
+
687
+ # Build model
688
+ model = build_model(current_args)
689
+ model_state_dict = model.state_dict()
690
+
691
+ # Skip missing parameters and parameters of mismatched size
692
+ pretrained_state_dict = {}
693
+ for param_name in loaded_state_dict.keys():
694
+ new_param_name = param_name
695
+ if new_param_name not in model_state_dict:
696
+ debug(f'Pretrained parameter "{param_name}" cannot be found in model parameters.')
697
+ elif model_state_dict[new_param_name].shape != loaded_state_dict[param_name].shape:
698
+ debug(f'Pretrained parameter "{param_name}" '
699
+ f'of shape {loaded_state_dict[param_name].shape} does not match corresponding '
700
+ f'model parameter of shape {model_state_dict[new_param_name].shape}.')
701
+ else:
702
+ debug(f'Loading pretrained parameter "{param_name}".')
703
+ pretrained_state_dict[new_param_name] = loaded_state_dict[param_name]
704
+ # Load pretrained weights
705
+ model_state_dict.update(pretrained_state_dict)
706
+ model.load_state_dict(model_state_dict)
707
+
708
+ if cuda:
709
+ debug('Moving model to cuda')
710
+ model = model.cuda()
711
+
712
+ return model
713
+
714
+
715
+ def get_loss_func(args: Namespace, model=None):
716
+ """
717
+ Gets the loss function corresponding to a given dataset type.
718
+
719
+ :param args: Namespace containing the dataset type ("classification" or "regression").
720
+ :return: A PyTorch loss function.
721
+ """
722
+ if hasattr(model, "get_loss_func"):
723
+ return model.get_loss_func(args)
724
+ if args.dataset_type == 'classification':
725
+ return nn.BCEWithLogitsLoss(reduction='none')
726
+ if args.dataset_type == 'regression':
727
+ return nn.MSELoss(reduction='none')
728
+
729
+ raise ValueError(f'Dataset type "{args.dataset_type}" not supported.')
730
+
731
+
732
+ def load_scalars(path: str):
733
+ """
734
+ Loads the scalars a model was trained with.
735
+
736
+ :param path: Path where model checkpoint is saved.
737
+ :return: A tuple with the data scaler and the features scaler.
738
+ """
739
+ state = torch.load(path, map_location=lambda storage, loc: storage)
740
+
741
+ scaler = StandardScaler(state['data_scaler']['means'],
742
+ state['data_scaler']['stds']) if state['data_scaler'] is not None else None
743
+ features_scaler = StandardScaler(state['features_scaler']['means'],
744
+ state['features_scaler']['stds'],
745
+ replace_nan_token=0) if state['features_scaler'] is not None else None
746
+
747
+ return scaler, features_scaler
748
+
749
+
750
+ def save_checkpoint(path: str,
751
+ model,
752
+ scaler,
753
+ features_scaler,
754
+ args: Namespace = None):
755
+ """
756
+ Saves a model checkpoint.
757
+
758
+ :param model: A MPNN.
759
+ :param scaler: A StandardScaler fitted on the data.
760
+ :param features_scaler: A StandardScaler fitted on the features.
761
+ :param args: Arguments namespace.
762
+ :param path: Path where checkpoint will be saved.
763
+ """
764
+ state = {
765
+ 'args': args,
766
+ 'state_dict': model.state_dict(),
767
+ 'data_scaler': {
768
+ 'means': scaler.means,
769
+ 'stds': scaler.stds
770
+ } if scaler is not None else None,
771
+ 'features_scaler': {
772
+ 'means': features_scaler.means,
773
+ 'stds': features_scaler.stds
774
+ } if features_scaler is not None else None
775
+ }
776
+ torch.save(state, path)
777
+
778
+
779
+ def build_model(args: Namespace, model_idx=0):
780
+ """
781
+ Builds a MPNN, which is a message passing neural network + feed-forward layers.
782
+
783
+ :param args: Arguments.
784
+ :return: A MPNN containing the MPN encoder along with final linear layers with parameters initialized.
785
+ """
786
+ if hasattr(args, 'num_tasks'):
787
+ args.output_size = args.num_tasks
788
+ else:
789
+ args.output_size = 1
790
+
791
+ if args.parser_name == "fingerprint":
792
+ model = GroverFpGeneration(args)
793
+ else:
794
+ # finetune and evaluation case.
795
+ model = GroverFinetuneTask(args)
796
+ initialize_weights(model=model, model_idx=model_idx)
797
+ return model
prepare_data.py CHANGED
@@ -17,4 +17,5 @@ val_path = "./tox21/tox21_validation.csv"
17
  train_path_clean = "./tox21/tox21_train_clean.csv"
18
  val_path_clean = "./tox21/tox21_validation_clean.csv"
19
 
20
- prepare_data(test_path, test_path_clean, "./tox21/valid_mask_test.npy")
 
 
17
  train_path_clean = "./tox21/tox21_train_clean.csv"
18
  val_path_clean = "./tox21/tox21_validation_clean.csv"
19
 
20
+ prepare_data(train_path, train_path_clean, "./tox21/valid_mask_train.npy")
21
+ prepare_data(val_path, val_path_clean, "./tox21/valid_mask_val.npy")
requirements.txt ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py @ file:///home/conda/feedstock_root/build_artifacts/absl-py_1751547525079/work
2
+ aiohappyeyeballs==2.6.1
3
+ aiohttp==3.13.2
4
+ aiosignal==1.4.0
5
+ anyio==4.12.0
6
+ async-timeout==5.0.1
7
+ attrs==25.4.0
8
+ Brotli @ file:///home/conda/feedstock_root/build_artifacts/brotli-split_1749229842835/work
9
+ certifi @ file:///home/conda/feedstock_root/build_artifacts/certifi_1754231422783/work/certifi
10
+ cffi @ file:///home/conda/feedstock_root/build_artifacts/cffi_1725571112467/work
11
+ charset-normalizer @ file:///home/conda/feedstock_root/build_artifacts/charset-normalizer_1754767332901/work
12
+ click==8.1.8
13
+ colorama @ file:///home/conda/feedstock_root/build_artifacts/colorama_1733218098505/work
14
+ datasets==4.4.1
15
+ descriptastorus==2.8.0
16
+ dill==0.4.0
17
+ exceptiongroup==1.3.1
18
+ filelock @ file:///home/conda/feedstock_root/build_artifacts/filelock_1755216263872/work
19
+ frozenlist==1.8.0
20
+ fsspec==2025.10.0
21
+ git-filter-repo @ file:///home/conda/feedstock_root/build_artifacts/git-filter-repo_1735551402582/work
22
+ gmpy2 @ file:///home/conda/feedstock_root/build_artifacts/gmpy2_1745509363867/work
23
+ grpcio @ file:///home/conda/feedstock_root/build_artifacts/grpc-split_1754634529307/work
24
+ h11==0.16.0
25
+ h2 @ file:///home/conda/feedstock_root/build_artifacts/h2_1738578511449/work
26
+ hf-xet==1.2.0
27
+ hpack @ file:///home/conda/feedstock_root/build_artifacts/hpack_1737618293087/work
28
+ httpcore==1.0.9
29
+ httpx==0.28.1
30
+ huggingface_hub==1.1.7
31
+ hyperframe @ file:///home/conda/feedstock_root/build_artifacts/hyperframe_1737618333194/work
32
+ idna @ file:///home/conda/feedstock_root/build_artifacts/idna_1733211830134/work
33
+ importlib_metadata @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_importlib-metadata_1747934053/work
34
+ Jinja2 @ file:///home/conda/feedstock_root/build_artifacts/jinja2_1741263328855/work
35
+ joblib @ file:///home/conda/feedstock_root/build_artifacts/joblib_1748019130050/work
36
+ Markdown @ file:///home/conda/feedstock_root/build_artifacts/markdown_1750360292101/work
37
+ MarkupSafe @ file:///home/conda/feedstock_root/build_artifacts/markupsafe_1733219680183/work
38
+ mpmath @ file:///home/conda/feedstock_root/build_artifacts/mpmath_1733302684489/work
39
+ multidict==6.7.0
40
+ multiprocess==0.70.18
41
+ networkx @ file:///home/conda/feedstock_root/build_artifacts/networkx_1698504735452/work
42
+ numpy @ file:///home/conda/feedstock_root/build_artifacts/numpy_1707225342954/work/dist/numpy-1.26.4-cp39-cp39-linux_x86_64.whl#sha256=c799942b5898f6e6c60264d1663a6469a475290e758c654aeeb78e2596463abd
43
+ packaging @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_packaging_1745345660/work
44
+ pandas @ file:///home/conda/feedstock_root/build_artifacts/pandas_1752081702369/work
45
+ pandas_flavor==0.7.0
46
+ pillow @ file:///home/conda/feedstock_root/build_artifacts/pillow_1751482006338/work
47
+ propcache==0.4.1
48
+ protobuf @ file:///home/conda/feedstock_root/build_artifacts/protobuf_1751668301193/work/bazel-bin/python/dist/protobuf-6.31.1-cp39-abi3-linux_x86_64.whl#sha256=91a4a00a210b50fbca2de99b20633990d9f00a443829a9badc867ec313e0fecc
49
+ pyarrow==21.0.0
50
+ pycairo==1.28.0
51
+ pycparser @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_pycparser_1733195786/work
52
+ PySocks @ file:///home/conda/feedstock_root/build_artifacts/pysocks_1733217236728/work
53
+ python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_python-dateutil_1751104122/work
54
+ pytz @ file:///home/conda/feedstock_root/build_artifacts/pytz_1742920838005/work
55
+ PyYAML @ file:///home/conda/feedstock_root/build_artifacts/pyyaml_1737454647378/work
56
+ rdkit==2025.9.1
57
+ rdkit-pypi==2022.9.5
58
+ requests @ file:///home/conda/feedstock_root/build_artifacts/requests_1755614211359/work
59
+ scikit-learn @ file:///home/conda/feedstock_root/build_artifacts/scikit-learn_1736496755362/work/dist/scikit_learn-1.6.1-cp39-cp39-linux_x86_64.whl#sha256=e8f978e37bb47e04e1337a63f75697b723d6d25f58e477734555faed033884ba
60
+ scipy==1.10.1
61
+ shellingham==1.5.4
62
+ six @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_six_1753199211/work
63
+ sympy @ file:///home/conda/feedstock_root/build_artifacts/sympy_1745946051654/work
64
+ tensorboard @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_tensorboard_1752825441/work/tensorboard-2.20.0-py3-none-any.whl#sha256=9dc9f978cb84c0723acf9a345d96c184f0293d18f166bb8d59ee098e6cfaaba6
65
+ tensorboard_data_server @ file:///home/conda/feedstock_root/build_artifacts/tensorboard-data-server_1728639721704/work/tensorboard_data_server-0.7.0-py3-none-manylinux2014_x86_64.whl#sha256=3b7dc7cf17b685028f955453a839cc9b2871818de53e7911eae158fe6b3a80cf
66
+ threadpoolctl @ file:///home/conda/feedstock_root/build_artifacts/threadpoolctl_1741878222898/work
67
+ torch==2.4.0
68
+ torchvision==0.19.0
69
+ tqdm @ file:///home/conda/feedstock_root/build_artifacts/tqdm_1735661334605/work
70
+ triton==3.0.0
71
+ typer-slim==0.20.0
72
+ typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_typing_extensions_1751643513/work
73
+ tzdata @ file:///home/conda/feedstock_root/build_artifacts/python-tzdata_1742745135198/work
74
+ urllib3 @ file:///home/conda/feedstock_root/build_artifacts/urllib3_1750271362675/work
75
+ Werkzeug @ file:///home/conda/feedstock_root/build_artifacts/werkzeug_1733160440960/work
76
+ xarray==2024.7.0
77
+ xxhash==3.6.0
78
+ yarl==1.22.0
79
+ zipp @ file:///home/conda/feedstock_root/build_artifacts/zipp_1749421620841/work
80
+ zstandard==0.23.0
81
+ fastapi
82
+ uvicorn[standard]
scripts/__init__.py ADDED
File without changes
scripts/build_vocab.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The vocabulary building scripts.
3
+ """
4
+ import os
5
+
6
+ from grover.data.torchvocab import MolVocab
7
+
8
+
9
+ def build():
10
+ import argparse
11
+ parser = argparse.ArgumentParser()
12
+ parser.add_argument('--data_path', default="../../dataset/grover_new_dataset/druglike_merged_refine2.csv", type=str)
13
+ parser.add_argument('--vocab_save_folder', default="../../dataset/grover_new_dataset", type=str)
14
+ parser.add_argument('--dataset_name', type=str, default=None,
15
+ help="Will be the first part of the vocab file name. If it is None,"
16
+ "the vocab files will be: atom_vocab.pkl and bond_vocab.pkl")
17
+ parser.add_argument('--vocab_max_size', type=int, default=None)
18
+ parser.add_argument('--vocab_min_freq', type=int, default=1)
19
+ args = parser.parse_args()
20
+
21
+ # fin = open(args.data_path, 'r')
22
+ # lines = fin.readlines()
23
+
24
+ for vocab_type in ['atom', 'bond']:
25
+ vocab_file = f"{vocab_type}_vocab.pkl"
26
+ if args.dataset_name is not None:
27
+ vocab_file = args.dataset_name + '_' + vocab_file
28
+ vocab_save_path = os.path.join(args.vocab_save_folder, vocab_file)
29
+
30
+ os.makedirs(os.path.dirname(vocab_save_path), exist_ok=True)
31
+ vocab = MolVocab(file_path=args.data_path,
32
+ max_size=args.vocab_max_size,
33
+ min_freq=args.vocab_min_freq,
34
+ num_workers=100,
35
+ vocab_type=vocab_type)
36
+ print(f"{vocab_type} vocab size", len(vocab))
37
+ vocab.save_vocab(vocab_save_path)
38
+
39
+
40
+ if __name__ == '__main__':
41
+ build()
scripts/save_features.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Computes and saves molecular features for a dataset.
3
+ """
4
+ import os
5
+ import shutil
6
+ import sys
7
+ from argparse import ArgumentParser, Namespace
8
+ from multiprocessing import Pool
9
+ from typing import List, Tuple
10
+
11
+ from tqdm import tqdm
12
+
13
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
14
+
15
+ from grover.util.utils import get_data, makedirs, load_features, save_features
16
+ from grover.data.molfeaturegenerator import get_available_features_generators, \
17
+ get_features_generator
18
+ from grover.data.task_labels import rdkit_functional_group_label_features_generator
19
+
20
+
21
+
22
+ def load_temp(temp_dir: str) -> Tuple[List[List[float]], int]:
23
+ """
24
+ Loads all features saved as .npz files in load_dir.
25
+
26
+ Assumes temporary files are named in order 0.npz, 1.npz, ...
27
+
28
+ :param temp_dir: Directory in which temporary .npz files containing features are stored.
29
+ :return: A tuple with a list of molecule features, where each molecule's features is a list of floats,
30
+ and the number of temporary files.
31
+ """
32
+ features = []
33
+ temp_num = 0
34
+ temp_path = os.path.join(temp_dir, f'{temp_num}.npz')
35
+
36
+ while os.path.exists(temp_path):
37
+ features.extend(load_features(temp_path))
38
+ temp_num += 1
39
+ temp_path = os.path.join(temp_dir, f'{temp_num}.npz')
40
+
41
+ return features, temp_num
42
+
43
+
44
+ def generate_and_save_features(args: Namespace):
45
+ """
46
+ Computes and saves features for a dataset of molecules as a 2D array in a .npz file.
47
+
48
+ :param args: Arguments.
49
+ """
50
+ # Create directory for save_path
51
+ makedirs(args.save_path, isfile=True)
52
+
53
+ # Get data and features function
54
+ data = get_data(path=args.data_path, max_data_size=None)
55
+ features_generator = get_features_generator(args.features_generator)
56
+ temp_save_dir = args.save_path + '_temp'
57
+
58
+ # Load partially complete data
59
+ if args.restart:
60
+ if os.path.exists(args.save_path):
61
+ os.remove(args.save_path)
62
+ if os.path.exists(temp_save_dir):
63
+ shutil.rmtree(temp_save_dir)
64
+ else:
65
+ if os.path.exists(args.save_path):
66
+ raise ValueError(f'"{args.save_path}" already exists and args.restart is False.')
67
+
68
+ if os.path.exists(temp_save_dir):
69
+ features, temp_num = load_temp(temp_save_dir)
70
+
71
+ if not os.path.exists(temp_save_dir):
72
+ makedirs(temp_save_dir)
73
+ features, temp_num = [], 0
74
+
75
+ # Build features map function
76
+ data = data[len(features):] # restrict to data for which features have not been computed yet
77
+ mols = (d.smiles for d in data)
78
+
79
+ if args.sequential:
80
+ features_map = map(features_generator, mols)
81
+ else:
82
+ features_map = Pool(30).imap(features_generator, mols)
83
+
84
+ # Get features
85
+ temp_features = []
86
+ for i, feats in tqdm(enumerate(features_map), total=len(data)):
87
+ temp_features.append(feats)
88
+
89
+ # Save temporary features every save_frequency
90
+ if (i > 0 and (i + 1) % args.save_frequency == 0) or i == len(data) - 1:
91
+ save_features(os.path.join(temp_save_dir, f'{temp_num}.npz'), temp_features)
92
+ features.extend(temp_features)
93
+ temp_features = []
94
+ temp_num += 1
95
+
96
+ try:
97
+ # Save all features
98
+ save_features(args.save_path, features)
99
+
100
+ # Remove temporary features
101
+ shutil.rmtree(temp_save_dir)
102
+ except OverflowError:
103
+ print('Features array is too large to save as a single file. Instead keeping features as a directory of files.')
104
+
105
+
106
+ if __name__ == '__main__':
107
+
108
+ parser = ArgumentParser()
109
+ parser.add_argument('--data_path', type=str, required=True,
110
+ help='Path to data CSV')
111
+ parser.add_argument('--features_generator', type=str, required=True,
112
+ choices=get_available_features_generators(),
113
+ help='Type of features to generate')
114
+ parser.add_argument('--save_path', type=str, default=None,
115
+ help='Path to .npz file where features will be saved as a compressed numpy archive')
116
+ parser.add_argument('--save_frequency', type=int, default=10000,
117
+ help='Frequency with which to save the features')
118
+ parser.add_argument('--restart', action='store_true', default=False,
119
+ help='Whether to not load partially complete featurization and instead start from scratch')
120
+ parser.add_argument('--max_data_size', type=int,
121
+ help='Maximum number of data points to load')
122
+ parser.add_argument('--sequential', action='store_true', default=False,
123
+ help='Whether to task sequentially rather than in parallel')
124
+ args = parser.parse_args()
125
+ if args.save_path is None:
126
+ args.save_path = args.data_path.split('csv')[0] + 'npz'
127
+ generate_and_save_features(args)
scripts/split_data.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The data splitting script for pretraining.
3
+ """
4
+ import os
5
+ from argparse import ArgumentParser
6
+ import csv
7
+ import shutil
8
+ import numpy as np
9
+
10
+
11
+ import grover.util.utils as fea_utils
12
+
13
+
14
+ parser = ArgumentParser()
15
+ parser.add_argument("--data_path", default="../drug_data/grover_data/delaneyfreesolvlipo.csv")
16
+ parser.add_argument("--features_path", default="../drug_data/grover_data/delaneyfreesolvlipo_molbert.npz")
17
+ parser.add_argument("--sample_per_file", type=int, default=1000)
18
+ parser.add_argument("--output_path", default="../drug_data/grover_data/delaneyfreesolvlipo")
19
+
20
+
21
+ def load_smiles(data_path):
22
+ with open(data_path) as f:
23
+ reader = csv.reader(f)
24
+ header = next(reader)
25
+ res = []
26
+ for line in reader:
27
+ res.append(line)
28
+ return res, header
29
+
30
+
31
+ def load_features(data_path):
32
+ fea = fea_utils.load_features(data_path)
33
+ return fea
34
+
35
+
36
+ def save_smiles(data_path, index, data, header):
37
+ fn = os.path.join(data_path, str(index) + ".csv")
38
+ with open(fn, "w") as f:
39
+ fw = csv.writer(f)
40
+ fw.writerow(header)
41
+ for d in data:
42
+ fw.writerow(d)
43
+
44
+
45
+ def save_features(data_path, index, data):
46
+ fn = os.path.join(data_path, str(index) + ".npz")
47
+ np.savez_compressed(fn, features=data)
48
+
49
+
50
+ def run():
51
+ args = parser.parse_args()
52
+ res, header = load_smiles(data_path=args.data_path)
53
+ fea = load_features(data_path=args.features_path)
54
+ assert len(res) == fea.shape[0]
55
+
56
+ n_graphs = len(res)
57
+ perm = np.random.permutation(n_graphs)
58
+
59
+ nfold = int(n_graphs / args.sample_per_file + 1)
60
+ print("Number of files: %d" % nfold)
61
+ if os.path.exists(args.output_path):
62
+ shutil.rmtree(args.output_path)
63
+ os.makedirs(args.output_path, exist_ok=True)
64
+ graph_path = os.path.join(args.output_path, "graph")
65
+ fea_path = os.path.join(args.output_path, "feature")
66
+ os.makedirs(graph_path, exist_ok=True)
67
+ os.makedirs(fea_path, exist_ok=True)
68
+
69
+ for i in range(nfold):
70
+ sidx = i * args.sample_per_file
71
+ eidx = min((i + 1) * args.sample_per_file, n_graphs)
72
+ indexes = perm[sidx:eidx]
73
+ sres = [res[j] for j in indexes]
74
+ sfea = fea[indexes]
75
+ save_smiles(graph_path, i, sres, header)
76
+ save_features(fea_path, i, sfea)
77
+
78
+ summary_path = os.path.join(args.output_path, "summary.txt")
79
+ summary_fout = open(summary_path, 'w')
80
+ summary_fout.write("n_files:%d\n" % nfold)
81
+ summary_fout.write("n_samples:%d\n" % n_graphs)
82
+ summary_fout.write("sample_per_file:%d\n" % args.sample_per_file)
83
+ summary_fout.close()
84
+
85
+
86
+ if __name__ == "__main__":
87
+ run()
src/commands.py CHANGED
@@ -16,6 +16,16 @@ def generate_features(data_path, save_path):
16
  f"--restart"
17
  )
18
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  def finetune(train_path, val_path, train_features_path, val_features_path,
21
  save_dir, checkpoint_path, args
 
16
  f"--restart"
17
  )
18
 
19
+ def predict_from_csv(data_path, features_path, checkpoint_dir, output_path):
20
+ predict_cmd = (
21
+ f"python main.py predict "
22
+ f"--data_path {data_path} "
23
+ f"--features_path {features_path} "
24
+ f"--checkpoint_dir {checkpoint_dir} "
25
+ f"--no_features_scaling "
26
+ f"--output {output_path}"
27
+ )
28
+
29
 
30
  def finetune(train_path, val_path, train_features_path, val_features_path,
31
  save_dir, checkpoint_path, args
task/__init__.py ADDED
File without changes
task/cross_validate.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The cross validation function for finetuning.
3
+ This implementation is adapted from
4
+ https://github.com/chemprop/chemprop/blob/master/chemprop/train/cross_validate.py
5
+ """
6
+ import os
7
+ import time
8
+ from argparse import Namespace
9
+ from logging import Logger
10
+ from typing import Tuple
11
+
12
+ import numpy as np
13
+
14
+ from grover.util.utils import get_task_names
15
+ from grover.util.utils import makedirs
16
+ from task.run_evaluation import run_evaluation
17
+ from task.train import run_training
18
+
19
+
20
+ def cross_validate(args: Namespace, logger: Logger = None) -> Tuple[float, float]:
21
+ """
22
+ k-fold cross validation.
23
+
24
+ :return: A tuple of mean_score and std_score.
25
+ """
26
+ info = logger.info if logger is not None else print
27
+
28
+ # Initialize relevant variables
29
+ init_seed = args.seed
30
+ save_dir = args.save_dir
31
+ task_names = get_task_names(args.data_path)
32
+
33
+ # Run training with different random seeds for each fold
34
+ all_scores = []
35
+ time_start = time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime())
36
+ for fold_num in range(args.num_folds):
37
+ info(f'Fold {fold_num}')
38
+ args.seed = init_seed + fold_num
39
+ args.save_dir = os.path.join(save_dir, f'fold_{fold_num}')
40
+ makedirs(args.save_dir)
41
+ if args.parser_name == "finetune":
42
+ model_scores = run_training(args, time_start, logger)
43
+ else:
44
+ model_scores = run_evaluation(args, logger)
45
+ all_scores.append(model_scores)
46
+ all_scores = np.array(all_scores)
47
+
48
+ # Report scores for each fold
49
+ info(f'{args.num_folds}-fold cross validation')
50
+
51
+ for fold_num, scores in enumerate(all_scores):
52
+ info(f'Seed {init_seed + fold_num} ==> test {args.metric} = {np.nanmean(scores):.6f}')
53
+
54
+ if args.show_individual_scores:
55
+ for task_name, score in zip(task_names, scores):
56
+ info(f'Seed {init_seed + fold_num} ==> test {task_name} {args.metric} = {score:.6f}')
57
+
58
+ # Report scores across models
59
+ avg_scores = np.nanmean(all_scores, axis=1) # average score for each model across tasks
60
+ mean_score, std_score = np.nanmean(avg_scores), np.nanstd(avg_scores)
61
+ info(f'overall_{args.split_type}_test_{args.metric}={mean_score:.6f}')
62
+ info(f'std={std_score:.6f}')
63
+
64
+ if args.show_individual_scores:
65
+ for task_num, task_name in enumerate(task_names):
66
+ info(f'Overall test {task_name} {args.metric} = '
67
+ f'{np.nanmean(all_scores[:, task_num]):.6f} +/- {np.nanstd(all_scores[:, task_num]):.6f}')
68
+
69
+ return mean_score, std_score
task/fingerprint.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The fingerprint generation function.
3
+ """
4
+ from argparse import Namespace
5
+ from logging import Logger
6
+ from typing import List
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.utils.data import DataLoader
11
+
12
+ from grover.data import MolCollator
13
+ from grover.data import MoleculeDataset
14
+ from grover.util.utils import get_data, create_logger, load_checkpoint
15
+
16
+
17
+ def do_generate(model: nn.Module,
18
+ data: MoleculeDataset,
19
+ args: Namespace,
20
+ ) -> List[List[float]]:
21
+ """
22
+ Do the fingerprint generation on a dataset using the pre-trained models.
23
+
24
+ :param model: A model.
25
+ :param data: A MoleculeDataset.
26
+ :param args: A StandardScaler object fit on the training targets.
27
+ :return: A list of fingerprints.
28
+ """
29
+ model.eval()
30
+ args.bond_drop_rate = 0
31
+ preds = []
32
+
33
+ mol_collator = MolCollator(args=args, shared_dict={})
34
+
35
+ num_workers = 4
36
+ mol_loader = DataLoader(data,
37
+ batch_size=32,
38
+ shuffle=False,
39
+ num_workers=num_workers,
40
+ collate_fn=mol_collator)
41
+ for item in mol_loader:
42
+ _, batch, features_batch, _, _ = item
43
+ with torch.no_grad():
44
+ batch_preds = model(batch, features_batch)
45
+ preds.extend(batch_preds.data.cpu().numpy())
46
+ return preds
47
+
48
+
49
+ def generate_fingerprints(args: Namespace, logger: Logger = None) -> List[List[float]]:
50
+ """
51
+ Generate the fingerprints.
52
+
53
+ :param logger:
54
+ :param args: Arguments.
55
+ :return: A list of lists of target fingerprints.
56
+ """
57
+
58
+ checkpoint_path = args.checkpoint_paths[0]
59
+ if logger is None:
60
+ logger = create_logger('fingerprints', quiet=False)
61
+ print('Loading data')
62
+ test_data = get_data(path=args.data_path,
63
+ args=args,
64
+ use_compound_names=False,
65
+ max_data_size=float("inf"),
66
+ skip_invalid_smiles=False)
67
+ test_data = MoleculeDataset(test_data)
68
+
69
+ logger.info(f'Total size = {len(test_data):,}')
70
+ logger.info(f'Generating...')
71
+ # Load model
72
+ model = load_checkpoint(checkpoint_path, cuda=args.cuda, current_args=args, logger=logger)
73
+ model_preds = do_generate(
74
+ model=model,
75
+ data=test_data,
76
+ args=args
77
+ )
78
+
79
+ return model_preds
task/grovertrainer.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The GROVER trainer.
3
+ """
4
+ import os
5
+ import time
6
+ from logging import Logger
7
+ from typing import List, Tuple
8
+ from collections.abc import Callable
9
+ import torch
10
+ from torch.nn import Module
11
+ from torch.utils.data import DataLoader
12
+
13
+ from grover.model.models import GroverTask
14
+ from grover.util.multi_gpu_wrapper import MultiGpuWrapper as mgw
15
+
16
+
17
+ class GROVERTrainer:
18
+ def __init__(self,
19
+ args,
20
+ embedding_model: Module,
21
+ atom_vocab_size: int, # atom vocab size
22
+ bond_vocab_size: int,
23
+ fg_szie: int,
24
+ train_dataloader: DataLoader,
25
+ test_dataloader: DataLoader,
26
+ optimizer_builder: Callable,
27
+ scheduler_builder: Callable,
28
+ logger: Logger = None,
29
+ with_cuda: bool = False,
30
+ enable_multi_gpu: bool = False):
31
+ """
32
+ The init function of GROVERTrainer
33
+ :param args: the input arguments.
34
+ :param embedding_model: the model to generate atom/bond embeddings.
35
+ :param atom_vocab_size: the vocabulary size of atoms.
36
+ :param bond_vocab_size: the vocabulary size of bonds.
37
+ :param fg_szie: the size of semantic motifs (functional groups)
38
+ :param train_dataloader: the data loader of train data.
39
+ :param test_dataloader: the data loader of validation data.
40
+ :param optimizer_builder: the function of building the optimizer.
41
+ :param scheduler_builder: the function of building the scheduler.
42
+ :param logger: the logger
43
+ :param with_cuda: enable gpu training.
44
+ :param enable_multi_gpu: enable multi_gpu traning.
45
+ """
46
+
47
+ self.args = args
48
+ self.with_cuda = with_cuda
49
+ self.grover = embedding_model
50
+ self.model = GroverTask(args, embedding_model, atom_vocab_size, bond_vocab_size, fg_szie)
51
+ self.loss_func = self.model.get_loss_func(args)
52
+ self.enable_multi_gpu = enable_multi_gpu
53
+
54
+ self.atom_vocab_size = atom_vocab_size
55
+ self.bond_vocab_size = bond_vocab_size
56
+ self.debug = logger.debug if logger is not None else print
57
+
58
+ if self.with_cuda:
59
+ # print("Using %d GPUs for training." % (torch.cuda.device_count()))
60
+ self.model = self.model.cuda()
61
+
62
+ self.train_data = train_dataloader
63
+ self.test_data = test_dataloader
64
+
65
+ self.optimizer = optimizer_builder(self.model, self.args)
66
+ self.scheduler = scheduler_builder(self.optimizer, self.args)
67
+ if self.enable_multi_gpu:
68
+ self.optimizer = mgw.DistributedOptimizer(self.optimizer,
69
+ named_parameters=self.model.named_parameters())
70
+ self.args = args
71
+ self.n_iter = 0
72
+
73
+ def broadcast_parameters(self) -> None:
74
+ """
75
+ Broadcast parameters before training.
76
+ :return: no return.
77
+ """
78
+ if self.enable_multi_gpu:
79
+ # broadcast parameters & optimizer state.
80
+ mgw.broadcast_parameters(self.model.state_dict(), root_rank=0)
81
+ mgw.broadcast_optimizer_state(self.optimizer, root_rank=0)
82
+
83
+ def train(self, epoch: int) -> List:
84
+ """
85
+ The training iteration
86
+ :param epoch: the current epoch number.
87
+ :return: the loss terms of current epoch.
88
+ """
89
+ # return self.mock_iter(epoch, self.train_data, train=True)
90
+ return self.iter(epoch, self.train_data, train=True)
91
+
92
+ def test(self, epoch: int) -> List:
93
+ """
94
+ The test/validaiion iteration
95
+ :param epoch: the current epoch number.
96
+ :return: the loss terms as a list
97
+ """
98
+ # return self.mock_iter(epoch, self.test_data, train=False)
99
+ return self.iter(epoch, self.test_data, train=False)
100
+
101
+ def mock_iter(self, epoch: int, data_loader: DataLoader, train: bool = True) -> List:
102
+ """
103
+ Perform a mock iteration. For test only.
104
+ :param epoch: the current epoch number.
105
+ :param data_loader: the data loader.
106
+ :param train: True: train model, False: validation model.
107
+ :return: the loss terms as a list
108
+ """
109
+
110
+ for _, _ in enumerate(data_loader):
111
+ self.scheduler.step()
112
+ cum_loss_sum = 0.0
113
+ self.n_iter += self.args.batch_size
114
+ return self.n_iter, cum_loss_sum, (0, 0, 0, 0, 0, 0)
115
+
116
+ def iter(self, epoch, data_loader, train=True) -> List:
117
+ """
118
+ Perform a training / validation iteration.
119
+ :param epoch: the current epoch number.
120
+ :param data_loader: the data loader.
121
+ :param train: True: train model, False: validation model.
122
+ :return: the loss terms as a list
123
+ """
124
+
125
+ if train:
126
+ self.model.train()
127
+ else:
128
+ self.model.eval()
129
+
130
+ loss_sum, iter_count = 0, 0
131
+ cum_loss_sum, cum_iter_count = 0, 0
132
+ av_loss_sum, bv_loss_sum, fg_loss_sum, av_dist_loss_sum, bv_dist_loss_sum, fg_dist_loss_sum = 0, 0, 0, 0, 0, 0
133
+ # loss_func = self.model.get_loss_func(self.args)
134
+
135
+ for _, item in enumerate(data_loader):
136
+ batch_graph = item["graph_input"]
137
+ targets = item["targets"]
138
+
139
+ if next(self.model.parameters()).is_cuda:
140
+ targets["av_task"] = targets["av_task"].cuda()
141
+ targets["bv_task"] = targets["bv_task"].cuda()
142
+ targets["fg_task"] = targets["fg_task"].cuda()
143
+
144
+ preds = self.model(batch_graph)
145
+
146
+ # # ad-hoc code, for visualizing a model, comment this block when it is not needed
147
+ # import dglt.contrib.grover.vis_model as vis_model
148
+ # for task in ['av_task', 'bv_task', 'fg_task']:
149
+ # vis_graph = vis_model.make_dot(self.model(batch_graph)[task],
150
+ # params=dict(self.model.named_parameters()))
151
+ # # vis_graph.view()
152
+ # vis_graph.render(f"{self.args.backbone}_model_{task}_vis.png", format="png")
153
+ # exit()
154
+
155
+ loss, av_loss, bv_loss, fg_loss, av_dist_loss, bv_dist_loss, fg_dist_loss = self.loss_func(preds, targets)
156
+
157
+ loss_sum += loss.item()
158
+ iter_count += self.args.batch_size
159
+
160
+ if train:
161
+ cum_loss_sum += loss.item()
162
+ # Run model
163
+ self.model.zero_grad()
164
+ self.optimizer.zero_grad()
165
+ loss.backward()
166
+ self.optimizer.step()
167
+ self.scheduler.step()
168
+ else:
169
+ # For eval model, only consider the loss of three task.
170
+ cum_loss_sum += av_loss.item()
171
+ cum_loss_sum += bv_loss.item()
172
+ cum_loss_sum += fg_loss.item()
173
+
174
+ av_loss_sum += av_loss.item()
175
+ bv_loss_sum += bv_loss.item()
176
+ fg_loss_sum += fg_loss.item()
177
+ av_dist_loss_sum += av_dist_loss.item() if type(av_dist_loss) != float else av_dist_loss
178
+ bv_dist_loss_sum += bv_dist_loss.item() if type(bv_dist_loss) != float else bv_dist_loss
179
+ fg_dist_loss_sum += fg_dist_loss.item() if type(fg_dist_loss) != float else fg_dist_loss
180
+
181
+ cum_iter_count += 1
182
+ self.n_iter += self.args.batch_size
183
+
184
+ # Debug only.
185
+ # if i % 50 == 0:
186
+ # print(f"epoch: {epoch}, batch_id: {i}, av_loss: {av_loss}, bv_loss: {bv_loss}, "
187
+ # f"fg_loss: {fg_loss}, av_dist_loss: {av_dist_loss}, bv_dist_loss: {bv_dist_loss}, "
188
+ # f"fg_dist_loss: {fg_dist_loss}")
189
+
190
+ cum_loss_sum /= cum_iter_count
191
+ av_loss_sum /= cum_iter_count
192
+ bv_loss_sum /= cum_iter_count
193
+ fg_loss_sum /= cum_iter_count
194
+ av_dist_loss_sum /= cum_iter_count
195
+ bv_dist_loss_sum /= cum_iter_count
196
+ fg_dist_loss_sum /= cum_iter_count
197
+
198
+ return self.n_iter, cum_loss_sum, (av_loss_sum, bv_loss_sum, fg_loss_sum, av_dist_loss_sum,
199
+ bv_dist_loss_sum, fg_dist_loss_sum)
200
+
201
+ def save(self, epoch, file_path, name=None) -> str:
202
+ """
203
+ Save the intermediate models during training.
204
+ :param epoch: the epoch number.
205
+ :param file_path: the file_path to save the model.
206
+ :return: the output path.
207
+ """
208
+ # add specific time in model fine name, in order to distinguish different saved models
209
+ now = time.localtime()
210
+ if name is None:
211
+ name = "_%04d_%02d_%02d_%02d_%02d_%02d" % (
212
+ now.tm_year, now.tm_mon, now.tm_mday, now.tm_hour, now.tm_min, now.tm_sec)
213
+ output_path = file_path + name + ".ep%d" % epoch
214
+ scaler = None
215
+ features_scaler = None
216
+ state = {
217
+ 'args': self.args,
218
+ 'state_dict': self.model.state_dict(),
219
+ 'optimizer': self.optimizer.state_dict(),
220
+ 'scheduler_step': self.scheduler.current_step,
221
+ "epoch": epoch,
222
+ 'data_scaler': {
223
+ 'means': scaler.means,
224
+ 'stds': scaler.stds
225
+ } if scaler is not None else None,
226
+ 'features_scaler': {
227
+ 'means': features_scaler.means,
228
+ 'stds': features_scaler.stds
229
+ } if features_scaler is not None else None
230
+ }
231
+ torch.save(state, output_path)
232
+
233
+ # Is this necessary?
234
+ # if self.with_cuda:
235
+ # self.model = self.model.cuda()
236
+ print("EP:%d Model Saved on:" % epoch, output_path)
237
+ return output_path
238
+
239
+ def save_tmp(self, epoch, file_path, rank=0):
240
+ """
241
+ Save the models for auto-restore during training.
242
+ The model are stored in file_path/tmp folder and will replaced on each epoch.
243
+ :param epoch: the epoch number.
244
+ :param file_path: the file_path to store the model.
245
+ :param rank: the current rank (decrypted).
246
+ :return:
247
+ """
248
+ store_path = os.path.join(file_path, "tmp")
249
+ if not os.path.exists(store_path):
250
+ os.makedirs(store_path, exist_ok=True)
251
+ store_path = os.path.join(store_path, "model.%d" % rank)
252
+ state = {
253
+ 'args': self.args,
254
+ 'state_dict': self.model.state_dict(),
255
+ 'optimizer': self.optimizer.state_dict(),
256
+ 'scheduler_step': self.scheduler.current_step,
257
+ "epoch": epoch
258
+ }
259
+ torch.save(state, store_path)
260
+
261
+ def restore(self, file_path, rank=0) -> Tuple[int, int]:
262
+ """
263
+ Restore the training state saved by save_tmp.
264
+ :param file_path: the file_path to store the model.
265
+ :param rank: the current rank (decrypted).
266
+ :return: the restored epoch number and the scheduler_step in scheduler.
267
+ """
268
+ cpt_path = os.path.join(file_path, "tmp", "model.%d" % rank)
269
+ if not os.path.exists(cpt_path):
270
+ print("No checkpoint found %d")
271
+ return 0, 0
272
+ cpt = torch.load(cpt_path)
273
+ self.model.load_state_dict(cpt["state_dict"])
274
+ self.optimizer.load_state_dict(cpt["optimizer"])
275
+ epoch = cpt["epoch"]
276
+ scheduler_step = cpt["scheduler_step"]
277
+ self.scheduler.current_step = scheduler_step
278
+ print("Restore checkpoint, current epoch: %d" % (epoch))
279
+ return epoch, scheduler_step
task/predict.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The predict function using the finetuned model to make the prediction. .
3
+ """
4
+ from argparse import Namespace
5
+ from typing import List
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch.utils.data import DataLoader
12
+
13
+ from grover.data import MolCollator
14
+ from grover.data import MoleculeDataset
15
+ from grover.data import StandardScaler
16
+ from grover.util.utils import get_data, get_data_from_smiles, create_logger, load_args, get_task_names, tqdm, \
17
+ load_checkpoint, load_scalars
18
+
19
+
20
+ def predict(model: nn.Module,
21
+ data: MoleculeDataset,
22
+ args: Namespace,
23
+ batch_size: int,
24
+ loss_func,
25
+ logger,
26
+ shared_dict,
27
+ scaler: StandardScaler = None
28
+ ) -> List[List[float]]:
29
+ """
30
+ Makes predictions on a dataset using an ensemble of models.
31
+
32
+ :param model: A model.
33
+ :param data: A MoleculeDataset.
34
+ :param batch_size: Batch size.
35
+ :param scaler: A StandardScaler object fit on the training targets.
36
+ :return: A list of lists of predictions. The outer list is examples
37
+ while the inner list is tasks.
38
+ """
39
+ # debug = logger.debug if logger is not None else print
40
+ model.eval()
41
+ args.bond_drop_rate = 0
42
+ preds = []
43
+
44
+ # num_iters, iter_step = len(data), batch_size
45
+ loss_sum, iter_count = 0, 0
46
+
47
+ mol_collator = MolCollator(args=args, shared_dict=shared_dict)
48
+ # mol_dataset = MoleculeDataset(data)
49
+
50
+ num_workers = 4
51
+ mol_loader = DataLoader(data, batch_size=batch_size, shuffle=False, num_workers=num_workers,
52
+ collate_fn=mol_collator)
53
+ for _, item in enumerate(mol_loader):
54
+ _, batch, features_batch, mask, targets = item
55
+ class_weights = torch.ones(targets.shape)
56
+ if next(model.parameters()).is_cuda:
57
+ targets = targets.cuda()
58
+ mask = mask.cuda()
59
+ class_weights = class_weights.cuda()
60
+ with torch.no_grad():
61
+ batch_preds = model(batch, features_batch)
62
+ iter_count += 1
63
+ if args.fingerprint:
64
+ preds.extend(batch_preds.data.cpu().numpy())
65
+ continue
66
+
67
+ if loss_func is not None:
68
+ loss = loss_func(batch_preds, targets) * class_weights * mask
69
+ loss = loss.sum() / mask.sum()
70
+ loss_sum += loss.item()
71
+ # Collect vectors
72
+ batch_preds = batch_preds.data.cpu().numpy().tolist()
73
+ if scaler is not None:
74
+ batch_preds = scaler.inverse_transform(batch_preds)
75
+ preds.extend(batch_preds)
76
+
77
+ loss_avg = loss_sum / iter_count
78
+ return preds, loss_avg
79
+
80
+
81
+ def make_predictions(args: Namespace, newest_train_args=None, smiles: List[str] = None):
82
+ """
83
+ Makes predictions. If smiles is provided, makes predictions on smiles.
84
+ Otherwise makes predictions on args.test_data.
85
+
86
+ :param args: Arguments.
87
+ :param smiles: Smiles to make predictions on.
88
+ :return: A list of lists of target predictions.
89
+ """
90
+ if args.gpu is not None:
91
+ torch.cuda.set_device(args.gpu)
92
+
93
+ print('Loading training args')
94
+
95
+ path = args.checkpoint_paths[0]
96
+ scaler, features_scaler = load_scalars(path)
97
+ train_args = load_args(path)
98
+
99
+ # Update args with training arguments saved in checkpoint
100
+ for key, value in vars(train_args).items():
101
+ if not hasattr(args, key):
102
+ setattr(args, key, value)
103
+
104
+ # update args with newest training args
105
+ if newest_train_args is not None:
106
+ for key, value in vars(newest_train_args).items():
107
+ if not hasattr(args, key):
108
+ setattr(args, key, value)
109
+
110
+
111
+ # deal with multiprocess problem
112
+ args.debug = True
113
+
114
+ logger = create_logger('predict', quiet=False)
115
+ print('Loading data')
116
+ args.task_names = get_task_names(args.data_path)
117
+ if smiles is not None:
118
+ test_data = get_data_from_smiles(smiles=smiles, skip_invalid_smiles=False)
119
+ else:
120
+ test_data = get_data(path=args.data_path, args=args,
121
+ use_compound_names=args.use_compound_names, skip_invalid_smiles=False)
122
+
123
+
124
+ args.num_tasks = test_data.num_tasks()
125
+ args.features_size = test_data.features_size()
126
+
127
+ print('Validating SMILES')
128
+ valid_indices = [i for i in range(len(test_data))]
129
+ full_data = test_data
130
+ # test_data = MoleculeDataset([test_data[i] for i in valid_indices])
131
+ test_data_list = []
132
+ for i in valid_indices:
133
+ test_data_list.append(test_data[i])
134
+ test_data = MoleculeDataset(test_data_list)
135
+
136
+ # Edge case if empty list of smiles is provided
137
+ if len(test_data) == 0:
138
+ return [None] * len(full_data)
139
+
140
+ print(f'Test size = {len(test_data):,}')
141
+
142
+ # Normalize features
143
+ if hasattr(train_args, 'features_scaling'):
144
+ if train_args.features_scaling:
145
+ test_data.normalize_features(features_scaler)
146
+
147
+ # Predict with each model individually and sum predictions
148
+ if hasattr(args, 'num_tasks'):
149
+ sum_preds = np.zeros((len(test_data), args.num_tasks))
150
+ print(f'Predicting...')
151
+ shared_dict = {}
152
+ # loss_func = torch.nn.BCEWithLogitsLoss()
153
+ count = 0
154
+ for checkpoint_path in tqdm(args.checkpoint_paths, total=len(args.checkpoint_paths)):
155
+ # Load model
156
+ model = load_checkpoint(checkpoint_path, cuda=args.cuda, current_args=args, logger=logger)
157
+ model_preds, _ = predict(
158
+ model=model,
159
+ data=test_data,
160
+ batch_size=args.batch_size,
161
+ scaler=scaler,
162
+ shared_dict=shared_dict,
163
+ args=args,
164
+ logger=logger,
165
+ loss_func=None
166
+ )
167
+
168
+ if args.fingerprint:
169
+ return model_preds
170
+
171
+ sum_preds += np.array(model_preds, dtype=float)
172
+ count += 1
173
+
174
+ # Ensemble predictions
175
+ avg_preds = sum_preds / len(args.checkpoint_paths)
176
+
177
+ # Save predictions
178
+ assert len(test_data) == len(avg_preds)
179
+
180
+ # Put Nones for invalid smiles
181
+ args.valid_indices = valid_indices
182
+ avg_preds = np.array(avg_preds)
183
+ test_smiles = full_data.smiles()
184
+ return avg_preds, test_smiles
185
+
186
+
187
+ def write_prediction(avg_preds, test_smiles, args):
188
+ """
189
+ write prediction to disk
190
+ :param avg_preds: prediction value
191
+ :param test_smiles: input smiles
192
+ :param args: Arguments
193
+ """
194
+ if args.dataset_type == 'multiclass':
195
+ avg_preds = np.argmax(avg_preds, -1)
196
+ full_preds = [[None]] * len(test_smiles)
197
+ for i, si in enumerate(args.valid_indices):
198
+ full_preds[si] = avg_preds[i]
199
+ result = pd.DataFrame(data=full_preds, index=test_smiles, columns=args.task_names)
200
+ result.to_csv(args.output_path)
201
+ print(f'Saving predictions to {args.output_path}')
202
+
203
+
204
+
205
+ def evaluate_predictions(preds: List[List[float]],
206
+ targets: List[List[float]],
207
+ num_tasks: int,
208
+ metric_func,
209
+ dataset_type: str,
210
+ logger = None) -> List[float]:
211
+ """
212
+ Evaluates predictions using a metric function and filtering out invalid targets.
213
+
214
+ :param preds: A list of lists of shape (data_size, num_tasks) with model predictions.
215
+ :param targets: A list of lists of shape (data_size, num_tasks) with targets.
216
+ :param num_tasks: Number of tasks.
217
+ :param metric_func: Metric function which takes in a list of targets and a list of predictions.
218
+ :param dataset_type: Dataset type.
219
+ :param logger: Logger.
220
+ :return: A list with the score for each task based on `metric_func`.
221
+ """
222
+ if dataset_type == 'multiclass':
223
+ results = metric_func(np.argmax(preds, -1), [i[0] for i in targets])
224
+ return [results]
225
+
226
+ # info = logger.info if logger is not None else print
227
+
228
+ if len(preds) == 0:
229
+ return [float('nan')] * num_tasks
230
+
231
+ # Filter out empty targets
232
+ # valid_preds and valid_targets have shape (num_tasks, data_size)
233
+ valid_preds = [[] for _ in range(num_tasks)]
234
+ valid_targets = [[] for _ in range(num_tasks)]
235
+ for i in range(num_tasks):
236
+ for j in range(len(preds)):
237
+ if targets[j][i] is not None: # Skip those without targets
238
+ valid_preds[i].append(preds[j][i])
239
+ valid_targets[i].append(targets[j][i])
240
+
241
+ # Compute metric
242
+ results = []
243
+ for i in range(num_tasks):
244
+ # # Skip if all targets or preds are identical, otherwise we'll crash during classification
245
+ if dataset_type == 'classification':
246
+ nan = False
247
+ if all(target == 0 for target in valid_targets[i]) or all(target == 1 for target in valid_targets[i]):
248
+ nan = True
249
+ # info('Warning: Found a task with targets all 0s or all 1s')
250
+ if all(pred == 0 for pred in valid_preds[i]) or all(pred == 1 for pred in valid_preds[i]):
251
+ nan = True
252
+ # info('Warning: Found a task with predictions all 0s or all 1s')
253
+
254
+ if nan:
255
+ results.append(float('nan'))
256
+ continue
257
+
258
+ if len(valid_targets[i]) == 0:
259
+ continue
260
+
261
+ results.append(metric_func(valid_targets[i], valid_preds[i]))
262
+
263
+ return results
264
+
265
+
266
+ def evaluate(model: nn.Module,
267
+ data: MoleculeDataset,
268
+ num_tasks: int,
269
+ metric_func,
270
+ loss_func,
271
+ batch_size: int,
272
+ dataset_type: str,
273
+ args: Namespace,
274
+ shared_dict,
275
+ scaler: StandardScaler = None,
276
+ logger = None) -> List[float]:
277
+ """
278
+ Evaluates an ensemble of models on a dataset.
279
+
280
+ :param model: A model.
281
+ :param data: A MoleculeDataset.
282
+ :param num_tasks: Number of tasks.
283
+ :param metric_func: Metric function which takes in a list of targets and a list of predictions.
284
+ :param batch_size: Batch size.
285
+ :param dataset_type: Dataset type.
286
+ :param scaler: A StandardScaler object fit on the training targets.
287
+ :param logger: Logger.
288
+ :return: A list with the score for each task based on `metric_func`.
289
+ """
290
+ preds, loss_avg = predict(
291
+ model=model,
292
+ data=data,
293
+ loss_func=loss_func,
294
+ batch_size=batch_size,
295
+ scaler=scaler,
296
+ shared_dict=shared_dict,
297
+ logger=logger,
298
+ args=args
299
+ )
300
+
301
+ targets = data.targets()
302
+ if scaler is not None:
303
+ targets = scaler.inverse_transform(targets)
304
+
305
+
306
+
307
+ results = evaluate_predictions(
308
+ preds=preds,
309
+ targets=targets,
310
+ num_tasks=num_tasks,
311
+ metric_func=metric_func,
312
+ dataset_type=dataset_type,
313
+ logger=logger
314
+ )
315
+
316
+ return results, loss_avg
task/pretrain.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The GROVER pretrain function.
3
+ """
4
+ import os
5
+ import time
6
+ from argparse import Namespace
7
+ from logging import Logger
8
+
9
+ import torch
10
+ from torch.utils.data import DataLoader
11
+
12
+ from grover.data.dist_sampler import DistributedSampler
13
+ from grover.data.groverdataset import get_data, split_data, GroverCollator, BatchMolDataset
14
+ from grover.data.torchvocab import MolVocab
15
+ from grover.model.models import GROVEREmbedding
16
+ from grover.util.multi_gpu_wrapper import MultiGpuWrapper as mgw
17
+ from grover.util.nn_utils import param_count
18
+ from grover.util.utils import build_optimizer, build_lr_scheduler
19
+ from task.grovertrainer import GROVERTrainer
20
+
21
+
22
+ def pretrain_model(args: Namespace, logger: Logger = None):
23
+ """
24
+ The entrey of pretrain.
25
+ :param args: the argument.
26
+ :param logger: the logger.
27
+ :return:
28
+ """
29
+
30
+ # avoid auto optimized import by pycharm.
31
+ a = MolVocab
32
+ s_time = time.time()
33
+ run_training(args=args, logger=logger)
34
+ e_time = time.time()
35
+ print("Total Time: %.3f" % (e_time - s_time))
36
+
37
+
38
+ def pre_load_data(dataset: BatchMolDataset, rank: int, num_replicas: int, sample_per_file: int = None, epoch: int = 0):
39
+ """
40
+ Pre-load data at the beginning of each epoch.
41
+ :param dataset: the training dataset.
42
+ :param rank: the rank of the current worker.
43
+ :param num_replicas: the replicas.
44
+ :param sample_per_file: the number of the data points in each file. When sample_per_file is None, all data will be
45
+ loaded. It implies the testing phase. (TODO: bad design here.)
46
+ :param epoch: the epoch number.
47
+ :return:
48
+ """
49
+ mock_sampler = DistributedSampler(dataset, num_replicas=num_replicas, rank=rank, shuffle=False,
50
+ sample_per_file=sample_per_file)
51
+ mock_sampler.set_epoch(epoch)
52
+ pre_indices = mock_sampler.get_indices()
53
+ for i in pre_indices:
54
+ dataset.load_data(i)
55
+
56
+
57
+ def run_training(args, logger):
58
+ """
59
+ Run the pretrain task.
60
+ :param args:
61
+ :param logger:
62
+ :return:
63
+ """
64
+
65
+ # initalize the logger.
66
+ if logger is not None:
67
+ debug, _ = logger.debug, logger.info
68
+ else:
69
+ debug = print
70
+
71
+ # initialize the horovod library
72
+ if args.enable_multi_gpu:
73
+ mgw.init()
74
+
75
+ # binding training to GPUs.
76
+ master_worker = (mgw.rank() == 0) if args.enable_multi_gpu else True
77
+ # pin GPU to local rank. By default, we use gpu:0 for training.
78
+ local_gpu_idx = mgw.local_rank() if args.enable_multi_gpu else 0
79
+ with_cuda = args.cuda
80
+ if with_cuda:
81
+ torch.cuda.set_device(local_gpu_idx)
82
+
83
+ # get rank an number of workers
84
+ rank = mgw.rank() if args.enable_multi_gpu else 0
85
+ num_replicas = mgw.size() if args.enable_multi_gpu else 1
86
+ # print("Rank: %d Rep: %d" % (rank, num_replicas))
87
+
88
+ # load file paths of the data.
89
+ if master_worker:
90
+ print(args)
91
+ if args.enable_multi_gpu:
92
+ debug("Total workers: %d" % (mgw.size()))
93
+ debug('Loading data')
94
+ data, sample_per_file = get_data(data_path=args.data_path)
95
+
96
+ # data splitting
97
+ if master_worker:
98
+ debug(f'Splitting data with seed 0.')
99
+ train_data, test_data, _ = split_data(data=data, sizes=(0.9, 0.1, 0.0), seed=0, logger=logger)
100
+
101
+ # Here the true train data size is the train_data divided by #GPUs
102
+ if args.enable_multi_gpu:
103
+ args.train_data_size = len(train_data) // mgw.size()
104
+ else:
105
+ args.train_data_size = len(train_data)
106
+ if master_worker:
107
+ debug(f'Total size = {len(data):,} | '
108
+ f'train size = {len(train_data):,} | val size = {len(test_data):,}')
109
+
110
+ # load atom and bond vocabulary and the semantic motif labels.
111
+ atom_vocab = MolVocab.load_vocab(args.atom_vocab_path)
112
+ bond_vocab = MolVocab.load_vocab(args.bond_vocab_path)
113
+ atom_vocab_size, bond_vocab_size = len(atom_vocab), len(bond_vocab)
114
+
115
+ # Hard coding here, since we haven't load any data yet!
116
+ fg_size = 85
117
+ shared_dict = {}
118
+ mol_collator = GroverCollator(shared_dict=shared_dict, atom_vocab=atom_vocab, bond_vocab=bond_vocab, args=args)
119
+ if master_worker:
120
+ debug("atom vocab size: %d, bond vocab size: %d, Number of FG tasks: %d" % (atom_vocab_size,
121
+ bond_vocab_size, fg_size))
122
+
123
+ # Define the distributed sampler. If using the single card, the sampler will be None.
124
+ train_sampler = None
125
+ test_sampler = None
126
+ shuffle = True
127
+ if args.enable_multi_gpu:
128
+ # If not shuffle, the performance may decayed.
129
+ train_sampler = DistributedSampler(
130
+ train_data, num_replicas=mgw.size(), rank=mgw.rank(), shuffle=True, sample_per_file=sample_per_file)
131
+ # Here sample_per_file in test_sampler is None, indicating the test sampler would not divide the test samples by
132
+ # rank. (TODO: bad design here.)
133
+ test_sampler = DistributedSampler(
134
+ test_data, num_replicas=mgw.size(), rank=mgw.rank(), shuffle=False)
135
+ train_sampler.set_epoch(args.epochs)
136
+ test_sampler.set_epoch(1)
137
+ # if we enables multi_gpu training. shuffle should be disabled.
138
+ shuffle = False
139
+
140
+ # Pre load data. (Maybe unnecessary. )
141
+ pre_load_data(train_data, rank, num_replicas, sample_per_file)
142
+ pre_load_data(test_data, rank, num_replicas)
143
+ if master_worker:
144
+ # print("Pre-loaded training data: %d" % train_data.count_loaded_datapoints())
145
+ print("Pre-loaded test data: %d" % test_data.count_loaded_datapoints())
146
+
147
+ # Build dataloader
148
+ train_data_dl = DataLoader(train_data,
149
+ batch_size=args.batch_size,
150
+ shuffle=shuffle,
151
+ num_workers=12,
152
+ sampler=train_sampler,
153
+ collate_fn=mol_collator)
154
+ test_data_dl = DataLoader(test_data,
155
+ batch_size=args.batch_size,
156
+ shuffle=shuffle,
157
+ num_workers=10,
158
+ sampler=test_sampler,
159
+ collate_fn=mol_collator)
160
+
161
+ # Build the embedding model.
162
+ grover_model = GROVEREmbedding(args)
163
+
164
+ # Build the trainer.
165
+ trainer = GROVERTrainer(args=args,
166
+ embedding_model=grover_model,
167
+ atom_vocab_size=atom_vocab_size,
168
+ bond_vocab_size=bond_vocab_size,
169
+ fg_szie=fg_size,
170
+ train_dataloader=train_data_dl,
171
+ test_dataloader=test_data_dl,
172
+ optimizer_builder=build_optimizer,
173
+ scheduler_builder=build_lr_scheduler,
174
+ logger=logger,
175
+ with_cuda=with_cuda,
176
+ enable_multi_gpu=args.enable_multi_gpu)
177
+
178
+ # Restore the interrupted training.
179
+ model_dir = os.path.join(args.save_dir, "model")
180
+ resume_from_epoch = 0
181
+ resume_scheduler_step = 0
182
+ if master_worker:
183
+ resume_from_epoch, resume_scheduler_step = trainer.restore(model_dir)
184
+ if args.enable_multi_gpu:
185
+ resume_from_epoch = mgw.broadcast(torch.tensor(resume_from_epoch), root_rank=0, name="resume_from_epoch").item()
186
+ resume_scheduler_step = mgw.broadcast(torch.tensor(resume_scheduler_step),
187
+ root_rank=0, name="resume_scheduler_step").item()
188
+ trainer.scheduler.current_step = resume_scheduler_step
189
+ print("Restored epoch: %d Restored scheduler step: %d" % (resume_from_epoch, trainer.scheduler.current_step))
190
+ trainer.broadcast_parameters()
191
+
192
+ # Print model details.
193
+ if master_worker:
194
+ # Change order here.
195
+ print(grover_model)
196
+ print("Total parameters: %d" % param_count(trainer.grover))
197
+
198
+ # Perform training.
199
+ for epoch in range(resume_from_epoch + 1, args.epochs):
200
+ s_time = time.time()
201
+
202
+ # Data pre-loading.
203
+ if args.enable_multi_gpu:
204
+ train_sampler.set_epoch(epoch)
205
+ train_data.clean_cache()
206
+ idxs = train_sampler.get_indices()
207
+ for local_gpu_idx in idxs:
208
+ train_data.load_data(local_gpu_idx)
209
+ d_time = time.time() - s_time
210
+
211
+ # perform training and validation.
212
+ s_time = time.time()
213
+ _, train_loss, _ = trainer.train(epoch)
214
+ t_time = time.time() - s_time
215
+ s_time = time.time()
216
+ _, val_loss, detailed_loss_val = trainer.test(epoch)
217
+ val_av_loss, val_bv_loss, val_fg_loss, _, _, _ = detailed_loss_val
218
+ v_time = time.time() - s_time
219
+
220
+ # print information.
221
+ if master_worker:
222
+ print('Epoch: {:04d}'.format(epoch),
223
+ 'loss_train: {:.6f}'.format(train_loss),
224
+ 'loss_val: {:.6f}'.format(val_loss),
225
+ 'loss_val_av: {:.6f}'.format(val_av_loss),
226
+ 'loss_val_bv: {:.6f}'.format(val_bv_loss),
227
+ 'loss_val_fg: {:.6f}'.format(val_fg_loss),
228
+ 'cur_lr: {:.5f}'.format(trainer.scheduler.get_lr()[0]),
229
+ 't_time: {:.4f}s'.format(t_time),
230
+ 'v_time: {:.4f}s'.format(v_time),
231
+ 'd_time: {:.4f}s'.format(d_time), flush=True)
232
+
233
+ if epoch % args.save_interval == 0:
234
+ trainer.save(epoch, model_dir)
235
+
236
+
237
+ trainer.save_tmp(epoch, model_dir, rank)
238
+
239
+ # Only save final version.
240
+ if master_worker:
241
+ trainer.save(args.epochs, model_dir, "")
task/run_evaluation.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The evaluation function.
3
+ """
4
+ from argparse import Namespace
5
+ from logging import Logger
6
+ from typing import List
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.utils.data.distributed
11
+
12
+ from grover.data.scaler import StandardScaler
13
+ from grover.util.utils import get_class_sizes, get_data, split_data, get_task_names, get_loss_func
14
+ from grover.util.utils import load_checkpoint
15
+ from task.predict import evaluate_predictions
16
+ from grover.util.metrics import get_metric_func
17
+ from grover.util.nn_utils import param_count
18
+ from task.predict import predict
19
+
20
+
21
+ def run_evaluation(args: Namespace, logger: Logger = None) -> List[float]:
22
+ """
23
+ Trains a model and returns test scores on the model checkpoint with the highest validation score.
24
+
25
+ :param args: Arguments.
26
+ :param logger: Logger.
27
+ :return: A list of ensemble scores for each task.
28
+ """
29
+ if logger is not None:
30
+ debug, info = logger.debug, logger.info
31
+ else:
32
+ debug = info = print
33
+
34
+ torch.cuda.set_device(0)
35
+
36
+ # Get data
37
+ debug('Loading data')
38
+ args.task_names = get_task_names(args.data_path)
39
+ data = get_data(path=args.data_path, args=args, logger=logger)
40
+ args.num_tasks = data.num_tasks()
41
+ args.features_size = data.features_size()
42
+ debug(f'Number of tasks = {args.num_tasks}')
43
+
44
+ # Split data
45
+ debug(f'Splitting data with seed {args.seed}')
46
+
47
+ train_data, val_data, test_data = split_data(data=data,
48
+ split_type=args.split_type,
49
+ sizes=[0.8, 0.1, 0.1],
50
+ seed=args.seed,
51
+ args=args,
52
+ logger=logger)
53
+
54
+ if args.dataset_type == 'classification':
55
+ class_sizes = get_class_sizes(data)
56
+ debug('Class sizes')
57
+ for i, task_class_sizes in enumerate(class_sizes):
58
+ debug(f'{args.task_names[i]} '
59
+ f'{", ".join(f"{cls}: {size * 100:.2f}%" for cls, size in enumerate(task_class_sizes))}')
60
+
61
+ if args.features_scaling:
62
+ features_scaler = train_data.normalize_features(replace_nan_token=0)
63
+ val_data.normalize_features(features_scaler)
64
+ test_data.normalize_features(features_scaler)
65
+ else:
66
+ features_scaler = None
67
+
68
+ args.train_data_size = len(train_data)
69
+
70
+ debug(f'Total size = {len(data):,} | '
71
+ f'train size = {len(train_data):,} | val size = {len(val_data):,} | test size = {len(test_data):,}')
72
+
73
+ # Initialize scaler (regression only)
74
+ scaler = None
75
+ if args.dataset_type == 'regression':
76
+ debug('Fitting scaler')
77
+ _, train_targets = train_data.smiles(), train_data.targets()
78
+ scaler = StandardScaler().fit(train_targets)
79
+ scaled_targets = scaler.transform(train_targets).tolist()
80
+ train_data.set_targets(scaled_targets)
81
+
82
+ val_targets = val_data.targets()
83
+ scaled_val_targets = scaler.transform(val_targets).tolist()
84
+ val_data.set_targets(scaled_val_targets)
85
+
86
+ metric_func = get_metric_func(metric=args.metric)
87
+
88
+ # Set up test set evaluation
89
+ test_smiles, test_targets = test_data.smiles(), test_data.targets()
90
+ sum_test_preds = np.zeros((len(test_smiles), args.num_tasks))
91
+
92
+ # Load/build model
93
+ if args.checkpoint_paths is not None:
94
+ cur_model = args.seed
95
+ target_path = []
96
+ for path in args.checkpoint_paths:
97
+ if "fold_%d" % cur_model in path:
98
+ target_path = path
99
+ debug(f'Loading model {args.seed} from {target_path}')
100
+ model = load_checkpoint(target_path, current_args=args, cuda=args.cuda, logger=logger)
101
+ # Get loss and metric functions
102
+ loss_func = get_loss_func(args, model)
103
+
104
+ debug(f'Number of parameters = {param_count(model):,}')
105
+
106
+ test_preds, _ = predict(
107
+ model=model,
108
+ data=test_data,
109
+ batch_size=args.batch_size,
110
+ loss_func=loss_func,
111
+ logger=logger,
112
+ shared_dict={},
113
+ scaler=scaler,
114
+ args=args
115
+ )
116
+
117
+ test_scores = evaluate_predictions(
118
+ preds=test_preds,
119
+ targets=test_targets,
120
+ num_tasks=args.num_tasks,
121
+ metric_func=metric_func,
122
+ dataset_type=args.dataset_type,
123
+ logger=logger
124
+ )
125
+
126
+ if len(test_preds) != 0:
127
+ sum_test_preds += np.array(test_preds, dtype=float)
128
+
129
+ # Average test score
130
+ avg_test_score = np.nanmean(test_scores)
131
+ info(f'Model test {args.metric} = {avg_test_score:.6f}')
132
+
133
+ if args.show_individual_scores:
134
+ # Individual test scores
135
+ for task_name, test_score in zip(args.task_names, test_scores):
136
+ info(f'Model test {task_name} {args.metric} = {test_score:.6f}')
137
+
138
+ # Evaluate ensemble on test set
139
+ avg_test_preds = (sum_test_preds / args.ensemble_size).tolist()
140
+
141
+ ensemble_scores = evaluate_predictions(
142
+ preds=avg_test_preds,
143
+ targets=test_targets,
144
+ num_tasks=args.num_tasks,
145
+ metric_func=metric_func,
146
+ dataset_type=args.dataset_type,
147
+ logger=logger
148
+ )
149
+
150
+ # If you want to save the prediction result, uncomment these lines.
151
+ # ind = [['preds'] * args.num_tasks + ['targets'] * args.num_tasks, args.task_names * 2]
152
+ # ind = pd.MultiIndex.from_tuples(list(zip(*ind)))
153
+ # data = np.concatenate([np.array(avg_test_preds), np.array(test_targets)], 1)
154
+ # test_result = pd.DataFrame(data, index=test_smiles, columns=ind)
155
+ # test_result.to_csv(os.path.join(args.save_dir, 'test_result.csv'))
156
+
157
+ return ensemble_scores
task/train.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The training function used in the finetuning task.
3
+ """
4
+ import csv
5
+ import logging
6
+ import os
7
+ import pickle
8
+ import time
9
+ from argparse import Namespace
10
+ from logging import Logger
11
+ from typing import List
12
+
13
+ import numpy as np
14
+ import pandas as pd
15
+ import torch
16
+ from torch.optim.lr_scheduler import ExponentialLR
17
+ from torch.utils.data import DataLoader
18
+
19
+ from grover.data import MolCollator
20
+ from grover.data import StandardScaler
21
+ from grover.util.metrics import get_metric_func
22
+ from grover.util.nn_utils import initialize_weights, param_count
23
+ from grover.util.scheduler import NoamLR
24
+ from grover.util.utils import build_optimizer, build_lr_scheduler, makedirs, load_checkpoint, get_loss_func, \
25
+ save_checkpoint, build_model
26
+ from grover.util.utils import get_class_sizes, get_data, split_data, get_task_names
27
+ from task.predict import predict, evaluate, evaluate_predictions
28
+
29
+
30
+
31
+ def train(epoch, model, data, loss_func, optimizer, scheduler,
32
+ shared_dict, args: Namespace, n_iter: int = 0,
33
+ logger: logging.Logger = None):
34
+ """
35
+ Trains a model for an epoch.
36
+
37
+ :param model: Model.
38
+ :param data: A MoleculeDataset (or a list of MoleculeDatasets if using moe).
39
+ :param loss_func: Loss function.
40
+ :param optimizer: An Optimizer.
41
+ :param scheduler: A learning rate scheduler.
42
+ :param args: Arguments.
43
+ :param n_iter: The number of iterations (training examples) trained on so far.
44
+ :param logger: A logger for printing intermediate results.
45
+ :param writer: A tensorboardX SummaryWriter.
46
+ :return: The total number of iterations (training examples) trained on so far.
47
+ """
48
+ # debug = logger.debug if logger is not None else print
49
+
50
+ model.train()
51
+
52
+ # data.shuffle()
53
+
54
+ loss_sum, iter_count = 0, 0
55
+ cum_loss_sum, cum_iter_count = 0, 0
56
+
57
+
58
+ mol_collator = MolCollator(shared_dict=shared_dict, args=args)
59
+
60
+ num_workers = 4
61
+ if type(data) == DataLoader:
62
+ mol_loader = data
63
+ else:
64
+ mol_loader = DataLoader(data, batch_size=args.batch_size, shuffle=True,
65
+ num_workers=num_workers, collate_fn=mol_collator)
66
+
67
+ for _, item in enumerate(mol_loader):
68
+ _, batch, features_batch, mask, targets = item
69
+ if next(model.parameters()).is_cuda:
70
+ mask, targets = mask.cuda(), targets.cuda()
71
+ class_weights = torch.ones(targets.shape)
72
+
73
+ if args.cuda:
74
+ class_weights = class_weights.cuda()
75
+
76
+ # Run model
77
+ model.zero_grad()
78
+ preds = model(batch, features_batch)
79
+ loss = loss_func(preds, targets) * class_weights * mask
80
+ loss = loss.sum() / mask.sum()
81
+
82
+ loss_sum += loss.item()
83
+ iter_count += args.batch_size
84
+
85
+ cum_loss_sum += loss.item()
86
+ cum_iter_count += 1
87
+
88
+ loss.backward()
89
+ optimizer.step()
90
+
91
+ if isinstance(scheduler, NoamLR):
92
+ scheduler.step()
93
+
94
+ n_iter += args.batch_size
95
+
96
+ #if (n_iter // args.batch_size) % args.log_frequency == 0:
97
+ # lrs = scheduler.get_lr()
98
+ # loss_avg = loss_sum / iter_count
99
+ # loss_sum, iter_count = 0, 0
100
+ # lrs_str = ', '.join(f'lr_{i} = {lr:.4e}' for i, lr in enumerate(lrs))
101
+
102
+ return n_iter, cum_loss_sum / cum_iter_count
103
+
104
+
105
+ def run_training(args: Namespace, time_start, logger: Logger = None) -> List[float]:
106
+ """
107
+ Trains a model and returns test scores on the model checkpoint with the highest validation score.
108
+
109
+ :param args: Arguments.
110
+ :param logger: Logger.
111
+ :return: A list of ensemble scores for each task.
112
+ """
113
+ if logger is not None:
114
+ debug, info = logger.debug, logger.info
115
+ else:
116
+ debug = info = print
117
+
118
+
119
+ # pin GPU to local rank.
120
+ idx = args.gpu
121
+ if args.gpu is not None:
122
+ torch.cuda.set_device(idx)
123
+
124
+ features_scaler, scaler, shared_dict, test_data, train_data, val_data = load_data(args, debug, logger)
125
+
126
+ metric_func = get_metric_func(metric=args.metric)
127
+
128
+ # Set up test set evaluation
129
+ test_smiles, test_targets = test_data.smiles(), test_data.targets()
130
+ sum_test_preds = np.zeros((len(test_smiles), args.num_tasks))
131
+
132
+ # Train ensemble of models
133
+ for model_idx in range(args.ensemble_size):
134
+ # Tensorboard writer
135
+ save_dir = os.path.join(args.save_dir, f'model_{model_idx}')
136
+ makedirs(save_dir)
137
+
138
+ # Load/build model
139
+ if args.checkpoint_paths is not None:
140
+ if len(args.checkpoint_paths) == 1:
141
+ cur_model = 0
142
+ else:
143
+ cur_model = model_idx
144
+ debug(f'Loading model {cur_model} from {args.checkpoint_paths[cur_model]}')
145
+ model = load_checkpoint(args.checkpoint_paths[cur_model], current_args=args, logger=logger)
146
+ else:
147
+ debug(f'Building model {model_idx}')
148
+ model = build_model(model_idx=model_idx, args=args)
149
+
150
+ if args.fine_tune_coff != 1 and args.checkpoint_paths is not None:
151
+ debug("Fine tune fc layer with different lr")
152
+ initialize_weights(model_idx=model_idx, model=model.ffn, distinct_init=args.distinct_init)
153
+
154
+
155
+ ############### FREEZE BLOCK ###########
156
+ # for name, param in model.named_parameters():
157
+ # if name.startswith("grover."):
158
+ # param.requires_grad = False
159
+
160
+ # # Train prediction layers (readout + two FFNs)
161
+ # else:
162
+ # param.requires_grad = True
163
+
164
+ # print("TRAINABLE PARAMETERS:")
165
+ # for name, p in model.named_parameters():
166
+ # if p.requires_grad:
167
+ # print(" ", name)
168
+ ############### FREEZE BLOCK ###########
169
+
170
+ # Get loss and metric functions
171
+ loss_func = get_loss_func(args, model)
172
+
173
+ optimizer = build_optimizer(model, args)
174
+
175
+ debug(model)
176
+ debug(f'Number of parameters = {param_count(model):,}')
177
+ if args.cuda:
178
+ debug('Moving model to cuda')
179
+ model = model.cuda()
180
+
181
+ # Ensure that model is saved in correct location for evaluation if 0 epochs
182
+ save_checkpoint(os.path.join(save_dir, 'model.pt'), model, scaler, features_scaler, args)
183
+
184
+ # Learning rate schedulers
185
+ scheduler = build_lr_scheduler(optimizer, args)
186
+
187
+ # Bulid data_loader
188
+ shuffle = True
189
+ mol_collator = MolCollator(shared_dict={}, args=args)
190
+ train_data = DataLoader(train_data,
191
+ batch_size=args.batch_size,
192
+ shuffle=shuffle,
193
+ num_workers=10,
194
+ collate_fn=mol_collator)
195
+
196
+ # Run training
197
+ best_score = float('inf') if args.minimize_score else -float('inf')
198
+ best_epoch, n_iter = 0, 0
199
+ min_val_loss = float('inf')
200
+ for epoch in range(args.epochs):
201
+ s_time = time.time()
202
+ n_iter, train_loss = train(
203
+ epoch=epoch,
204
+ model=model,
205
+ data=train_data,
206
+ loss_func=loss_func,
207
+ optimizer=optimizer,
208
+ scheduler=scheduler,
209
+ args=args,
210
+ n_iter=n_iter,
211
+ shared_dict=shared_dict,
212
+ logger=logger
213
+ )
214
+ t_time = time.time() - s_time
215
+ s_time = time.time()
216
+ val_scores, val_loss = evaluate(
217
+ model=model,
218
+ data=val_data,
219
+ loss_func=loss_func,
220
+ num_tasks=args.num_tasks,
221
+ metric_func=metric_func,
222
+ batch_size=args.batch_size,
223
+ dataset_type=args.dataset_type,
224
+ scaler=scaler,
225
+ shared_dict=shared_dict,
226
+ logger=logger,
227
+ args=args
228
+ )
229
+ v_time = time.time() - s_time
230
+ # Average validation score
231
+ avg_val_score = np.nanmean(val_scores)
232
+ # Logged after lr step
233
+ if isinstance(scheduler, ExponentialLR):
234
+ scheduler.step()
235
+
236
+ if args.show_individual_scores:
237
+ # Individual validation scores
238
+ for task_name, val_score in zip(args.task_names, val_scores):
239
+ debug(f'Validation {task_name} {args.metric} = {val_score:.6f}')
240
+ print('Epoch: {:04d}'.format(epoch),
241
+ 'loss_train: {:.6f}'.format(train_loss),
242
+ 'loss_val: {:.6f}'.format(val_loss),
243
+ f'{args.metric}_val: {avg_val_score:.4f}',
244
+ # 'auc_val: {:.4f}'.format(avg_val_score),
245
+ 'cur_lr: {:.5f}'.format(scheduler.get_lr()[-1]),
246
+ 't_time: {:.4f}s'.format(t_time),
247
+ 'v_time: {:.4f}s'.format(v_time))
248
+
249
+ if args.tensorboard:
250
+ writer.add_scalar('loss/train', train_loss, epoch)
251
+ writer.add_scalar('loss/val', val_loss, epoch)
252
+ writer.add_scalar(f'{args.metric}_val', avg_val_score, epoch)
253
+
254
+
255
+ # Save model checkpoint if improved validation score
256
+ if args.select_by_loss:
257
+ if val_loss < min_val_loss:
258
+ min_val_loss, best_epoch = val_loss, epoch
259
+ save_checkpoint(os.path.join(save_dir, 'model.pt'), model, scaler, features_scaler, args)
260
+ else:
261
+ if args.minimize_score and avg_val_score < best_score or \
262
+ not args.minimize_score and avg_val_score > best_score:
263
+ best_score, best_epoch = avg_val_score, epoch
264
+ save_checkpoint(os.path.join(save_dir, 'model.pt'), model, scaler, features_scaler, args)
265
+
266
+ if epoch - best_epoch > args.early_stop_epoch:
267
+ break
268
+
269
+ ensemble_scores = 0.0
270
+
271
+ # Evaluate on test set using model with best validation score
272
+ if args.select_by_loss:
273
+ info(f'Model {model_idx} best val loss = {min_val_loss:.6f} on epoch {best_epoch}')
274
+ else:
275
+ info(f'Model {model_idx} best validation {args.metric} = {best_score:.6f} on epoch {best_epoch}')
276
+ model = load_checkpoint(os.path.join(save_dir, 'model.pt'), cuda=args.cuda, logger=logger)
277
+
278
+ test_preds, _ = predict(
279
+ model=model,
280
+ data=test_data,
281
+ loss_func=loss_func,
282
+ batch_size=args.batch_size,
283
+ logger=logger,
284
+ shared_dict=shared_dict,
285
+ scaler=scaler,
286
+ args=args
287
+ )
288
+
289
+ test_scores = evaluate_predictions(
290
+ preds=test_preds,
291
+ targets=test_targets,
292
+ num_tasks=args.num_tasks,
293
+ metric_func=metric_func,
294
+ dataset_type=args.dataset_type,
295
+ logger=logger
296
+ )
297
+
298
+ if len(test_preds) != 0:
299
+ sum_test_preds += np.array(test_preds, dtype=float)
300
+
301
+ # Average test score
302
+ avg_test_score = np.nanmean(test_scores)
303
+ info(f'Model {model_idx} test {args.metric} = {avg_test_score:.6f}')
304
+
305
+ if args.show_individual_scores:
306
+ # Individual test scores
307
+ for task_name, test_score in zip(args.task_names, test_scores):
308
+ info(f'Model {model_idx} test {task_name} {args.metric} = {test_score:.6f}')
309
+
310
+ # Evaluate ensemble on test set
311
+ avg_test_preds = (sum_test_preds / args.ensemble_size).tolist()
312
+
313
+ ensemble_scores = evaluate_predictions(
314
+ preds=avg_test_preds,
315
+ targets=test_targets,
316
+ num_tasks=args.num_tasks,
317
+ metric_func=metric_func,
318
+ dataset_type=args.dataset_type,
319
+ logger=logger
320
+ )
321
+
322
+ ind = [['preds'] * args.num_tasks + ['targets'] * args.num_tasks, args.task_names * 2]
323
+ ind = pd.MultiIndex.from_tuples(list(zip(*ind)))
324
+ data = np.concatenate([np.array(avg_test_preds), np.array(test_targets)], 1)
325
+ test_result = pd.DataFrame(data, index=test_smiles, columns=ind)
326
+ test_result.to_csv(os.path.join(args.save_dir, 'test_result.csv'))
327
+
328
+ # Average ensemble score
329
+ avg_ensemble_test_score = np.nanmean(ensemble_scores)
330
+ info(f'Ensemble test {args.metric} = {avg_ensemble_test_score:.6f}')
331
+
332
+ # Individual ensemble scores
333
+ if args.show_individual_scores:
334
+ for task_name, ensemble_score in zip(args.task_names, ensemble_scores):
335
+ info(f'Ensemble test {task_name} {args.metric} = {ensemble_score:.6f}')
336
+
337
+ return ensemble_scores
338
+
339
+
340
+ def load_data(args, debug, logger):
341
+ """
342
+ load the training data.
343
+ :param args:
344
+ :param debug:
345
+ :param logger:
346
+ :return:
347
+ """
348
+ # Get data
349
+ debug('Loading data')
350
+ args.task_names = get_task_names(args.data_path)
351
+ data = get_data(path=args.data_path, args=args, logger=logger)
352
+ if data.data[0].features is not None:
353
+ args.features_dim = len(data.data[0].features)
354
+ else:
355
+ args.features_dim = 0
356
+ shared_dict = {}
357
+ args.num_tasks = data.num_tasks()
358
+ args.features_size = data.features_size()
359
+ debug(f'Number of tasks = {args.num_tasks}')
360
+ # Split data
361
+ debug(f'Splitting data with seed {args.seed}')
362
+ if args.separate_test_path:
363
+ test_data = get_data(path=args.separate_test_path, args=args,
364
+ features_path=args.separate_test_features_path, logger=logger)
365
+ if args.separate_val_path:
366
+ val_data = get_data(path=args.separate_val_path, args=args,
367
+ features_path=args.separate_val_features_path, logger=logger)
368
+ if args.separate_val_path and args.separate_test_path:
369
+ train_data = data
370
+ elif args.separate_val_path:
371
+ train_data, _, test_data = split_data(data=data, split_type=args.split_type,
372
+ sizes=(0.8, 0.2, 0.0), seed=args.seed, args=args, logger=logger)
373
+ elif args.separate_test_path:
374
+ train_data, val_data, _ = split_data(data=data, split_type=args.split_type,
375
+ sizes=(0.8, 0.2, 0.0), seed=args.seed, args=args, logger=logger)
376
+ else:
377
+ train_data, val_data, test_data = split_data(data=data, split_type=args.split_type,
378
+ sizes=args.split_sizes, seed=args.seed, args=args, logger=logger)
379
+ if args.dataset_type == 'classification':
380
+ class_sizes = get_class_sizes(data)
381
+ debug('Class sizes')
382
+ for i, task_class_sizes in enumerate(class_sizes):
383
+ debug(f'{args.task_names[i]} '
384
+ f'{", ".join(f"{cls}: {size * 100:.2f}%" for cls, size in enumerate(task_class_sizes))}')
385
+
386
+ #if args.save_smiles_splits:
387
+ # save_splits(args, test_data, train_data, val_data)
388
+
389
+ if args.features_scaling:
390
+ features_scaler = train_data.normalize_features(replace_nan_token=0)
391
+ val_data.normalize_features(features_scaler)
392
+ test_data.normalize_features(features_scaler)
393
+ else:
394
+ features_scaler = None
395
+ args.train_data_size = len(train_data)
396
+ debug(f'Total size = {len(data):,} | '
397
+ f'train size = {len(train_data):,} | val size = {len(val_data):,} | test size = {len(test_data):,}')
398
+
399
+ # Initialize scaler and scale training targets by subtracting mean and dividing standard deviation (regression only)
400
+ if args.dataset_type == 'regression':
401
+ debug('Fitting scaler')
402
+ _, train_targets = train_data.smiles(), train_data.targets()
403
+ scaler = StandardScaler().fit(train_targets)
404
+ scaled_targets = scaler.transform(train_targets).tolist()
405
+ train_data.set_targets(scaled_targets)
406
+
407
+ val_targets = val_data.targets()
408
+ scaled_val_targets = scaler.transform(val_targets).tolist()
409
+ val_data.set_targets(scaled_val_targets)
410
+ else:
411
+ scaler = None
412
+ return features_scaler, scaler, shared_dict, test_data, train_data, val_data
413
+
414
+
415
+ def save_splits(args, test_data, train_data, val_data):
416
+ """
417
+ Save the splits.
418
+ :param args:
419
+ :param test_data:
420
+ :param train_data:
421
+ :param val_data:
422
+ :return:
423
+ """
424
+ with open(args.data_path, 'r') as f:
425
+ reader = csv.reader(f)
426
+ header = next(reader)
427
+
428
+ lines_by_smiles = {}
429
+ indices_by_smiles = {}
430
+ for i, line in enumerate(reader):
431
+ smiles = line[0]
432
+ lines_by_smiles[smiles] = line
433
+ indices_by_smiles[smiles] = i
434
+
435
+ all_split_indices = []
436
+ for dataset, name in [(train_data, 'train'), (val_data, 'val'), (test_data, 'test')]:
437
+ with open(os.path.join(args.save_dir, name + '_smiles.csv'), 'w') as f:
438
+ writer = csv.writer(f)
439
+ writer.writerow(['smiles'])
440
+ for smiles in dataset.smiles():
441
+ writer.writerow([smiles])
442
+ with open(os.path.join(args.save_dir, name + '_full.csv'), 'w') as f:
443
+ writer = csv.writer(f)
444
+ writer.writerow(header)
445
+ for smiles in dataset.smiles():
446
+ writer.writerow(lines_by_smiles[smiles])
447
+ split_indices = []
448
+ for smiles in dataset.smiles():
449
+ split_indices.append(indices_by_smiles[smiles])
450
+ split_indices = sorted(split_indices)
451
+ all_split_indices.append(split_indices)
452
+ with open(os.path.join(args.save_dir, 'split_indices.pckl'), 'wb') as f:
453
+ pickle.dump(all_split_indices, f)
454
+ return writer