Skip to content

Commit

Permalink
adding docstring, refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
wiederm committed Dec 1, 2023
1 parent 25d85fe commit 267b1a8
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 81 deletions.
22 changes: 19 additions & 3 deletions chiron/tests/test_mcmc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
def test_sample_from_harmonic_osciallator():
# use local moves to sample from the harmonic oscillator
"""
Test sampling from a harmonic oscillator using local moves.
This test initializes a harmonic oscillator from openmmtools.testsystems,
sets up a harmonic potential, and uses a Langevin integrator to sample
from the oscillator's state space.
"""
from openmm import unit

# initialize openmmtestsystem
Expand Down Expand Up @@ -41,7 +47,12 @@ def test_sample_from_harmonic_osciallator():


def test_sample_from_harmonic_osciallator_with_MCMC_classes_and_LangevinDynamics():
# use local moves to sample from the HO, but use the MCMC classes
"""
Test sampling from a harmonic oscillator using MCMC classes and Langevin dynamics.
This test initializes a harmonic oscillator, sets up the thermodynamic and
sampler states, and uses the Langevin dynamics move in an MCMC sampling scheme.
"""
from openmm import unit
from chiron.potential import HarmonicOscillatorPotential
from chiron.mcmc import LangevinDynamicsMove, MoveSet, GibbsSampler
Expand Down Expand Up @@ -77,7 +88,12 @@ def test_sample_from_harmonic_osciallator_with_MCMC_classes_and_LangevinDynamics


def test_sample_from_harmonic_osciallator_with_MCMC_classes_and_MetropolisDisplacementMove():
# use local moves to sample from the HO, but use the MCMC classes
"""
Test sampling from a harmonic oscillator using MCMC classes and Metropolis displacement move.
This test initializes a harmonic oscillator, sets up the thermodynamic and
sampler states, and uses the Metropolis displacement move in an MCMC sampling scheme.
"""
from openmm import unit
from chiron.potential import HarmonicOscillatorPotential
from chiron.mcmc import MetropolisDisplacementMove, MoveSet, GibbsSampler
Expand Down
113 changes: 35 additions & 78 deletions chiron/tests/test_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,29 @@

# Test NeuralNetworkPotential
def test_neural_network_pairlist():
# Create a topology object
"""
Test the pairlist computation for a NeuralNetworkPotential.
This function tests the compute_pairlist method for different cutoff distances
using a simple two-particle system and ethanol molecule.
"""

pdb_file = get_data_file_path("two_particles_1.pdb")
pdb = app.PDBFile(pdb_file)
topology = pdb.getTopology()
positions = pdb.getPositions(asNumpy=True).value_in_unit_system(unit.md_unit_system)

# Create a neural network potential object
nn_potential = NeuralNetworkPotential(model=None, topology=topology)

# Test compute_pairlist method
cutoff = 0.2
pairlist = nn_potential.compute_pairlist(positions, cutoff)
assert (
pairlist[0].size == 1 and pairlist[1].size == 1
) # there is one pair that is within the cutoff distance
# Test with different cutoffs
cutoffs = [0.2, 0.1]
expected_pairs = [(1, 1), (0, 0)]
for cutoff, expected in zip(cutoffs, expected_pairs):
pairlist = nn_potential.compute_pairlist(positions, cutoff)
assert pairlist[0].size == expected[0] and pairlist[1].size == expected[1]

# Test compute_pairlist method
cutoff = 0.1
pairlist = nn_potential.compute_pairlist(positions, cutoff)
assert (
pairlist[0].size == 0 and pairlist[1].size == 0
) # there is no pair that is within the cutoff distance

# try this with ethanol
# Test with ethanol molecule
pdb_file = get_data_file_path("ethanol.pdb")
pdb = app.PDBFile(pdb_file)
topology = pdb.getTopology()
Expand All @@ -53,76 +53,33 @@ def test_neural_network_pairlist():

