Skip to content

Commit

Permalink
Merge pull request #5 from hits-mbm-dev/compatiblity/grappa_refactor
Browse files Browse the repository at this point in the history
Compatiblity/grappa refactor
  • Loading branch information
KRiedmiller authored Jan 31, 2024
2 parents 902398b + faf7d0c commit 5bc8ec2
Show file tree
Hide file tree
Showing 9 changed files with 5,037 additions and 192 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ example/**/pycallgraph.png

# tutorials and example output files
guide/kimmdy-tutorial

amber*

# excludes
!tests/test_files/test_coordinates/pull.trr
Expand Down
261 changes: 150 additions & 111 deletions src/grappa_interface.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,37 @@
import logging
import numpy as np

np.set_printoptions(suppress=True, precision=6) # prevent numpy exponential
import math
from typing import Union

from kimmdy.topology.topology import Topology
from kimmdy.topology.atomic import Bond, Angle, Dihedral, MultipleDihedrals
from kimmdy.topology.atomic import Atom, Bond, Angle, Dihedral, MultipleDihedrals
from kimmdy.topology.utils import get_by_permutations
from kimmdy.plugins import Parameterizer
from kimmdy.parsing import write_json

import grappa.ff
import openmm.unit
from grappa.data import Molecule
from grappa.data import Parameters

from grappa.utils.loading_utils import model_from_url
from grappa.grappa import Grappa

from openmm import unit as openmm_unit
from grappa.units import convert

logger = logging.getLogger("kimmdy.grappa_interface")


# helper functions
def check_equal_length(d: dict, name: str):
lengths = [len(y) for y in d.values()]
assert (
len(set(lengths)) == 1
), f"Different length of {name} parameters: { {k:len(v) for k, v in d.items()} }"


# unused?
def convert_to_python_types(array: Union[list, np.ndarray]) -> list:
return getattr(array, "tolist", lambda: array)()

Expand All @@ -36,6 +45,7 @@ def order_proper(idxs: np.ndarray) -> np.ndarray:
return np.flip(idxs)


# unused?
def elements_to_string(l: list):
for i, e in enumerate(l):
if isinstance(e, list):
Expand All @@ -44,129 +54,132 @@ def elements_to_string(l: list):
l[i] = str(e)


