From b92c89f71a61a9b4539b9cbe7a9e00c29cff7ee1 Mon Sep 17 00:00:00 2001 From: wiederm Date: Tue, 9 Jan 2024 12:10:52 +0100 Subject: [PATCH] Refactor multistate sampler setup and add test cases --- chiron/tests/test_multistate.py | 74 +++++++++++++++++++++++++-------- 1 file changed, 57 insertions(+), 17 deletions(-) diff --git a/chiron/tests/test_multistate.py b/chiron/tests/test_multistate.py index fc3e13f..d38f357 100644 --- a/chiron/tests/test_multistate.py +++ b/chiron/tests/test_multistate.py @@ -1,15 +1,20 @@ from chiron.multistate import MultiStateSampler +from chiron.neighbors import NeighborListNsqrd import pytest +from typing import Tuple -def setup_sampler(): +def setup_sampler() -> Tuple[NeighborListNsqrd, MultiStateSampler]: + """ + Set up the neighbor list and multistate sampler for the simulation. + + Returns: + Tuple: A tuple containing the neighbor list and multistate sampler objects. + """ from openmm import unit from chiron.mcmc import LangevinDynamicsMove from chiron.neighbors import NeighborListNsqrd, OrthogonalPeriodicSpace - # Initialize simulation object with options. Run with a langevin integrator. - # initialize the LennardJones potential in chiron - # sigma = 0.34 * unit.nanometer cutoff = 3.0 * sigma skin = 0.5 * unit.nanometer @@ -18,7 +23,7 @@ def setup_sampler(): OrthogonalPeriodicSpace(), cutoff=cutoff, skin=skin, n_max_neighbors=180 ) - move = LangevinDynamicsMove(stepsize=2.0 * unit.femtoseconds, nr_of_steps=50) + move = LangevinDynamicsMove(stepsize=2.0 * unit.femtoseconds, nr_of_steps=500) multistate_sampler = MultiStateSampler(mcmc_moves=move) return nbr_list, multistate_sampler @@ -27,7 +32,7 @@ def setup_sampler(): @pytest.fixture def ho_multistate_sampler_multiple_minima() -> MultiStateSampler: """ - Create a multi-state sampler for a harmonic oscillator system. + Create a multi-state sampler for multiple harmonic oscillators with different minimum values. Returns: MultiStateSampler: The multi-state sampler object. @@ -68,10 +73,11 @@ def ho_multistate_sampler_multiple_minima() -> MultiStateSampler: @pytest.fixture def ho_multistate_sampler_multiple_ks() -> MultiStateSampler: """ - Create a multi-state sampler for a harmonic oscillator system. - - Returns: - MultiStateSampler: The multi-state sampler object. + Create a multi-state sampler for a harmonic oscillator system with different spring constants. + Returns + ------- + MultiStateSampler + The multi-state sampler object. """ from openmm import unit from chiron.states import ThermodynamicState, SamplerState @@ -120,22 +126,37 @@ def ho_multistate_sampler_multiple_ks() -> MultiStateSampler: def test_multistate_class(ho_multistate_sampler_multiple_minima: MultiStateSampler): - # test the multistate_sampler object + """ + Test initialization for the MultiStateSampler class. + + Parameters: + ------- + ho_multistate_sampler_multiple_minima: MultiStateSampler + An instance of the MultiStateSampler class. + Raises: + ------- + AssertionError: + If any of the assertions fail. + + """ assert ho_multistate_sampler_multiple_minima._iteration == 0 assert ho_multistate_sampler_multiple_minima.n_replicas == 3 assert ho_multistate_sampler_multiple_minima.n_states == 3 - assert ho_multistate_sampler_multiple_minima._energy_thermodynamic_states.shape == (3, 3) + assert ho_multistate_sampler_multiple_minima._energy_thermodynamic_states.shape == ( + 3, + 3, + ) assert ho_multistate_sampler_multiple_minima._n_proposed_matrix.shape == (3, 3) def test_multistate_minimize(ho_multistate_sampler_multiple_minima: MultiStateSampler): """ Test function for the `minimize` method of the `ho_multistate_sampler` object. - It checks if the sampler states are correctly minimized. + Check if the sampler states are correctly minimized. Parameters ---------- - ho_multistate_sampler: The `ho_multistate_sampler` object to be tested. + ho_multistate_sampler: MultiStateSampler """ import numpy as np @@ -143,7 +164,8 @@ def test_multistate_minimize(ho_multistate_sampler_multiple_minima: MultiStateSa ho_multistate_sampler_multiple_minima.minimize() assert np.allclose( - ho_multistate_sampler_multiple_minima.sampler_states[0].x0, np.array([[0.0, 0.0, 0.0]]) + ho_multistate_sampler_multiple_minima.sampler_states[0].x0, + np.array([[0.0, 0.0, 0.0]]), ) assert np.allclose( ho_multistate_sampler_multiple_minima.sampler_states[1].x0, @@ -158,10 +180,25 @@ def test_multistate_minimize(ho_multistate_sampler_multiple_minima: MultiStateSa def test_multistate_run(ho_multistate_sampler_multiple_ks: MultiStateSampler): + """ + Test function for running the multistate sampler. + + Parameters + ---------- + ho_multistate_sampler_multiple_ks: MultiStateSampler + The multistate sampler object. + Raises + ------- + AssertionError: If free energy does not converge to the analytical free energy difference. + + """ + ho_sampler = ho_multistate_sampler_multiple_ks import numpy as np - n_iteratinos = 100 + print(f"Analytical free energy difference: {ho_sampler.delta_f_ij_analytical[0]}") + + n_iteratinos = 25 ho_sampler.run(n_iteratinos) # check that we have the correct number of iterations, replicas and states @@ -174,4 +211,7 @@ def test_multistate_run(ho_multistate_sampler_multiple_ks: MultiStateSampler): print(ho_sampler.analytical_f_i) print(ho_sampler.delta_f_ij_analytical) print(ho_sampler._last_mbar_f_k_offline) - a = 7 + + assert np.isclose( + ho_sampler.delta_f_ij_analytical[0], ho_sampler._last_mbar_f_k_offline + )