Skip to content

Commit

Permalink
Fix some bugs. Still in debugging
Browse files Browse the repository at this point in the history
  • Loading branch information
dingye18 committed Mar 7, 2024
1 parent 25468e7 commit 511a721
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 29 deletions.
12 changes: 10 additions & 2 deletions dmff/admp/pme.py
Original file line number Diff line number Diff line change
Expand Up @@ -1336,17 +1336,25 @@ def pme_self(Q_h, kappa, lmax=2):
2,
]
* 5
+ [
3,
]
* 7
)[:n_harms]
l_fac2 = np.array(
[1]
+ [
3,
3 * 1,
]
* 3
+ [
15,
5 * 3 * 1,
]
* 5
+ [
7 * 5 * 3 * 1,
]
* 7
)[:n_harms]
factor = kappa / np.sqrt(np.pi) * (2 * kappa**2) ** l_list / l_fac2
return -jnp.sum(factor[np.newaxis] * Q_h**2) * DIELECTRIC
Expand Down
55 changes: 28 additions & 27 deletions dmff/admp/recip.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,24 +327,25 @@ def theta3prime_eval(u, Nj_Aji_star, M_u, Mprime_u, M2prime_u, M3prime_u):
div_222 = M_u[:, 0] * M_u[:, 1] * M3prime_u[:, 2]

div = jnp.array([
[
[div_000, div_001, div_002],
[div_010, div_011, div_012],
[div_020, div_021, div_022]
],
[
[div_100, div_101, div_102],
[div_110, div_111, div_112],
[div_120, div_121, div_122]
],
[
[div_200, div_201, div_202],
[div_210, div_211, div_212],
[div_220, div_221, div_222]
]]
[
[div_000, div_001, div_002],
[div_010, div_011, div_012],
[div_020, div_021, div_022]
],
[
[div_100, div_101, div_102],
[div_110, div_111, div_112],
[div_120, div_121, div_122]
],
[
[div_200, div_201, div_202],
[div_210, div_211, div_212],
[div_220, div_221, div_222]
]
]
).swapaxes(0, 3).swapaxes(1, 2).swapaxes(2, 3)

return jnp.einsum("im,jn,ko,mn,op->kij", -Nj_Aji_star, -Nj_Aji_star, -Nj_Aji_star, div, -Nj_Aji_star)
return jnp.einsum("im,jn,ko,abcd->aijk", -Nj_Aji_star, -Nj_Aji_star, -Nj_Aji_star, div)


def sph_harmonics_GO(u0, Nj_Aji_star):
Expand Down Expand Up @@ -411,22 +412,22 @@ def sph_harmonics_GO(u0, Nj_Aji_star):

# Octupole
M3prime_u = bspline_prime3(u)
theta3prime_eval(u, Nj_Aji_star, M_u, Mprime_u, M2prime_u, M3prime_u)
theta3prime = theta3prime_eval(u, Nj_Aji_star, M_u, Mprime_u, M2prime_u, M3prime_u)
rt6 = jnp.sqrt(6)
rt15 = jnp.sqrt(15)
rt10 = jnp.sqrt(10)

harmonics_3 = jnp.hstack(
[harmonics_2,
jnp.stack([
(5 * theta3prime_eval[:, 2, 2, 2] - 3 * jnp.trace(theta3prime_eval[:, 2], axis1 = 0, axis2 = 1)) / 2,
rt6 * (5 * theta3prime_eval[:, 0, 2, 2] - jnp.trace(theta3prime_eval[:, 0], axis1 = 0, axis2 = 1)) / 4,
rt6 * (5 * theta3prime_eval[:, 1, 2, 2] - jnp.trace(theta3prime_eval[:, 1], axis1 = 0, axis2 = 1)) / 4,
rt15 * (theta3prime_eval[:, 2, 0, 0] - theta3prime_eval[:, 2, 1,1]) / 2,
rt15 * theta3prime_eval[:, 0, 1, 2],
rt10 * (theta3prime_eval[:, 0, 0, 0] - 3 * theta3prime_eval[:, 0, 1, 1]) / 4,
rt10 * (3 * theta3prime_eval[:, 0, 0, 1] - theta3prime_eval[:, 1, 1, 1]) / 4
])
(5 * theta3prime[:, 2, 2, 2] - 3 * jnp.trace(theta3prime[:, 2], axis1 = 1, axis2 = 2)) / 2,
rt6 * (5 * theta3prime[:, 0, 2, 2] - jnp.trace(theta3prime[:, 0], axis1 = 1, axis2 = 2)) / 4,
rt6 * (5 * theta3prime[:, 1, 2, 2] - jnp.trace(theta3prime[:, 1], axis1 = 1, axis2 = 2)) / 4,
rt15 * (theta3prime[:, 2, 0, 0] - theta3prime[:, 2, 1,1]) / 2,
rt15 * theta3prime[:, 0, 1, 2],
rt10 * (theta3prime[:, 0, 0, 0] - 3 * theta3prime[:, 0, 1, 1]) / 4,
rt10 * (3 * theta3prime[:, 0, 0, 1] - theta3prime[:, 1, 1, 1]) / 4
], axis=1)
]
)