def clean_parameters(parameters: dict) -> dict:
harmonic_keys = {"idxs", "eq", "k"}
dihedral_keys = {"idxs", "phases", "ks", "ns"}
parameters_clean = {
"atom": {"idxs": [], "q": []},
"bond": {k: [] for k in harmonic_keys},
"angle": {k: [] for k in harmonic_keys},
"proper": {k: [] for k in dihedral_keys},
"improper": {k: [] for k in dihedral_keys},
# workflow functions
def build_molecule(top: Topology) -> Molecule:
at_map = top.ff.atomtypes
atom_info = {
"nr": [],
"atomic_number": [],
"partial_charges": [],
"sigma": [],
"epsilon": [],
"is_radical": [],
}
# convert from kJ/mol/deg-2 to kJ/mol/rad-2 because GROMACS units are inconsistent
parameters["angle_k"] = parameters["angle_k"] * (180.0**2 / math.pi**2)
parameters["proper_idxs"] = np.array(
[order_proper(x) for x in parameters["proper_idxs"]]
for atom in top.atoms.values():
atom_info["nr"].append(int(atom.nr))
atom_info["atomic_number"].append(int(at_map[atom.type].at_num))
atom_info["partial_charges"].append(float(atom.charge))
atom_info["sigma"].append(float(at_map[atom.type].sigma))
atom_info["epsilon"].append(float(at_map[atom.type].epsilon))
atom_info["is_radical"].append(int(atom.is_radical))

bonds = [(int(bond.ai), int(bond.aj)) for bond in top.bonds.values()]
impropers = [
(int(improper.ai), int(improper.aj), int(improper.ak), int(improper.al))
for improper in top.improper_dihedrals.values()
]

mol = Molecule(
atoms=atom_info["nr"],
bonds=bonds,
impropers=impropers,
atomic_numbers=atom_info["atomic_number"],
partial_charges=atom_info["partial_charges"],
additional_features={
k: np.asarray(v)
for k, v in atom_info.items()
if k not in ["nr", "atomic_number", "partial_charges"]
},
)
return mol

try:
for atomic in parameters_clean.keys():
for parameter in parameters_clean[atomic].keys():
key = atomic + "_" + parameter
parameters_clean[atomic][parameter] = convert_to_python_types(
parameters[key]
)
elements_to_string(parameters_clean[atomic][parameter])
except KeyError:
raise KeyError(
f"GrAPPa returned parameters {list(parameters.keys())}, which do not contain the required sections to fill {parameters_clean}"
)

for name, atomic in parameters_clean.items():
check_equal_length(atomic, name)
def convert_parameters(parameters: Parameters) -> Parameters:
"""Converts parameters to gromacs units
Assumes input parameters to be in kcal/mol, Angstrom und rad
Gromac units mostly kJ/mol, nm and degree
"""

## sample type check
assert parameters_clean["atom"]["idxs"][
0
].isdigit(), f"atom idxs element does not look like int {parameters_clean['atom']['idxs'][0]}."
assert (
parameters_clean["bond"]["k"][0]
.strip()
.lstrip("-")
.replace(".", "", 1)
.isdigit()
), f"b k element does not look like float {parameters_clean['bond']['k'][0]}."
assert isinstance(
parameters_clean["proper"]["ns"][0], list
), f"proper ns element has wrong type {type(parameters_clean['proper']['ns'][0])}, should be list."
assert (
parameters_clean["improper"]["phases"][0][0]
.strip()
.lstrip("-")
.replace(".", "", 1)
.isdigit()
), f"improper phases element does not look like float {type(parameters_clean['improper']['phases'][0][0])}"
return parameters_clean
distance_factor = convert(1, openmm_unit.angstrom, openmm_unit.nanometer)
degree_factor = convert(1, openmm_unit.radian, openmm_unit.degree)
energy_factor = convert(
1, openmm_unit.kilocalorie_per_mole, openmm_unit.kilojoule_per_mole
)

# convert parameters
parameters.bond_eq = parameters.bond_eq * distance_factor
parameters.bond_k = parameters.bond_k * energy_factor / np.power(distance_factor, 2)
# angles are given in degrees and force constants in kJ/mol/rad**2.
parameters.angle_eq = parameters.angle_eq * degree_factor
parameters.angle_k = parameters.angle_k * energy_factor

parameters.propers = np.array([order_proper(x) for x in parameters.propers])
parameters.proper_phases = parameters.proper_phases * degree_factor
parameters.proper_ks = parameters.proper_ks * energy_factor

parameters.improper_phases = parameters.improper_phases * degree_factor
parameters.improper_ks = parameters.improper_ks * energy_factor

# convert to list of strings
for k in parameters.__annotations__.keys():
v = getattr(parameters, k)
if len(v) == 0:
logger.info(f"Parameter list {k} is empty.")
else:
if isinstance(v[0], float):
v_list = [f"{i:11.4f}".strip() for i in v]
elif isinstance(v[0], np.ndarray) and isinstance(v[0, 0], float):
v_list = []
for sub_list in v:
v_list.append([f"{i:11.4f}".strip() for i in sub_list])
else:
v_list = v.astype(str).tolist()
setattr(parameters, k, v_list)

def generate_input(top: Topology) -> dict:
at_map = top.ff.atomtypes
atoms = [
[
int(atom.nr),
atom.atom,
atom.residue,
int(atom.resnr),
[float(at_map[atom.type].sigma), float(at_map[atom.type].epsilon)],
int(at_map[atom.type].at_num),
]
for atom in top.atoms.values()
]
atoms.sort(key=lambda x: x[2])
bonds = [(int(bond.ai), int(bond.aj)) for bond in top.bonds.values()]
radicals = [int(radical) for radical in top.radicals.keys()]
return parameters

return {"atoms": atoms, "bonds": bonds, "radicals": radicals}

def apply_parameters(top: Topology, parameters: Parameters):
"""Applies parameters to topology
def apply_parameters(top: Topology, parameters: dict):
# parameter structure is defined in clean_parameters()
# assume units are according to https://manual.gromacs.org/current/reference-manual/definitions.html
# namely: length [nm], mass [kg], time [ps], energy [kJ/mol], force [kJ mol-1 nm-1], angle [deg]
parameter structure is defined in grappa.data.Parameters.Parameters
assume units are according to https://manual.gromacs.org/current/reference-manual/definitions.html
namely: length [nm], mass [kg], time [ps], energy [kJ/mol], force [kJ mol-1 nm-1], angle [deg]
"""

## atoms
for i, idx in enumerate(parameters["atom"]["idxs"]):
if not (atom := top.atoms.get(idx)):
# raise KeyError(f"bad index {idx} in {list(top.atoms.keys())}")
logging.warning(
f"Ignored parameters with invalid ids: {idx} for atoms"
) # this can happen when removing a hydrogen in kimmdy-remove-hydrogen
continue
# can anything but charge change??
atom.charge = parameters["atom"]["q"][i]
atom.chargeB = None
# Nothing to do here because partial charges are dealt with elsewhere

## bonds
for i, idx in enumerate(parameters["bond"]["idxs"]):
for i, idx in enumerate(parameters.bonds):
tup = tuple(idx)
if not top.bonds.get(tup):
# raise KeyError(f"bad index {tup} in {list(top.bonds.keys())}")
logging.warning(f"Ignored parameters with invalid ids: {tup} for bonds")
continue
top.bonds[tup] = Bond(
*parameters["bond"]["idxs"][i],
*tup,
funct="1",
c0=parameters["bond"]["eq"][i],
c1=parameters["bond"]["k"][i],
c0=parameters.bond_eq[i],
c1=parameters.bond_k[i],
)

## angles
for i, idx in enumerate(parameters["angle"]["idxs"]):
for i, idx in enumerate(parameters.angles):
tup = tuple(idx)
if not top.angles.get(tup):
# raise KeyError(f"bad index {tup} in {list(top.angles.keys())}")
logging.warning(f"Ignored parameters with invalid ids: {tup} for angles")
continue
top.angles[tup] = Angle(
*parameters["angle"]["idxs"][i],
*tup,
funct="1",
c0=parameters["angle"]["eq"][i],
c1=parameters["angle"]["k"][i],
c0=parameters.angle_eq[i],
c1=parameters.angle_k[i],
)

## proper dihedrals
for i, idx in enumerate(parameters["proper"]["idxs"]):
for i, idx in enumerate(parameters.propers):
tup = tuple(idx)
if not top.proper_dihedrals.get(tup):
# raise KeyError(f"bad index {tup} in {list(top.proper_dihedrals.keys())}")
Expand All @@ -175,12 +188,13 @@ def apply_parameters(top: Topology, parameters: dict):
)
continue
dihedral_dict = {}
for ii, n in enumerate(parameters["proper"]["ns"][i]):
for ii in range(len(parameters.proper_ks[i])):
n = str(ii + 1)
dihedral_dict[n] = Dihedral(
*tup,
funct="9",
c0=parameters["proper"]["phases"][i][ii],
c1=parameters["proper"]["ks"][i][ii],
c0=parameters.proper_phases[i][ii],
c1=parameters.proper_ks[i][ii],
periodicity=n,
)
top.proper_dihedrals[tup] = MultipleDihedrals(
Expand All @@ -189,43 +203,68 @@ def apply_parameters(top: Topology, parameters: dict):

## improper dihedrals
top.improper_dihedrals = {}
for i, idx in enumerate(parameters["improper"]["idxs"]):
for i, idx in enumerate(parameters.impropers):
tup = tuple(idx)
for ii, n in enumerate(parameters["improper"]["ns"][i]):
for ii in range(len(parameters.improper_ks[i])):
n = str(ii + 1)
if not math.isclose(
float(parameters["improper"]["ks"][i][ii]), 0.0, abs_tol=1e-4
float(parameters.improper_ks[i][ii]), 0.0, abs_tol=1e-3
):
if not top.improper_dihedrals.get(tup):
if curr_improper := (top.improper_dihedrals.get(tup)) is None:
top.improper_dihedrals[tup] = Dihedral(
*tup,
funct="4",
c0=parameters["improper"]["phases"][i][ii],
c1=parameters["improper"]["ks"][i][ii],
c0=parameters.improper_phases[i][ii],
c1=parameters.improper_ks[i][ii],
periodicity=n,
)
else:
new_improper = Dihedral(
*tup,
funct="4",
c0=parameters.improper_phases[i][ii],
c1=parameters.improper_ks[i][ii],
periodicity=n,
)
if new_improper.c1 > curr_improper.c1:
top.improper_dihedrals[tup] = new_improper
deserted_improper = curr_improper
else:
deserted_improper = new_improper

logger.warning(
f"There are multiple improper dihedrals for {tup} and only one can be chosen, dihedral {n} with amplitude of {parameters['proper']['ks'][i][ii]} will be ignored."
f"There are multiple improper dihedrals for {tup} and only one can be chosen, dihedral p{deserted_improper} will be ignored."
)

return


def load_model():
"""Loads grappa model"""
# load model, tag will be changed to be more permanent
# model_tag = "https://github.com/LeifSeute/test_torchhub/releases/download/test_release_radicals/radical_model_12142023.pth" # older model
model_tag = "https://github.com/LeifSeute/test_torchhub/releases/download/model_release/grappa-1.0-01-26-2024.pth"
model = model_from_url(model_tag)
return model


class GrappaInterface(Parameterizer):
def parameterize_topology(
self, current_topology: Topology, focus_nr: list[str] = []
) -> Topology:
## get atoms, bonds, radicals in required format
input_dict = generate_input(current_topology)
write_json(input_dict, "in.json")
mol = build_molecule(current_topology)
logger.debug(mol.to_dict())

model = load_model()

# initialize class that handles ML part
grappa = Grappa(model, device="cpu")
parameters = grappa.predict(mol)

ff = grappa.ff.ForceField.from_tag("radical_latest")
ff.units["angle"] = openmm.unit.degree
# gromacs angle force constant are already in kJ/mol/rad-2]
parameters = ff.params_from_topology_dict(input_dict)
write_json(parameters, "out_raw.json")
parameters = clean_parameters(parameters)
write_json(parameters, "out_clean.json")
# convert units et cetera
parameters = convert_parameters(parameters)

# apply parameters
apply_parameters(current_topology, parameters)
return current_topology
1 change: 0 additions & 1 deletion tests/GrAPPa_input_alanine.json

This file was deleted.

Loading

0 comments on commit 5bc8ec2

Please sign in to comment.