diff --git a/chiron/analysis.py b/chiron/analysis.py index 06c54dc..6550901 100644 --- a/chiron/analysis.py +++ b/chiron/analysis.py @@ -3,7 +3,7 @@ class MBAREstimator: def __init__(self, N_u: int) -> None: - self.mbar_f_k = np.zeros(len(N_u)) + self.mbar_f_k = np.zeros(N_u) def initialize(self, u_kn: np.ndarray, N_k: np.ndarray): """ diff --git a/chiron/multistate.py b/chiron/multistate.py index 40f5c0e..713f1d0 100644 --- a/chiron/multistate.py +++ b/chiron/multistate.py @@ -66,11 +66,6 @@ def __init__( self._last_mbar_f_k = None self._last_err_free_energy = None - self._online_estimator = None - - from chiron.analysis import MBAREstimator - self._offline_estimator = MBAREstimator() - @property def n_states(self): """The integer number of thermodynamic states (read-only).""" @@ -190,10 +185,17 @@ def create( """ # TODO: initialize reporter here # TODO: consider unsampled thermodynamic states for reweighting schemes - self.free_energy_estimator = "mbar" + self._online_estimator = None + + from chiron.analysis import MBAREstimator + + n_thermodynamic_states = len(thermodynamic_states) + n_sampler_states = len(sampler_states) + + self._offline_estimator = MBAREstimator(N_u=n_thermodynamic_states) # Ensure the number of thermodynamic states matches the number of sampler states - if len(thermodynamic_states) != len(sampler_states): + if n_thermodynamic_states != n_sampler_states: raise RuntimeError( "Number of thermodynamic states and sampler states must be equal." ) @@ -603,5 +605,9 @@ def _update_analysis(self): # Perform offline free energy estimate if requested if self._offline_estimator: log.debug("Performing offline free energy estimate...") - self._offline_estimator.initialize(self._energy_thermodynamic_states_for_each_iteration_in_run) - + N_k = [self._iteration] * self.n_states + log.debug(f"{N_k=}") + self._offline_estimator.initialize( + u_kn=self._energy_thermodynamic_states_for_each_iteration_in_run, + N_k=N_k, + )