diff --git a/CHANGES b/CHANGES index 6734d315..02940a79 100644 --- a/CHANGES +++ b/CHANGES @@ -24,6 +24,8 @@ Changes `None` (start from all zeros) as this change provides a sizable speedup (PR #357) Enhancements + - `forward_backward_convergence` uses the bootstrap error when the statistical error + is too large. (PR #358) - `BAR` result is used as initial guess for `MBAR` estimator. (PR #357) - `forward_backward_convergence` uses the result from the previous step as the initial guess for the next step. (PR #357) diff --git a/src/alchemlyb/convergence/convergence.py b/src/alchemlyb/convergence/convergence.py index 65f15b86..d109e1ee 100644 --- a/src/alchemlyb/convergence/convergence.py +++ b/src/alchemlyb/convergence/convergence.py @@ -1,10 +1,12 @@ """Functions for assessing convergence of free energy estimates and raw data.""" +from typing import Any, List, Tuple from warnings import warn import numpy as np import pandas as pd from loguru import logger +from sklearn.base import BaseEstimator from .. import concat from ..estimators import BAR, TI, MBAR, FEP_ESTIMATORS, TI_ESTIMATORS @@ -13,7 +15,9 @@ estimators_dispatch = {"BAR": BAR, "TI": TI, "MBAR": MBAR} -def forward_backward_convergence(df_list, estimator="MBAR", num=10, **kwargs): +def forward_backward_convergence( + df_list, estimator="MBAR", num=10, error_tol: float = 3, **kwargs +): """Forward and backward convergence of the free energy estimate. Generate the free energy estimate as a function of time in both directions, @@ -35,6 +39,12 @@ def forward_backward_convergence(df_list, estimator="MBAR", num=10, **kwargs): Lower case input is also accepted until release 2.0.0. num : int The number of time points. + error_tol : float + The maximum error tolerated for analytic error. If the analytic error is + bigger than the error tolerance, the bootstrap error will be used. + + .. versionadded:: 2.3.0 + kwargs : dict Keyword arguments to be passed to the estimator. @@ -93,23 +103,11 @@ def forward_backward_convergence(df_list, estimator="MBAR", num=10, **kwargs): sample = [] for data in df_list: sample.append(data[: len(data) // num * i]) - sample = concat(sample) - result = my_estimator.fit(sample) - if estimator == "MBAR": - my_estimator.initial_f_k = result.delta_f_.iloc[0, :] - forward_list.append(result.delta_f_.iloc[0, -1]) - if estimator.lower() == "bar": - error = np.sqrt( - sum( - [ - result.d_delta_f_.iloc[i, i + 1] ** 2 - for i in range(len(result.d_delta_f_) - 1) - ] - ) - ) - forward_error_list.append(error) - else: - forward_error_list.append(result.d_delta_f_.iloc[0, -1]) + mean, error = _forward_backward_convergence_estimate( + sample, estimator, my_estimator, error_tol, **kwargs + ) + forward_list.append(mean) + forward_error_list.append(error) logger.info( "{:.2f} +/- {:.2f} kT".format(forward_list[-1], forward_error_list[-1]) ) @@ -122,23 +120,11 @@ def forward_backward_convergence(df_list, estimator="MBAR", num=10, **kwargs): sample = [] for data in df_list: sample.append(data[-len(data) // num * i :]) - sample = concat(sample) - result = my_estimator.fit(sample) - if estimator == "MBAR": - my_estimator.initial_f_k = result.delta_f_.iloc[0, :] - backward_list.append(result.delta_f_.iloc[0, -1]) - if estimator.lower() == "bar": - error = np.sqrt( - sum( - [ - result.d_delta_f_.iloc[i, i + 1] ** 2 - for i in range(len(result.d_delta_f_) - 1) - ] - ) - ) - backward_error_list.append(error) - else: - backward_error_list.append(result.d_delta_f_.iloc[0, -1]) + mean, error = _forward_backward_convergence_estimate( + sample, estimator, my_estimator, error_tol, **kwargs + ) + backward_list.append(mean) + backward_error_list.append(error) logger.info( "{:.2f} +/- {:.2f} kT".format(backward_list[-1], backward_error_list[-1]) ) @@ -156,6 +142,57 @@ def forward_backward_convergence(df_list, estimator="MBAR", num=10, **kwargs): return convergence +def _forward_backward_convergence_estimate( + sample_list: List[pd.DataFrame], + estimator: str, + my_estimator: BaseEstimator, + error_tol: float, + **kwargs: Any, +) -> Tuple[float, float]: + """Use estimator to run the estimation and return the mean and error. + + Parameters + ---------- + sample_list: A list of samples as pandas Dataframe. + estimator: The string of the estimator + my_estimator: The estimator object. + error_tol: The error tolerance. + kwargs + + Returns + ------- + mean: The delta_f between 0 and 1 + error: The d_delta_f between 0 and 1 + """ + sample = concat(sample_list) + result = my_estimator.fit(sample) + if estimator == "MBAR": + my_estimator.initial_f_k = result.delta_f_.iloc[0, :] + mean = result.delta_f_.iloc[0, -1] + if estimator.lower() == "bar": + error = np.sqrt( + sum( + [ + result.d_delta_f_.iloc[i, i + 1] ** 2 + for i in range(len(result.d_delta_f_) - 1) + ] + ) + ) + else: + error = result.d_delta_f_.iloc[0, -1] + if estimator.lower() == "mbar" and error > error_tol: + logger.warning( + f"Statistical Error ({error}) bigger than error tolerance ({error_tol}), use bootstrap error instead." + ) + bootstraps_estimator = estimators_dispatch[estimator]( + n_bootstraps=50, initial_f_k=result.delta_f_.iloc[0, :], **kwargs + ) + bootstraps_estimator.fit(sample) + error = bootstraps_estimator.d_delta_f_.iloc[0, -1] + + return mean, error + + def _cummean(vals, out_length): """The cumulative mean of an array. diff --git a/src/alchemlyb/tests/test_convergence.py b/src/alchemlyb/tests/test_convergence.py index 32fa1eb9..2bde94b3 100644 --- a/src/alchemlyb/tests/test_convergence.py +++ b/src/alchemlyb/tests/test_convergence.py @@ -36,6 +36,15 @@ def test_convergence_wrong_cases(gmx_benzene_Coulomb_u_nk): forward_backward_convergence(gmx_benzene_Coulomb_u_nk, "mbar") +def test_convergence_bootstrap(gmx_benzene_Coulomb_u_nk, caplog): + normal_c = forward_backward_convergence(gmx_benzene_Coulomb_u_nk, "mbar", num=2) + bootstrap_c = forward_backward_convergence( + gmx_benzene_Coulomb_u_nk, "mbar", error_tol=0.01, num=2 + ) + assert "use bootstrap error instead." in caplog.text + assert (bootstrap_c["Forward_Error"] != normal_c["Forward_Error"]).all() + + def test_convergence_method(gmx_benzene_Coulomb_u_nk): convergence = forward_backward_convergence( gmx_benzene_Coulomb_u_nk, "MBAR", num=2, method="adaptive"