Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save covalent map to Potential object & make energy function generator #74

Merged
merged 9 commits into from
Dec 1, 2022
9 changes: 7 additions & 2 deletions dmff/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,13 @@ class Potential:
def __init__(self):
self.dmff_potentials = {}
self.omm_system = None
self.meta = {}

def addDmffPotential(self, name, potential):
def addDmffPotential(self, name, potential, meta={}):
self.dmff_potentials[name] = potential
if len(meta):
for key in meta.keys():
self.meta[key] = meta[key]

def addOmmSystem(self, system):
self.omm_system = system
Expand Down Expand Up @@ -183,7 +187,8 @@ def createPotential(self,
continue
try:
potentialImpl = generator.getJaxPotential()
potObj.addDmffPotential(generator.name, potentialImpl)
meta = generator.getMetaData()
potObj.addDmffPotential(generator.name, potentialImpl, meta=meta)
except Exception as e:
print(e)
pass
Expand Down
27 changes: 26 additions & 1 deletion dmff/generators/admp.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(self, ff):
self.types = []
self.ethresh = 5e-4
self.pmax = 10
self._meta = {}

def extract(self):

Expand Down Expand Up @@ -160,6 +161,9 @@ def overwrite(self):

def getJaxPotential(self):
return self._jaxPotential

def getMetaData(self):
return self._meta


dmff.api.jaxGenerators['ADMPDispForce'] = ADMPDispGenerator
Expand All @@ -181,6 +185,7 @@ def __init__(self, ff):
self.ethresh = 5e-4
self.pmax = 10
self.name = "ADMPDispPmeForce"
self._meta = {}

def extract(self):

Expand Down Expand Up @@ -285,6 +290,9 @@ def potential_fn(positions, box, pairs, params):

def getJaxPotential(self):
return self._jaxPotential

def getMetaData(self):
return self._meta


dmff.api.jaxGenerators['ADMPDispPmeForce'] = ADMPDispPmeGenerator
Expand All @@ -302,6 +310,7 @@ def __init__(self, ff):
self.paramtree = ff.paramtree
self._jaxPotnetial = None
self.name = "QqTtDampingForce"
self._meta = {}

def extract(self):
# get mscales
Expand Down Expand Up @@ -373,6 +382,9 @@ def potential_fn(positions, box, pairs, params):

def getJaxPotential(self):
return self._jaxPotential

def getMetaData(self):
return self._meta


# register all parsers
Expand All @@ -392,6 +404,7 @@ def __init__(self, ff):
self.fftree = ff.fftree
self.paramtree = ff.paramtree
self._jaxPotential = None
self._meta = {}

def extract(self):
# get mscales
Expand Down Expand Up @@ -474,6 +487,9 @@ def potential_fn(positions, box, pairs, params):

def getJaxPotential(self):
return self._jaxPotential

def getMetaData(self):
return self._meta


dmff.api.jaxGenerators['SlaterDampingForce'] = SlaterDampingGenerator
Expand All @@ -490,6 +506,7 @@ def __init__(self, ff):
self.fftree = ff.fftree
self.paramtree = ff.paramtree
self._jaxPotential = None
self._meta = {}

def extract(self):
# get mscales
Expand Down Expand Up @@ -559,6 +576,9 @@ def potential_fn(positions, box, pairs, params):

def getJaxPotential(self):
return self._jaxPotential

def getMetaData(self):
return self._meta


dmff.api.jaxGenerators["SlaterExForce"] = SlaterExGenerator
Expand Down Expand Up @@ -613,6 +633,8 @@ def __init__(self, ff):
self.lpol = False
self.ref_dip = ""

self._meta = {}

def extract(self):

self.lmax = self.fftree.get_attribs(f'{self.name}',
Expand Down Expand Up @@ -850,7 +872,7 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff,

# build covalent map
self.covalent_map = covalent_map = build_covalent_map(data, 6)

self._meta["cov_map"] = self.covalent_map
# build intra-molecule axis
# the following code is the direct transplant of forcefield.py in openmm 7.4.0

Expand Down Expand Up @@ -1098,5 +1120,8 @@ def potential_fn(positions, box, pairs, params):
def getJaxPotential(self):
return self._jaxPotential

def getMetaData(self):
return self._meta


dmff.api.jaxGenerators["ADMPPmeForce"] = ADMPPmeGenerator
54 changes: 50 additions & 4 deletions dmff/generators/classical.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(self, ff: Hamiltonian):
self.ff: Hamiltonian = ff
self.fftree: ForcefieldTree = ff.fftree
self.paramtree: Dict = ff.paramtree
self._meta = {}

def extract(self):
"""
Expand Down Expand Up @@ -74,6 +75,7 @@ def createForce(self, sys, data, nonbondedMethod, nonbondedCutoff, args):
Args:
Those args are the same as those in createSystem.
"""
self._meta = {}

# initialize typemap
matcher = TypeMatcher(self.fftree, "HarmonicBondForce/Bond")
Expand Down Expand Up @@ -114,7 +116,10 @@ def createForce(self, sys, data, nonbondedMethod, nonbondedCutoff, args):

map_atom1 = np.array(map_atom1, dtype=int)
map_atom2 = np.array(map_atom2, dtype=int)
map_param = np.array(map_param, dtype=int)
map_param = np.array(map_param, dtype=int)
self._meta["HarmonicBondForce_atom1"] = map_atom1
self._meta["HarmonicBondForce_atom2"] = map_atom2
self._meta["HarmonicBondForce_param"] = map_param

bforce = HarmonicBondJaxForce(map_atom1, map_atom2, map_param)
self._force_latest = bforce
Expand All @@ -129,6 +134,9 @@ def potential_fn(positions, box, pairs, params):

def getJaxPotential(self):
return self._jaxPotential

def getMetaData(self):
return self._meta


dmff.api.jaxGenerators["HarmonicBondForce"] = HarmonicBondJaxGenerator
Expand All @@ -140,6 +148,7 @@ def __init__(self, ff):
self.ff = ff
self.fftree = ff.fftree
self.paramtree = ff.paramtree
self._meta = {}

def extract(self):
angles = self.fftree.get_attribs(f"{self.name}/Angle", "angle")
Expand All @@ -155,6 +164,8 @@ def overwrite(self):
self.paramtree[self.name]["k"])

def createForce(self, sys, data, nonbondedMethod, nonbondedCutoff, args):
self._meta = {}

matcher = TypeMatcher(self.fftree, "HarmonicAngleForce/Angle")

map_atom1, map_atom2, map_atom3, map_param = [], [], [], []
Expand Down Expand Up @@ -202,6 +213,10 @@ def createForce(self, sys, data, nonbondedMethod, nonbondedCutoff, args):
map_atom2 = np.array(map_atom2, dtype=int)
map_atom3 = np.array(map_atom3, dtype=int)
map_param = np.array(map_param, dtype=int)
self._meta["HarmonicAngleForce_atom1"] = map_atom1
self._meta["HarmonicAngleForce_atom2"] = map_atom2
self._meta["HarmonicAngleForce_atom3"] = map_atom3
self._meta["HarmonicAngleForce_param"] = map_param

aforce = HarmonicAngleJaxForce(map_atom1, map_atom2, map_atom3,
map_param)
Expand All @@ -217,6 +232,9 @@ def potential_fn(positions, box, pairs, params):

def getJaxPotential(self):
return self._jaxPotential

def getMetaData(self):
return self._meta


dmff.api.jaxGenerators["HarmonicAngleForce"] = HarmonicAngleJaxGenerator
Expand All @@ -229,7 +247,7 @@ def __init__(self, ff):
self.fftree = ff.fftree
self.paramtree = ff.paramtree
self.meta = {}

self._meta = {}
self.meta["prop_order"] = defaultdict(list)
self.meta["prop_nodeidx"] = defaultdict(list)

Expand Down Expand Up @@ -340,6 +358,7 @@ def createForce(self, sys, data, nonbondedMethod, nonbondedCutoff, args):
"""
Create force for torsions
"""

# Proper Torsions
proper_matcher = TypeMatcher(self.fftree,
"PeriodicTorsionForce/Proper")
Expand Down Expand Up @@ -487,6 +506,19 @@ def createForce(self, sys, data, nonbondedMethod, nonbondedCutoff, args):
self._props_latest = props
self._imprs_latest = imprs

self._meta["PeriodicTorsionForce_prop_atom1"] = map_prop_atom1
self._meta["PeriodicTorsionForce_prop_atom2"] = map_prop_atom2
self._meta["PeriodicTorsionForce_prop_atom3"] = map_prop_atom3
self._meta["PeriodicTorsionForce_prop_atom4"] = map_prop_atom4
self._meta["PeriodicTorsionForce_prop_param"] = map_prop_param

self._meta["PeriodicTorsionForce_impr_atom1"] = map_impr_atom1
self._meta["PeriodicTorsionForce_impr_atom2"] = map_impr_atom2
self._meta["PeriodicTorsionForce_impr_atom3"] = map_impr_atom3
self._meta["PeriodicTorsionForce_impr_atom4"] = map_impr_atom4
self._meta["PeriodicTorsionForce_impr_param"] = map_impr_param


def potential_fn(positions, box, pairs, params):
prop_sum = sum([
props[i].get_energy(
Expand All @@ -509,6 +541,9 @@ def potential_fn(positions, box, pairs, params):

def getJaxPotential(self):
return self._jaxPotential

def getMetaData(self):
return self._meta


dmff.api.jaxGenerators["PeriodicTorsionForce"] = PeriodicTorsionJaxGenerator
Expand All @@ -532,6 +567,8 @@ def __init__(self, ff: Hamiltonian):
self.useBCC = False
self.useVsite = False

self._meta = {}

def extract(self):
self.from_residue = self.fftree.get_attribs(
"NonbondedForce/UseAttributeFromResidue", "name")
Expand Down Expand Up @@ -684,6 +721,8 @@ def addVsiteFunc(pos, params):
cov_map[ori_dim + i, parent_i] = 1
self.covalent_map = jnp.array(cov_map)

self._meta["cov_map"] = self.covalent_map

# Load Lennard-Jones parameters
maps = {}
if not nbmatcher.useSmirks:
Expand Down Expand Up @@ -976,6 +1015,9 @@ def potential_fn(positions, box, pairs, params, vdwLambda,

def getJaxPotential(self):
return self._jaxPotential

def getMetaData(self):
return self._meta

def getAddVsiteFunc(self):
"""
Expand Down Expand Up @@ -1008,8 +1050,8 @@ def __init__(self, ff):
self.fftree = ff.fftree
self.paramtree = ff.paramtree
self.paramtree[self.name] = {}
self.paramtree[self.name]
self.paramtree[self.name]
self._meta


def extract(self):
for prm in ["sigma", "epsilon"]:
Expand Down Expand Up @@ -1109,6 +1151,7 @@ def findIdx(labels, label):
map_nbfix = jnp.array(map_nbfix)

colv_map = build_covalent_map(data, 6)
self._meta["cov_map"] = colv_map

if unit.is_quantity(nonbondedCutoff):
r_cut = nonbondedCutoff.value_in_unit(unit.nanometer)
Expand Down Expand Up @@ -1189,6 +1232,9 @@ def potential_fn(positions, box, pairs, params):

def getJaxPotential(self):
return self._jaxPotential

def getMetaData(self):
return self._meta


dmff.api.jaxGenerators["LennardJonesForce"] = LennardJonesGenerator
Loading