From e5bf78b68d86f2157d883f99f61944b76d4ba0a0 Mon Sep 17 00:00:00 2001 From: chrisiacovella Date: Thu, 29 Feb 2024 20:19:45 -0800 Subject: [PATCH] Fixed multistate test of harmonic oscillator array. --- chiron/states.py | 11 ++++++++--- chiron/tests/test_multistate.py | 16 ++++++++++------ 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/chiron/states.py b/chiron/states.py index 6ebdca1..805d02f 100644 --- a/chiron/states.py +++ b/chiron/states.py @@ -329,10 +329,13 @@ def kT_to_kJ_per_mol(self, energy): return energy / self.beta +from chiron.neighbors import PairsBase + + def calculate_reduced_potential_at_states( sampler_state: SamplerState, thermodynamic_states: List[ThermodynamicState], - nbr_list=None, + nbr_list: Optional[PairsBase] = None, ): """ Calculate the reduced potential for a list of thermodynamic states. @@ -343,7 +346,7 @@ def calculate_reduced_potential_at_states( The sampler state for which to compute the reduced potential. thermodynamic_states : list of ThermodynamicState The thermodynamic states for which to compute the reduced potential. - nbr_list : NeighborList or PairListNsqrd, optional + nbr_list : NeighborList or PairListNsqrd, or None, optional Returns ------- list of float @@ -355,6 +358,8 @@ def calculate_reduced_potential_at_states( reduced_potentials = np.zeros(len(thermodynamic_states)) for state_idx, state in enumerate(thermodynamic_states): - reduced_potentials[state_idx] = state.get_reduced_potential(sampler_state) + reduced_potentials[state_idx] = state.get_reduced_potential( + sampler_state, nbr_list + ) log.debug(f"reduced potentials per sampler sate: {reduced_potentials}") return reduced_potentials diff --git a/chiron/tests/test_multistate.py b/chiron/tests/test_multistate.py index c850efb..f23854d 100644 --- a/chiron/tests/test_multistate.py +++ b/chiron/tests/test_multistate.py @@ -205,9 +205,9 @@ def test_multistate_minimize(ho_multistate_sampler_multiple_minima: MultiStateSa ) -@pytest.mark.skip( - reason="Multistate code still needs to be modified in the multistage branch" -) +# @pytest.mark.skip( +# reason="Multistate code still needs to be modified in the multistage branch" +# ) def test_multistate_run(ho_multistate_sampler_multiple_ks: MultiStateSampler): """ Test function for running the multistate sampler. @@ -227,8 +227,8 @@ def test_multistate_run(ho_multistate_sampler_multiple_ks: MultiStateSampler): print(f"Analytical free energy difference: {ho_sampler.delta_f_ij_analytical[0]}") - n_iteratinos = 25 - ho_sampler.run(n_iteratinos) + n_iterations = 25 + ho_sampler.run(n_iterations) # check that we have the correct number of iterations, replicas and states assert ho_sampler.iteration == n_iterations @@ -238,7 +238,11 @@ def test_multistate_run(ho_multistate_sampler_multiple_ks: MultiStateSampler): u_kn = ho_sampler._reporter.get_property("u_kn") - assert u_kn.shape == (n_iteratinos, 4, 4) + # the u_kn array is transposed to be _states, n_replicas, n_iterations + # SHOULD THIS BE TRANSPOSED IN THE REPORTER? I feel safer to have it + # be transposed when used (if we want it in such a form). + # note n_iterations+1 because it logs time = 0 as well + assert u_kn.shape == (4, 4, n_iterations + 1) # check that the free energies are correct print(ho_sampler.analytical_f_i) # [ 0. , -0.28593054, -0.54696467, -0.78709279]