diff --git a/chiron/analysis.py b/chiron/analysis.py index 6550901..9608a81 100644 --- a/chiron/analysis.py +++ b/chiron/analysis.py @@ -4,6 +4,7 @@ class MBAREstimator: def __init__(self, N_u: int) -> None: self.mbar_f_k = np.zeros(N_u) + self.mbar = None def initialize(self, u_kn: np.ndarray, N_k: np.ndarray): """ @@ -13,6 +14,18 @@ def initialize(self, u_kn: np.ndarray, N_k: np.ndarray): from loguru import logger as log log.debug(f"{N_k=}") - mbar = MBAR(u_kn=u_kn, N_k=N_k) - log.debug(mbar.f_k) - self.mbar_f_k = mbar.f_k + self.mbar = MBAR(u_kn=u_kn, N_k=N_k) + + @property + def f_k(self): + from loguru import logger as log + + log.debug(self.mbar.f_k) + return self.mbar.f_k + + def get_free_energy_difference(self): + from loguru import logger as log + + log.debug(self.mbar.f_k[-1]) + self.f_k = self.mbar.f_k + return self.mbar_f_k[-1] diff --git a/chiron/multistate.py b/chiron/multistate.py index 713f1d0..6a0ba75 100644 --- a/chiron/multistate.py +++ b/chiron/multistate.py @@ -606,8 +606,25 @@ def _update_analysis(self): if self._offline_estimator: log.debug("Performing offline free energy estimate...") N_k = [self._iteration] * self.n_states - log.debug(f"{N_k=}") self._offline_estimator.initialize( u_kn=self._energy_thermodynamic_states_for_each_iteration_in_run, N_k=N_k, ) + elif self._online_estimator: + log.debug("Performing online free energy estimate...") + self._online_estimator.update( + u_kn=self._energy_thermodynamic_states_for_each_iteration_in_run[ + :, :, self._iteration + ] + ) + else: + raise RuntimeError("No free energy estimator provided.") + + @property + def f_k(self): + if self._offline_estimator: + return self._offline_estimator.f_k + elif self._online_estimator: + return self._online_estimator.f_k + else: + raise RuntimeError("No free energy estimator found.") diff --git a/chiron/tests/test_multistate.py b/chiron/tests/test_multistate.py index ce74e1d..c316046 100644 --- a/chiron/tests/test_multistate.py +++ b/chiron/tests/test_multistate.py @@ -210,8 +210,6 @@ def test_multistate_run(ho_multistate_sampler_multiple_ks: MultiStateSampler): # check that the free energies are correct print(ho_sampler.analytical_f_i) print(ho_sampler.delta_f_ij_analytical) - print(ho_sampler._last_mbar_f_k_offline) + print(ho_sampler.f_k) - assert np.allclose( - ho_sampler.delta_f_ij_analytical[0], ho_sampler._last_mbar_f_k_offline, atol=0.1 - ) + assert np.allclose(ho_sampler.delta_f_ij_analytical[0], ho_sampler.f_k, atol=0.1)