| import warnings |
| import tempfile |
|
|
| import torch |
| import numpy as np |
| from rdkit import Chem |
| from rdkit.Chem.rdForceFieldHelpers import UFFOptimizeMolecule, UFFHasAllMoleculeParams |
| import openbabel |
|
|
| import utils |
| from constants import bonds1, bonds2, bonds3, margin1, margin2, margin3, \ |
| bond_dict |
|
|
|
|
| def get_bond_order(atom1, atom2, distance): |
| distance = 100 * distance |
|
|
| if atom1 in bonds3 and atom2 in bonds3[atom1] and distance < bonds3[atom1][atom2] + margin3: |
| return 3 |
|
|
| if atom1 in bonds2 and atom2 in bonds2[atom1] and distance < bonds2[atom1][atom2] + margin2: |
| return 2 |
|
|
| if atom1 in bonds1 and atom2 in bonds1[atom1] and distance < bonds1[atom1][atom2] + margin1: |
| return 1 |
|
|
| return 0 |
|
|
|
|
| def get_bond_order_batch(atoms1, atoms2, distances, dataset_info): |
| if isinstance(atoms1, np.ndarray): |
| atoms1 = torch.from_numpy(atoms1) |
| if isinstance(atoms2, np.ndarray): |
| atoms2 = torch.from_numpy(atoms2) |
| if isinstance(distances, np.ndarray): |
| distances = torch.from_numpy(distances) |
|
|
| distances = 100 * distances |
|
|
| bonds1 = torch.tensor(dataset_info['bonds1'], device=atoms1.device) |
| bonds2 = torch.tensor(dataset_info['bonds2'], device=atoms1.device) |
| bonds3 = torch.tensor(dataset_info['bonds3'], device=atoms1.device) |
|
|
| bond_types = torch.zeros_like(atoms1) |
|
|
| |
| bond_types[distances < bonds1[atoms1, atoms2] + margin1] = 1 |
|
|
| |
| bond_types[distances < bonds2[atoms1, atoms2] + margin2] = 2 |
|
|
| |
| bond_types[distances < bonds3[atoms1, atoms2] + margin3] = 3 |
|
|
| return bond_types |
|
|
|
|
| def make_mol_openbabel(positions, atom_types, atom_decoder): |
| """ |
| Build an RDKit molecule using openbabel for creating bonds |
| Args: |
| positions: N x 3 |
| atom_types: N |
| atom_decoder: maps indices to atom types |
| Returns: |
| rdkit molecule |
| """ |
| atom_types = [atom_decoder[x] for x in atom_types] |
|
|
| with tempfile.NamedTemporaryFile() as tmp: |
| tmp_file = tmp.name |
|
|
| |
| utils.write_xyz_file(positions, atom_types, tmp_file) |
|
|
| |
| |
| obConversion = openbabel.OBConversion() |
| obConversion.SetInAndOutFormats("xyz", "sdf") |
| ob_mol = openbabel.OBMol() |
| obConversion.ReadFile(ob_mol, tmp_file) |
|
|
| obConversion.WriteFile(ob_mol, tmp_file) |
|
|
| |
| tmp_mol = Chem.SDMolSupplier(tmp_file, sanitize=False)[0] |
|
|
| |
| mol = Chem.RWMol() |
| for atom in tmp_mol.GetAtoms(): |
| mol.AddAtom(Chem.Atom(atom.GetSymbol())) |
| mol.AddConformer(tmp_mol.GetConformer(0)) |
|
|
| for bond in tmp_mol.GetBonds(): |
| mol.AddBond(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx(), |
| bond.GetBondType()) |
|
|
| return mol |
|
|
|
|
| def make_mol_edm(positions, atom_types, dataset_info, add_coords): |
| """ |
| Equivalent to EDM's way of building RDKit molecules |
| """ |
| n = len(positions) |
|
|
| |
| |
| |
| |
| pos = positions.unsqueeze(0) |
| dists = torch.cdist(pos, pos, p=2).squeeze(0).view(-1) |
| atoms1, atoms2 = torch.cartesian_prod(atom_types, atom_types).T |
| E_full = get_bond_order_batch(atoms1, atoms2, dists, dataset_info).view(n, n) |
| E = torch.tril(E_full, diagonal=-1) |
| A = E.bool() |
| X = atom_types |
|
|
| mol = Chem.RWMol() |
| for atom in X: |
| a = Chem.Atom(dataset_info["atom_decoder"][atom.item()]) |
| mol.AddAtom(a) |
|
|
| all_bonds = torch.nonzero(A) |
| for bond in all_bonds: |
| mol.AddBond(bond[0].item(), bond[1].item(), |
| bond_dict[E[bond[0], bond[1]].item()]) |
|
|
| if add_coords: |
| conf = Chem.Conformer(mol.GetNumAtoms()) |
| for i in range(mol.GetNumAtoms()): |
| conf.SetAtomPosition(i, (positions[i, 0].item(), |
| positions[i, 1].item(), |
| positions[i, 2].item())) |
| mol.AddConformer(conf) |
|
|
| return mol |
|
|
|
|
| def build_molecule(positions, atom_types, dataset_info, add_coords=False, |
| use_openbabel=True): |
| """ |
| Build RDKit molecule |
| Args: |
| positions: N x 3 |
| atom_types: N |
| dataset_info: dict |
| add_coords: Add conformer to mol (always added if use_openbabel=True) |
| use_openbabel: use OpenBabel to create bonds |
| Returns: |
| RDKit molecule |
| """ |
| if use_openbabel: |
| mol = make_mol_openbabel(positions, atom_types, |
| dataset_info["atom_decoder"]) |
| else: |
| mol = make_mol_edm(positions, atom_types, dataset_info, add_coords) |
|
|
| return mol |
|
|
|
|
| def process_molecule(rdmol, add_hydrogens=False, sanitize=False, relax_iter=0, |
| largest_frag=False): |
| """ |
| Apply filters to an RDKit molecule. Makes a copy first. |
| Args: |
| rdmol: rdkit molecule |
| add_hydrogens |
| sanitize |
| relax_iter: maximum number of UFF optimization iterations |
| largest_frag: filter out the largest fragment in a set of disjoint |
| molecules |
| Returns: |
| RDKit molecule or None if it does not pass the filters |
| """ |
|
|
| |
| mol = Chem.Mol(rdmol) |
|
|
| if sanitize: |
| try: |
| Chem.SanitizeMol(mol) |
| except ValueError: |
| warnings.warn('Sanitization failed. Returning None.') |
| return None |
|
|
| if add_hydrogens: |
| mol = Chem.AddHs(mol, addCoords=(len(mol.GetConformers()) > 0)) |
|
|
| if largest_frag: |
| mol_frags = Chem.GetMolFrags(mol, asMols=True, sanitizeFrags=False) |
| mol = max(mol_frags, default=mol, key=lambda m: m.GetNumAtoms()) |
| if sanitize: |
| |
| try: |
| Chem.SanitizeMol(mol) |
| except ValueError: |
| return None |
|
|
| if relax_iter > 0: |
| if not UFFHasAllMoleculeParams(mol): |
| warnings.warn('UFF parameters not available for all atoms. ' |
| 'Returning None.') |
| return None |
|
|
| try: |
| uff_relax(mol, relax_iter) |
| if sanitize: |
| |
| Chem.SanitizeMol(mol) |
| except (RuntimeError, ValueError) as e: |
| return None |
|
|
| return mol |
|
|
|
|
| def uff_relax(mol, max_iter=200): |
| """ |
| Uses RDKit's universal force field (UFF) implementation to optimize a |
| molecule. |
| """ |
| more_iterations_required = UFFOptimizeMolecule(mol, maxIters=max_iter) |
| if more_iterations_required: |
| warnings.warn(f'Maximum number of FF iterations reached. ' |
| f'Returning molecule after {max_iter} relaxation steps.') |
| return more_iterations_required |
|
|
|
|
| def filter_rd_mol(rdmol): |
| """ |
| Filter out RDMols if they have a 3-3 ring intersection |
| adapted from: |
| https://github.com/luost26/3D-Generative-SBDD/blob/main/utils/chem.py |
| """ |
| ring_info = rdmol.GetRingInfo() |
| ring_info.AtomRings() |
| rings = [set(r) for r in ring_info.AtomRings()] |
|
|
| |
| for i, ring_a in enumerate(rings): |
| if len(ring_a) != 3: |
| continue |
| for j, ring_b in enumerate(rings): |
| if i <= j: |
| continue |
| inter = ring_a.intersection(ring_b) |
| if (len(ring_b) == 3) and (len(inter) > 0): |
| return False |
|
|
| return True |
|
|