Skip to content

Commit

Permalink
Add MBAR class and update free energy estimators
Browse files Browse the repository at this point in the history
  • Loading branch information
wiederm committed Jan 10, 2024
1 parent 1006be5 commit 0d7c944
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 8 deletions.
19 changes: 16 additions & 3 deletions chiron/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
class MBAREstimator:
def __init__(self, N_u: int) -> None:
self.mbar_f_k = np.zeros(N_u)
self.mbar = None

def initialize(self, u_kn: np.ndarray, N_k: np.ndarray):
"""
Expand All @@ -13,6 +14,18 @@ def initialize(self, u_kn: np.ndarray, N_k: np.ndarray):
from loguru import logger as log

log.debug(f"{N_k=}")
mbar = MBAR(u_kn=u_kn, N_k=N_k)
log.debug(mbar.f_k)
self.mbar_f_k = mbar.f_k
self.mbar = MBAR(u_kn=u_kn, N_k=N_k)

@property
def f_k(self):
from loguru import logger as log

log.debug(self.mbar.f_k)
return self.mbar.f_k

def get_free_energy_difference(self):
from loguru import logger as log

log.debug(self.mbar.f_k[-1])
self.f_k = self.mbar.f_k
return self.mbar_f_k[-1]
19 changes: 18 additions & 1 deletion chiron/multistate.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,8 +606,25 @@ def _update_analysis(self):
if self._offline_estimator:
log.debug("Performing offline free energy estimate...")
N_k = [self._iteration] * self.n_states
log.debug(f"{N_k=}")
self._offline_estimator.initialize(
u_kn=self._energy_thermodynamic_states_for_each_iteration_in_run,
N_k=N_k,
)
elif self._online_estimator:
log.debug("Performing online free energy estimate...")
self._online_estimator.update(
u_kn=self._energy_thermodynamic_states_for_each_iteration_in_run[
:, :, self._iteration
]
)
else:
raise RuntimeError("No free energy estimator provided.")

@property
def f_k(self):
if self._offline_estimator:
return self._offline_estimator.f_k
elif self._online_estimator:
return self._online_estimator.f_k
else:
raise RuntimeError("No free energy estimator found.")
6 changes: 2 additions & 4 deletions chiron/tests/test_multistate.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,6 @@ def test_multistate_run(ho_multistate_sampler_multiple_ks: MultiStateSampler):
# check that the free energies are correct
print(ho_sampler.analytical_f_i)
print(ho_sampler.delta_f_ij_analytical)
print(ho_sampler._last_mbar_f_k_offline)
print(ho_sampler.f_k)

assert np.allclose(
ho_sampler.delta_f_ij_analytical[0], ho_sampler._last_mbar_f_k_offline, atol=0.1
)
assert np.allclose(ho_sampler.delta_f_ij_analytical[0], ho_sampler.f_k, atol=0.1)

0 comments on commit 0d7c944

Please sign in to comment.