Skip to content

Commit

Permalink
Fix initialization of MBAREstimator and update offline free energy es…
Browse files Browse the repository at this point in the history
…timation
  • Loading branch information
wiederm committed Jan 9, 2024
1 parent 6ce2d70 commit 1006be5
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
2 changes: 1 addition & 1 deletion chiron/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

class MBAREstimator:
def __init__(self, N_u: int) -> None:
self.mbar_f_k = np.zeros(len(N_u))
self.mbar_f_k = np.zeros(N_u)

def initialize(self, u_kn: np.ndarray, N_k: np.ndarray):
"""
Expand Down
24 changes: 15 additions & 9 deletions chiron/multistate.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,6 @@ def __init__(
self._last_mbar_f_k = None
self._last_err_free_energy = None

self._online_estimator = None

from chiron.analysis import MBAREstimator
self._offline_estimator = MBAREstimator()

@property
def n_states(self):
"""The integer number of thermodynamic states (read-only)."""
Expand Down Expand Up @@ -190,10 +185,17 @@ def create(
"""
# TODO: initialize reporter here
# TODO: consider unsampled thermodynamic states for reweighting schemes
self.free_energy_estimator = "mbar"
self._online_estimator = None

from chiron.analysis import MBAREstimator

n_thermodynamic_states = len(thermodynamic_states)
n_sampler_states = len(sampler_states)

self._offline_estimator = MBAREstimator(N_u=n_thermodynamic_states)

# Ensure the number of thermodynamic states matches the number of sampler states
if len(thermodynamic_states) != len(sampler_states):
if n_thermodynamic_states != n_sampler_states:
raise RuntimeError(
"Number of thermodynamic states and sampler states must be equal."
)
Expand Down Expand Up @@ -603,5 +605,9 @@ def _update_analysis(self):
# Perform offline free energy estimate if requested
if self._offline_estimator:
log.debug("Performing offline free energy estimate...")
self._offline_estimator.initialize(self._energy_thermodynamic_states_for_each_iteration_in_run)

N_k = [self._iteration] * self.n_states
log.debug(f"{N_k=}")
self._offline_estimator.initialize(
u_kn=self._energy_thermodynamic_states_for_each_iteration_in_run,
N_k=N_k,
)

0 comments on commit 1006be5

Please sign in to comment.