Skip to content

Commit

Permalink
Fix bug when rendering xml files
Browse files Browse the repository at this point in the history
  • Loading branch information
WangXinyan940 committed Nov 9, 2023
1 parent 2316e28 commit 4cb4de2
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 62 deletions.
5 changes: 3 additions & 2 deletions dmff/common/nblist.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def _do_cov_map(self, pairs):
pairs = jnp.concatenate([pairs, nbond[:, None]], axis=1)
return pairs

def allocate(self, coords):
def allocate(self, coords, box=None):
self._positions = coords # cache it
natoms = coords.shape[0]
nblist = np.fromiter(permutations(range(natoms), 2), dtype=np.dtype(int, 2))
Expand All @@ -111,7 +111,7 @@ def allocate(self, coords):
raise ValueError("padding width < 0")

def update(self, positions, box=None):
self.allocate(positions, box)
self.allocate(positions)

@property
def pairs(self):
Expand All @@ -130,6 +130,7 @@ class NoPeriodicNeighborList(NoCutoffNeighborList):

def __init__(self, rcut, cov_map, padding=True):
super().__init__(cov_map, padding)
self.rcut = rcut

def allocate(self, coords):
self._positions = coords # cache it
Expand Down
88 changes: 36 additions & 52 deletions dmff/generators/admp.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def getName(self):
return self.name

def overwrite(self, paramset):
atom_mask = paramset.mask[self.name]["sigma"]
atom_mask = paramset.mask[self.name]["A"]
A = paramset[self.name]["A"]
B = paramset[self.name]["B"]
Q = paramset[self.name]["Q"]
Expand Down Expand Up @@ -452,10 +452,10 @@ def __init__(self, ffinfo: dict, paramset: ParamSet):
def getName(self) -> str:
return self.name

def overwrite(self):
B = self.paramtree[self.name]["B"]
Q = self.paramtree[self.name]["Q"]
atom_mask = self.paramtree.mask[self.name]["B"]
def overwrite(self, paramset):
B = paramset[self.name]["B"]
Q = paramset[self.name]["Q"]
atom_mask = paramset.mask[self.name]["B"]

nnode = 0
for node in self.ffinfo["Forces"][self.name]["node"]:
Expand Down Expand Up @@ -520,8 +520,6 @@ def potential_fn(positions, box, pairs, params, aux = None):
def getJaxPotential(self):
return self._jaxPotential

def getMetaData(self):
return self._meta


# register all parsers
Expand Down Expand Up @@ -589,12 +587,12 @@ def __init__(self, ffinfo: dict, paramset: ParamSet):
def getName(self) -> str:
return self.name

def overwrite(self):
B = self.paramtree[self.name]["B"]
C6 = self.paramtree[self.name]["C6"]
C8 = self.paramtree[self.name]["C8"]
C10 = self.paramtree[self.name]["C10"]
atom_mask = self.paramtree.mask[self.name]["B"]
def overwrite(self, paramset):
B = paramset[self.name]["B"]
C6 = paramset[self.name]["C6"]
C8 = paramset[self.name]["C8"]
C10 = paramset[self.name]["C10"]
atom_mask = paramset.mask[self.name]["B"]