# Test HarmonicOscillatorPotential
def test_harmonic_oscillator_potential():
# Create a harmonic oscillator potential object
"""
Test the energy computation of a HarmonicOscillatorPotential.
This function verifies the energy computed by a HarmonicOscillatorPotential for
various positions, comparing the computed energies with expected values.
"""
k = 100.0 * unit.kilocalories_per_mole / unit.angstroms**2
U0 = 0.0 * unit.kilocalories_per_mole
x0 = 0.0 * unit.angstrom

from openmmtools.testsystems import HarmonicOscillator as ho

harmonic_potential = HarmonicOscillatorPotential(ho.topology, k, x0, U0)
positions = jnp.array([0.0, 0.0, 0.0]) * unit.angstrom
# Test compute_energy method
positions_without_unit = jnp.array(
positions.value_in_unit_system(unit.md_unit_system)
)
energy = float(harmonic_potential.compute_energy(positions_without_unit))
assert jnp.isclose(energy, 0.0)

positions = jnp.array([0.2, 0.2, 0.2]) * unit.angstrom
# Test compute_energy method
positions_without_unit = jnp.array(
positions.value_in_unit_system(unit.md_unit_system)
)
energy = float(harmonic_potential.compute_energy(positions_without_unit))
assert jnp.isclose(energy, 8.368000984191895)

positions = jnp.array([0.2, 0.0, 0.0]) * unit.angstrom
# Test compute_energy method
positions_without_unit = jnp.array(
positions.value_in_unit_system(unit.md_unit_system)
)
energy = float(harmonic_potential.compute_energy(positions_without_unit))
assert jnp.isclose(energy, 8.368000984191895)

positions = jnp.array([-0.2, 0.0, 0.0]) * unit.angstrom
# Test compute_energy method
positions_without_unit = jnp.array(
positions.value_in_unit_system(unit.md_unit_system)
)
energy = float(harmonic_potential.compute_energy(positions_without_unit))
assert jnp.isclose(energy, 8.368000984191895)

positions = jnp.array([-0.0, 0.2, 0.0]) * unit.angstrom
# Test compute_energy method
positions_without_unit = jnp.array(
positions.value_in_unit_system(unit.md_unit_system)
)
energy = float(harmonic_potential.compute_energy(positions_without_unit))
assert jnp.isclose(energy, 0.0)
test_positions = [
jnp.array([0.0, 0.0, 0.0]) * unit.angstrom,
jnp.array([0.2, 0.2, 0.2]) * unit.angstrom,
jnp.array([0.2, 0.0, 0.0]) * unit.angstrom,
jnp.array([-0.2, 0.0, 0.0]) * unit.angstrom,
jnp.array([-0.0, 0.2, 0.0]) * unit.angstrom
]
expected_energies = [0.0, 8.368000984191895, 8.368000984191895, 8.368000984191895, 0.0]

for pos, expected_energy in zip(test_positions, expected_energies):
positions_without_unit = jnp.array(pos.value_in_unit_system(unit.md_unit_system))
energy = float(harmonic_potential.compute_energy(positions_without_unit))
assert jnp.isclose(energy, expected_energy), f"Energy at {pos} is incorrect."

# Test compute_force method
forces = harmonic_potential.compute_force(positions_without_unit)
assert forces.shape == positions_without_unit.shape


# # Test LJPotential
# def test_lj_potential():
# pdb_file = get_data_file_path("two_particles_1.pdb")
# pdb = app.PDBFile(pdb_file)
# topology = pdb.getTopology()
# positions = pdb.getPositions(asNumpy=True)

# # Create an LJ potential object
# sigma = 1.0 * unit.kilocalories_per_mole
# epsilon = 3.0 * unit.angstroms
# lj_potential = LJPotential(topology, sigma, epsilon)

# # Test compute_energy method
# positions = np.array([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]]) * unit.angstrom
# energy = lj_potential.compute_energy(positions)
# assert isinstance(energy, float)

# # Test compute_force method
# forces = lj_potential.compute_force(positions)
# assert forces.shape == positions.shape
assert forces.shape == positions_without_unit.shape, "Forces shape mismatch."

0 comments on commit 267b1a8

Please sign in to comment.