Skip to content

Commit

Permalink
Fix bug and add unit test for octupole implementation.
Browse files Browse the repository at this point in the history
  • Loading branch information
dingye18 committed Mar 7, 2024
1 parent 31e6e36 commit e98db17
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 202 deletions.
14 changes: 8 additions & 6 deletions dmff/admp/recip.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,8 +404,8 @@ def sph_harmonics_GO(u0, Nj_Aji_star):
rt3/2 * (theta2prime[:, 0, 0] - theta2prime[:, 1, 1]),
rt3 * theta2prime[:, 0, 1]], axis = 1)]
)
#if lmax == 2:
# return harmonics_2.reshape(N_a, n_mesh, n_harm)
if lmax == 2:
return harmonics_2.reshape(N_a, n_mesh, n_harm)
#else:
# raise NotImplementedError('l > 2 (beyond quadrupole) not supported')

Expand All @@ -419,9 +419,9 @@ def sph_harmonics_GO(u0, Nj_Aji_star):
harmonics_3 = jnp.hstack(
[harmonics_2,
jnp.stack([
(5 * theta3prime_eval[:, 2, 2, 2] - 3 * jnp.trace(theta3prime_eval[:, 2], axis = 0, axis = 1)) / 2,
rt6 * (5 * theta3prime_eval[:, 0, 2, 2] - jnp.trace(theta3prime_eval[:, 0], axis = 0, axis = 1)) / 4,
rt6 * (5 * theta3prime_eval[:, 1, 2, 2] - jnp.trace(theta3prime_eval[:, 1], axis = 0, axis = 1)) / 4,
(5 * theta3prime_eval[:, 2, 2, 2] - 3 * jnp.trace(theta3prime_eval[:, 2], axis1 = 0, axis2 = 1)) / 2,
rt6 * (5 * theta3prime_eval[:, 0, 2, 2] - jnp.trace(theta3prime_eval[:, 0], axis1 = 0, axis2 = 1)) / 4,
rt6 * (5 * theta3prime_eval[:, 1, 2, 2] - jnp.trace(theta3prime_eval[:, 1], axis1 = 0, axis2 = 1)) / 4,
rt15 * (theta3prime_eval[:, 2, 0, 0] - theta3prime_eval[:, 2, 1,1]) / 2,
rt15 * theta3prime_eval[:, 0, 1, 2],
rt10 * (theta3prime_eval[:, 0, 0, 0] - 3 * theta3prime_eval[:, 0, 1, 1]) / 4,
Expand All @@ -430,7 +430,7 @@ def sph_harmonics_GO(u0, Nj_Aji_star):
]
)

if lmax == 2:
if lmax == 3:
return harmonics_3.reshape(N_a, n_mesh, n_harm)


Expand Down Expand Up @@ -462,6 +462,8 @@ def Q_m_peratom(Q, sph_harms):
Q_dbf = jnp.hstack([Q_dbf, Q[:,1:4]])
if lmax >= 2:
Q_dbf = jnp.hstack([Q_dbf, Q[:,4:9]/3])
if lmax >= 3:
Q_dbf = jnp.hstack([Q_dbf, Q[:,9:16]/15])

Q_m_pera = jnp.sum(Q_dbf[:,jnp.newaxis,:]* sph_harms, axis=2)

Expand Down
2 changes: 2 additions & 0 deletions dmff/generators/admp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1077,6 +1077,8 @@ def __init__(self, ffinfo: dict, paramset: ParamSet):
n_mtps = 4
elif self.lmax == 2:
n_mtps = 10
elif self.lmax == 3:
n_mtps = 20
Q = np.zeros((n_atoms, n_mtps))

# TDDO: unit conversion
Expand Down
2 changes: 1 addition & 1 deletion tests/data/admp_octupole.xml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
<Atom type="381" A="83.2283563" B="37.78544799" Q="0.370853" C6="5.7929e-05" C8="1.416624e-06" C10="2.26525e-08"/>
</ADMPDispForce>

<ADMPPmeForce lmax="2" mScale12="0.00" mScale13="0.00" mScale14="0.00" mScale15="1.00" mScale16="1.00" pScale12="0.00" pScale13="0.00" pScale14="0.00" pScale15="1.00" pScale16="1.00" dScale12="0.00" dScale13="0.00" dScale14="0.00" dScale15="1.00" dScale16="1.00">
<ADMPPmeForce lmax="3" mScale12="0.00" mScale13="0.00" mScale14="0.00" mScale15="1.00" mScale16="1.00" pScale12="0.00" pScale13="0.00" pScale14="0.00" pScale15="1.00" pScale16="1.00" dScale12="0.00" dScale13="0.00" dScale14="0.00" dScale15="1.00" dScale16="1.00">
<Atom type="380" kz="-381" kx="-381"
c0="-1.0614"
dX="0.0" dY="0.0" dZ="-0.023671684"
Expand Down
40 changes: 34 additions & 6 deletions tests/test_admp/test_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,18 @@ def test_init(self):
H = Hamiltonian('tests/data/admp.xml')
H1 = Hamiltonian('tests/data/admp_mono.xml')
H2 = Hamiltonian('tests/data/admp_nonpol.xml')
H3 = Hamiltonian('tests/data/admp_octupole.xml')
pdb = app.PDBFile('tests/data/water_dimer.pdb')
potential = H.createPotential(pdb.topology, nonbondedMethod=app.CutoffPeriodic, nonbondedCutoff=rc*unit.angstrom, ethresh=5e-4, step_pol=5)
potential_aux = H.createPotential(pdb.topology, nonbondedMethod=app.CutoffPeriodic, nonbondedCutoff=rc*unit.angstrom, ethresh=5e-4, step_pol=5, has_aux=True)
potential1 = H1.createPotential(pdb.topology, nonbondedMethod=app.CutoffPeriodic, nonbondedCutoff=rc*unit.angstrom, ethresh=5e-4, step_pol=5)
potential2 = H2.createPotential(pdb.topology, nonbondedMethod=app.CutoffPeriodic, nonbondedCutoff=rc*unit.angstrom, ethresh=5e-4, step_pol=5)
potential3 = H3.createPotential(pdb.topology, nonbondedMethod=app.CutoffPeriodic, nonbondedCutoff=rc*unit.angstrom, ethresh=5e-4, step_pol=5, has_aux=True)

yield potential, potential_aux, potential1, potential2, H.paramset, H1.paramset, H2.paramset
yield potential, potential_aux, potential1, potential2, potential3, H.paramset, H1.paramset, H2.paramset, H3.paramset

def test_ADMPPmeForce(self, pot_prm):
potential, potential_aux, potential1, potential2, paramset, paramset1, paramset2 = pot_prm
potential, potential_aux, potential1, potential2, potential3, paramset, paramset1, paramset2, paramset3 = pot_prm
rc = 0.4
pdb = app.PDBFile('tests/data/water_dimer.pdb')
positions = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)
Expand All @@ -56,7 +58,7 @@ def test_ADMPPmeForce(self, pot_prm):


def test_ADMPPmeForce_jit(self, pot_prm):
potential, potential_aux, potential1, potential2, paramset, paramset1, paramset2 = pot_prm
potential, potential_aux, potential1, potential2, potential3, paramset, paramset1, paramset2, paramset3 = pot_prm
rc = 0.4
pdb = app.PDBFile('tests/data/water_dimer.pdb')
positions = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)
Expand All @@ -76,7 +78,7 @@ def test_ADMPPmeForce_jit(self, pot_prm):
np.testing.assert_almost_equal(energy, -35.71585296268245, decimal=1)