nnode = 0
for node in self.ffinfo["Forces"][self.name]["node"]:
Expand All @@ -605,18 +603,10 @@ def overwrite(self):
C8_new = C8[nnode]
C10_new = C10[nnode]
mask = atom_mask[nnode]
self.ffinfo["Forces"][self.name]["node"][nnode]["attrib"]["B"] = str(
B_new
)
self.ffinfo["Forces"][self.name]["node"][nnode]["attrib"]["C6"] = str(
C6_new
)
self.ffinfo["Forces"][self.name]["node"][nnode]["attrib"]["C8"] = str(
C8_new
)
self.ffinfo["Forces"][self.name]["node"][nnode]["attrib"]["C10"] = str(
C10_new
)
self.ffinfo["Forces"][self.name]["node"][nnode]["attrib"]["B"] = B_new
self.ffinfo["Forces"][self.name]["node"][nnode]["attrib"]["C6"] = C6_new
self.ffinfo["Forces"][self.name]["node"][nnode]["attrib"]["C8"] = C8_new
self.ffinfo["Forces"][self.name]["node"][nnode]["attrib"]["C10"] = C10_new
if mask < 0.999:
self.ffinfo["Forces"][self.name]["node"][nnode]["attrib"][
"mask"
Expand Down Expand Up @@ -736,10 +726,10 @@ def __init__(self, ffinfo: dict, paramset: ParamSet, default_name=None):
def getName(self) -> str:
return self.name

def overwrite(self):
A = self.paramtree[self.name]["A"]
B = self.paramtree[self.name]["B"]
atom_mask = self.paramtree.mask[self.name]["B"]
def overwrite(self, paramset):
A = paramset[self.name]["A"]
B = paramset[self.name]["B"]
atom_mask = paramset.mask[self.name]["B"]

nnode = 0
for node in self.ffinfo["Forces"][self.name]["node"]:
Expand All @@ -748,12 +738,8 @@ def overwrite(self):
A_new = A[nnode]
B_new = B[nnode]
mask = atom_mask[nnode]
self.ffinfo["Forces"][self.name]["node"][nnode]["attrib"]["A"] = str(
A_new
)
self.ffinfo["Forces"][self.name]["node"][nnode]["attrib"]["B"] = str(
B_new
)
self.ffinfo["Forces"][self.name]["node"][nnode]["attrib"]["A"] = A_new
self.ffinfo["Forces"][self.name]["node"][nnode]["attrib"]["B"] = B_new
if mask < 0.999:
self.ffinfo["Forces"][self.name]["node"][nnode]["attrib"][
"mask"
Expand Down Expand Up @@ -799,8 +785,6 @@ def potential_fn(positions, box, pairs, params, aux=None):
def getJaxPotential(self):
return self._jaxPotential

def getMetaData(self):
return self._meta


_DMFFGenerators["SlaterExForce"] = SlaterExGenerator
Expand Down Expand Up @@ -1094,28 +1078,28 @@ def overwrite(self, paramset):
for nnode in range(len(self.ffinfo["Forces"][self.name]["node"])):
node = self.ffinfo["Forces"][self.name]["node"][nnode]
if node["name"] in ["Atom", "Multipole"]:
node["c0"] = Q_global[n_multipole, 0]
node["attrib"]["c0"] = Q_global[n_multipole, 0]
if self.lmax >= 1:
node["dX"] = Q_global[n_multipole, 1] * 0.1
node["dY"] = Q_global[n_multipole, 2] * 0.1
node["dZ"] = Q_global[n_multipole, 3] * 0.1
node["attrib"]["dX"] = Q_global[n_multipole, 1] * 0.1
node["attrib"]["dY"] = Q_global[n_multipole, 2] * 0.1
node["attrib"]["dZ"] = Q_global[n_multipole, 3] * 0.1
if self.lmax >= 2:
node["qXX"] = Q_global[n_multipole, 4] / 300.0
node["qYY"] = Q_global[n_multipole, 5] / 300.0
node["qZZ"] = Q_global[n_multipole, 6] / 300.0
node["qXY"] = Q_global[n_multipole, 7] / 300.0
node["qXZ"] = Q_global[n_multipole, 8] / 300.0
node["qYZ"] = Q_global[n_multipole, 9] / 300.0
node["attrib"]["qXX"] = Q_global[n_multipole, 4] / 300.0
node["attrib"]["qYY"] = Q_global[n_multipole, 5] / 300.0
node["attrib"]["qZZ"] = Q_global[n_multipole, 6] / 300.0
node["attrib"]["qXY"] = Q_global[n_multipole, 7] / 300.0
node["attrib"]["qXZ"] = Q_global[n_multipole, 8] / 300.0
node["attrib"]["qYZ"] = Q_global[n_multipole, 9] / 300.0
if q_local_masks[n_multipole] < 0.999:
node["mask"] = "true"
n_multipole += 1
elif node["name"] == "Polarize":
node["polarizabilityXX"] = paramset[self.name]["pol"][n_pol] * 0.001
node["polarizabilityYY"] = paramset[self.name]["pol"][n_pol] * 0.001
node["polarizabilityZZ"] = paramset[self.name]["pol"][n_pol] * 0.001
node["thole"] = paramset[self.name]["thole"][n_pol]
node["attrib"]["polarizabilityXX"] = paramset[self.name]["pol"][n_pol] * 0.001
node["attrib"]["polarizabilityYY"] = paramset[self.name]["pol"][n_pol] * 0.001
node["attrib"]["polarizabilityZZ"] = paramset[self.name]["pol"][n_pol] * 0.001
node["attrib"]["thole"] = paramset[self.name]["thole"][n_pol]
if polar_masks[n_pol] < 0.999:
node["mask"] = "true"
node["attrib"]["mask"] = "true"
n_pol += 1

def _find_multipole_key_index(self, atype: str):
Expand Down
2 changes: 1 addition & 1 deletion dmff/generators/qeq.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def overwrite(self, paramset: ParamSet) -> None:
chi = paramset[self.name]["chi"]
J = paramset[self.name]["J"]
eta = paramset[self.name]["eta"]
atom_mask = paramset[self.name]["mask"]
atom_mask = paramset.mask[self.name]["chi"]
for nidx, idx in enumerate(node_indices):
chi0 = chi[nidx]
J0 = J[nidx]
Expand Down
14 changes: 7 additions & 7 deletions tests/data/qeq2.xml
Original file line number Diff line number Diff line change
Expand Up @@ -163,12 +163,12 @@
</Residue>
</Residues>
<ADMPQeqForce coulomb14scale="1.0" DampMod="3" >
<QeqAtom type="1" chi="1.39198761e+02" J="-6.51683561e-01" eta="0.08360402"/>
<QeqAtom type="2" chi="1.47411181e+02" J="-3.42450470e-02" eta="0.08360402"/>
<QeqAtom type="3" chi="7.58342711e+01" J="-5.62083783e-01" eta="0.08360402"/>
<QeqAtom type="4" chi="1.40533063e+02" J="-6.33544439e-01" eta="0.08360402"/>
<QeqAtom type="5" chi="1.37525951e+02" J="-6.63409032e-01" eta="0.08360402"/>
<QeqAtom type="6" chi="-1.157e+03" J="-3.873e-01" eta="0.0" />
<QeqAtom type="7" chi="1.365e+03" J="-5.378e-01" eta="0.0" />
<Atom type="1" chi="1.39198761e+02" J="-6.51683561e-01" eta="0.08360402"/>
<Atom type="2" chi="1.47411181e+02" J="-3.42450470e-02" eta="0.08360402"/>
<Atom type="3" chi="7.58342711e+01" J="-5.62083783e-01" eta="0.08360402"/>
<Atom type="4" chi="1.40533063e+02" J="-6.33544439e-01" eta="0.08360402"/>
<Atom type="5" chi="1.37525951e+02" J="-6.63409032e-01" eta="0.08360402"/>
<Atom type="6" chi="-1.157e+03" J="-3.873e-01" eta="0.0" />
<Atom type="7" chi="1.365e+03" J="-5.378e-01" eta="0.0" />
</ADMPQeqForce>
</ForceField>

0 comments on commit 4cb4de2

Please sign in to comment.