Expand All @@ -453,8 +454,8 @@ def Q_m_peratom(Q, sph_harms):

N_a = sph_harms.shape[0]

if lmax > 2:
raise NotImplementedError('l > 2 (beyond quadrupole) not supported')
if lmax > 3:
raise NotImplementedError('l > 3 (beyond octupole) not supported')

Q_dbf = Q[:, 0:1]

Expand Down
63 changes: 63 additions & 0 deletions tests/test_admp/test_compute_octupole.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import openmm.app as app
import openmm.unit as unit
import numpy as np
import jax.numpy as jnp
import numpy.testing as npt
import pytest
from dmff import Hamiltonian, NeighborList
from jax import jit, value_and_grad

class TestADMPAPI:

""" Test ADMP related generators
"""

@pytest.fixture(scope='class', name='pot_prm')
def test_init(self):
"""load generators from XML file
Yields:
Tuple: (
ADMPDispForce,
ADMPPmeForce, # polarized
)
"""
rc = 4.0
H = Hamiltonian('tests/data/admp.xml')
H1 = Hamiltonian('tests/data/admp_mono.xml')
H2 = Hamiltonian('tests/data/admp_nonpol.xml')
H3 = Hamiltonian('tests/data/admp_octupole.xml')
pdb = app.PDBFile('tests/data/water_dimer.pdb')
potential = H.createPotential(pdb.topology, nonbondedMethod=app.CutoffPeriodic, nonbondedCutoff=rc*unit.angstrom, ethresh=5e-4, step_pol=5)
potential_aux = H.createPotential(pdb.topology, nonbondedMethod=app.CutoffPeriodic, nonbondedCutoff=rc*unit.angstrom, ethresh=5e-4, step_pol=5, has_aux=True)
potential1 = H1.createPotential(pdb.topology, nonbondedMethod=app.CutoffPeriodic, nonbondedCutoff=rc*unit.angstrom, ethresh=5e-4, step_pol=5)
potential2 = H2.createPotential(pdb.topology, nonbondedMethod=app.CutoffPeriodic, nonbondedCutoff=rc*unit.angstrom, ethresh=5e-4, step_pol=5)
potential3 = H3.createPotential(pdb.topology, nonbondedMethod=app.PME, nonbondedCutoff=rc*unit.angstrom, ethresh=5e-4, step_pol=5, has_aux=True)

yield potential, potential_aux, potential1, potential2, potential3, H.paramset, H1.paramset, H2.paramset, H3.paramset

def test_ADMPPmeForce_octupole(self, pot_prm):
potential, potential_aux, potential1, potential2, potential3, paramset, paramset1, paramset2, paramset3 = pot_prm
rc = 0.4
pdb = app.PDBFile('tests/data/water_dimer.pdb')
positions = pdb.getPositions(asNumpy=True).value_in_unit(unit.nanometer)
positions = jnp.array(positions)
a, b, c = pdb.topology.getPeriodicBoxVectors().value_in_unit(unit.nanometer)
box = jnp.array([a, b, c])
# neighbor list

covalent_map = potential3.meta["cov_map"]

nblist = NeighborList(box, rc, covalent_map)
nblist.allocate(positions)
pairs = nblist.pairs
pot = potential3.getPotentialFunc(names=["ADMPPmeForce"])

aux = dict()
U_ind = jnp.zeros((6, 3))
aux["U_ind"] = U_ind

energy_and_aux = pot(positions, box, pairs, paramset3, aux)
energy = energy_and_aux[0]
print("Octupole Included Energy: ", energy)
np.testing.assert_almost_equal(energy, -36.32748562120901, decimal=1)

0 comments on commit 511a721

Please sign in to comment.