Skip to content

Commit

Permalink
Fixed multistate test of harmonic oscillator array.
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisiacovella committed Mar 1, 2024
1 parent d38ed6d commit e5bf78b
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 9 deletions.
11 changes: 8 additions & 3 deletions chiron/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,10 +329,13 @@ def kT_to_kJ_per_mol(self, energy):
return energy / self.beta


from chiron.neighbors import PairsBase


def calculate_reduced_potential_at_states(
sampler_state: SamplerState,
thermodynamic_states: List[ThermodynamicState],
nbr_list=None,
nbr_list: Optional[PairsBase] = None,
):
"""
Calculate the reduced potential for a list of thermodynamic states.
Expand All @@ -343,7 +346,7 @@ def calculate_reduced_potential_at_states(
The sampler state for which to compute the reduced potential.
thermodynamic_states : list of ThermodynamicState
The thermodynamic states for which to compute the reduced potential.
nbr_list : NeighborList or PairListNsqrd, optional
nbr_list : NeighborList or PairListNsqrd, or None, optional
Returns
-------
list of float
Expand All @@ -355,6 +358,8 @@ def calculate_reduced_potential_at_states(

reduced_potentials = np.zeros(len(thermodynamic_states))
for state_idx, state in enumerate(thermodynamic_states):
reduced_potentials[state_idx] = state.get_reduced_potential(sampler_state)
reduced_potentials[state_idx] = state.get_reduced_potential(
sampler_state, nbr_list
)
log.debug(f"reduced potentials per sampler sate: {reduced_potentials}")
return reduced_potentials
16 changes: 10 additions & 6 deletions chiron/tests/test_multistate.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,9 @@ def test_multistate_minimize(ho_multistate_sampler_multiple_minima: MultiStateSa
)


@pytest.mark.skip(
reason="Multistate code still needs to be modified in the multistage branch"
)
# @pytest.mark.skip(
# reason="Multistate code still needs to be modified in the multistage branch"
# )
def test_multistate_run(ho_multistate_sampler_multiple_ks: MultiStateSampler):
"""
Test function for running the multistate sampler.
Expand All @@ -227,8 +227,8 @@ def test_multistate_run(ho_multistate_sampler_multiple_ks: MultiStateSampler):

print(f"Analytical free energy difference: {ho_sampler.delta_f_ij_analytical[0]}")

n_iteratinos = 25
ho_sampler.run(n_iteratinos)
n_iterations = 25
ho_sampler.run(n_iterations)

# check that we have the correct number of iterations, replicas and states
assert ho_sampler.iteration == n_iterations
Expand All @@ -238,7 +238,11 @@ def test_multistate_run(ho_multistate_sampler_multiple_ks: MultiStateSampler):

u_kn = ho_sampler._reporter.get_property("u_kn")

assert u_kn.shape == (n_iteratinos, 4, 4)
# the u_kn array is transposed to be _states, n_replicas, n_iterations
# SHOULD THIS BE TRANSPOSED IN THE REPORTER? I feel safer to have it
# be transposed when used (if we want it in such a form).
# note n_iterations+1 because it logs time = 0 as well
assert u_kn.shape == (4, 4, n_iterations + 1)
# check that the free energies are correct
print(ho_sampler.analytical_f_i)
# [ 0. , -0.28593054, -0.54696467, -0.78709279]
Expand Down

0 comments on commit e5bf78b

Please sign in to comment.