Skip to content

Commit

Permalink
Refactor multistate sampler setup and add test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
wiederm committed Jan 9, 2024
1 parent d0252af commit b92c89f
Showing 1 changed file with 57 additions and 17 deletions.
74 changes: 57 additions & 17 deletions chiron/tests/test_multistate.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
from chiron.multistate import MultiStateSampler
from chiron.neighbors import NeighborListNsqrd
import pytest
from typing import Tuple


def setup_sampler():
def setup_sampler() -> Tuple[NeighborListNsqrd, MultiStateSampler]:
"""
Set up the neighbor list and multistate sampler for the simulation.
Returns:
Tuple: A tuple containing the neighbor list and multistate sampler objects.
"""
from openmm import unit
from chiron.mcmc import LangevinDynamicsMove
from chiron.neighbors import NeighborListNsqrd, OrthogonalPeriodicSpace

# Initialize simulation object with options. Run with a langevin integrator.
# initialize the LennardJones potential in chiron
#
sigma = 0.34 * unit.nanometer
cutoff = 3.0 * sigma
skin = 0.5 * unit.nanometer
Expand All @@ -18,7 +23,7 @@ def setup_sampler():
OrthogonalPeriodicSpace(), cutoff=cutoff, skin=skin, n_max_neighbors=180
)

move = LangevinDynamicsMove(stepsize=2.0 * unit.femtoseconds, nr_of_steps=50)
move = LangevinDynamicsMove(stepsize=2.0 * unit.femtoseconds, nr_of_steps=500)

multistate_sampler = MultiStateSampler(mcmc_moves=move)
return nbr_list, multistate_sampler
Expand All @@ -27,7 +32,7 @@ def setup_sampler():
@pytest.fixture
def ho_multistate_sampler_multiple_minima() -> MultiStateSampler:
"""
Create a multi-state sampler for a harmonic oscillator system.
Create a multi-state sampler for multiple harmonic oscillators with different minimum values.
Returns:
MultiStateSampler: The multi-state sampler object.
Expand Down Expand Up @@ -68,10 +73,11 @@ def ho_multistate_sampler_multiple_minima() -> MultiStateSampler:
@pytest.fixture
def ho_multistate_sampler_multiple_ks() -> MultiStateSampler:
"""
Create a multi-state sampler for a harmonic oscillator system.
Returns:
MultiStateSampler: The multi-state sampler object.
Create a multi-state sampler for a harmonic oscillator system with different spring constants.
Returns
-------
MultiStateSampler
The multi-state sampler object.
"""
from openmm import unit
from chiron.states import ThermodynamicState, SamplerState
Expand Down Expand Up @@ -120,30 +126,46 @@ def ho_multistate_sampler_multiple_ks() -> MultiStateSampler:


def test_multistate_class(ho_multistate_sampler_multiple_minima: MultiStateSampler):
# test the multistate_sampler object
"""
Test initialization for the MultiStateSampler class.
Parameters:
-------
ho_multistate_sampler_multiple_minima: MultiStateSampler
An instance of the MultiStateSampler class.
Raises:
-------
AssertionError:
If any of the assertions fail.
"""
assert ho_multistate_sampler_multiple_minima._iteration == 0
assert ho_multistate_sampler_multiple_minima.n_replicas == 3
assert ho_multistate_sampler_multiple_minima.n_states == 3
assert ho_multistate_sampler_multiple_minima._energy_thermodynamic_states.shape == (3, 3)
assert ho_multistate_sampler_multiple_minima._energy_thermodynamic_states.shape == (
3,
3,
)
assert ho_multistate_sampler_multiple_minima._n_proposed_matrix.shape == (3, 3)


def test_multistate_minimize(ho_multistate_sampler_multiple_minima: MultiStateSampler):
"""
Test function for the `minimize` method of the `ho_multistate_sampler` object.
It checks if the sampler states are correctly minimized.
Check if the sampler states are correctly minimized.
Parameters
----------
ho_multistate_sampler: The `ho_multistate_sampler` object to be tested.
ho_multistate_sampler: MultiStateSampler
"""

import numpy as np

ho_multistate_sampler_multiple_minima.minimize()

assert np.allclose(
ho_multistate_sampler_multiple_minima.sampler_states[0].x0, np.array([[0.0, 0.0, 0.0]])
ho_multistate_sampler_multiple_minima.sampler_states[0].x0,
np.array([[0.0, 0.0, 0.0]]),
)
assert np.allclose(
ho_multistate_sampler_multiple_minima.sampler_states[1].x0,
Expand All @@ -158,10 +180,25 @@ def test_multistate_minimize(ho_multistate_sampler_multiple_minima: MultiStateSa


def test_multistate_run(ho_multistate_sampler_multiple_ks: MultiStateSampler):
"""
Test function for running the multistate sampler.
Parameters
----------
ho_multistate_sampler_multiple_ks: MultiStateSampler
The multistate sampler object.
Raises
-------
AssertionError: If free energy does not converge to the analytical free energy difference.
"""

ho_sampler = ho_multistate_sampler_multiple_ks
import numpy as np

n_iteratinos = 100
print(f"Analytical free energy difference: {ho_sampler.delta_f_ij_analytical[0]}")

n_iteratinos = 25
ho_sampler.run(n_iteratinos)

# check that we have the correct number of iterations, replicas and states
Expand All @@ -174,4 +211,7 @@ def test_multistate_run(ho_multistate_sampler_multiple_ks: MultiStateSampler):
print(ho_sampler.analytical_f_i)
print(ho_sampler.delta_f_ij_analytical)
print(ho_sampler._last_mbar_f_k_offline)
a = 7

assert np.isclose(
ho_sampler.delta_f_ij_analytical[0], ho_sampler._last_mbar_f_k_offline
)

0 comments on commit b92c89f

Please sign in to comment.