def test_ADMPPmeForce_aux(self, pot_prm):
potential, potential_aux, potential1, potential2, paramset, paramset1, paramset2 = pot_prm
potential, potential_aux, potential1, potential2, potential3, paramset, paramset1, paramset2, paramset3 = pot_prm
rc = 0.4
pdb = app.PDBFile('tests/data/water_dimer.pdb')
positions = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)
Expand Down Expand Up @@ -104,7 +106,7 @@ def test_ADMPPmeForce_aux(self, pot_prm):


def test_ADMPPmeForce_mono(self, pot_prm):
potential, potential_aux, potential1, potential2, paramset, paramset1, paramset2 = pot_prm
potential, potential_aux, potential1, potential2, potential3, paramset, paramset1, paramset2, paramset3 = pot_prm
rc = 0.4
pdb = app.PDBFile('tests/data/water_dimer.pdb')
positions = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)
Expand All @@ -125,7 +127,7 @@ def test_ADMPPmeForce_mono(self, pot_prm):


def test_ADMPPmeForce_nonpol(self, pot_prm):
potential, potential_aux, potential1, potential2, paramset, paramset1, paramset2 = pot_prm
potential, potential_aux, potential1, potential2, potential3, paramset, paramset1, paramset2, paramset3 = pot_prm
rc = 0.4
pdb = app.PDBFile('tests/data/water_dimer.pdb')
positions = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)
Expand All @@ -143,3 +145,29 @@ def test_ADMPPmeForce_nonpol(self, pot_prm):
energy = pot(positions, box, pairs, paramset2)
print(energy)
np.testing.assert_almost_equal(energy, -31.65932348, decimal=2)

def test_ADMPPmeForce_octupole(self, pot_prm):
potential, potential_aux, potential1, potential2, potential3, paramset, paramset1, paramset2, paramset3 = pot_prm
rc = 0.4
pdb = app.PDBFile('tests/data/water_dimer.pdb')
positions = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)
positions = jnp.array(positions)
a, b, c = pdb.topology.getPeriodicBoxVectors().value_in_unit(unit.nanometer)
box = jnp.array([a, b, c])
# neighbor list

covalent_map = potential3.meta["cov_map"]

nblist = NeighborList(box, rc, covalent_map)
nblist.allocate(positions)
pairs = nblist.pairs
pot = potential3.getPotentialFunc(names=["ADMPPmeForce"])

aux = dict()
U_ind = jnp.zeros((6, 3))
aux["U_ind"] = U_ind

energy_and_aux = pot(positions, box, pairs, paramset3, aux)
energy = energy_and_aux[0]
print("Octupole Included Energy: ", energy)
np.testing.assert_almost_equal(energy, -36.32748562120901, decimal=1)
189 changes: 0 additions & 189 deletions tests/test_admp/test_octopole.py

This file was deleted.

0 comments on commit e98db17

Please sign in to comment.