Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve the error handling in the backward and forward convergence #358

Merged
merged 23 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGES
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ Changes
`None` (start from all zeros) as this change provides a sizable speedup (PR #357)

Enhancements
- `forward_backward_convergence` uses the bootstrap error when the statistical error
is too large. (PR #358)
- `BAR` result is used as initial guess for `MBAR` estimator. (PR #357)
- `forward_backward_convergence` uses the result from the previous step as the initial guess for the next step. (PR #357)

Expand Down
107 changes: 72 additions & 35 deletions src/alchemlyb/convergence/convergence.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Functions for assessing convergence of free energy estimates and raw data."""

from typing import Any, List, Tuple
from warnings import warn

import numpy as np
import pandas as pd
from loguru import logger
from sklearn.base import BaseEstimator

from .. import concat
from ..estimators import BAR, TI, MBAR, FEP_ESTIMATORS, TI_ESTIMATORS
Expand All @@ -13,7 +15,9 @@
estimators_dispatch = {"BAR": BAR, "TI": TI, "MBAR": MBAR}


def forward_backward_convergence(df_list, estimator="MBAR", num=10, **kwargs):
def forward_backward_convergence(
df_list, estimator="MBAR", num=10, error_tol: float = 3, **kwargs
):
"""Forward and backward convergence of the free energy estimate.

Generate the free energy estimate as a function of time in both directions,
Expand All @@ -35,6 +39,12 @@ def forward_backward_convergence(df_list, estimator="MBAR", num=10, **kwargs):
Lower case input is also accepted until release 2.0.0.
num : int
The number of time points.
error_tol : float
The maximum error tolerated for analytic error. If the analytic error is
bigger than the error tolerance, the bootstrap error will be used.
xiki-tempula marked this conversation as resolved.
Show resolved Hide resolved

.. versionadded:: 2.3.0

kwargs : dict
Keyword arguments to be passed to the estimator.

Expand Down Expand Up @@ -93,23 +103,11 @@ def forward_backward_convergence(df_list, estimator="MBAR", num=10, **kwargs):
sample = []
for data in df_list:
sample.append(data[: len(data) // num * i])
sample = concat(sample)
result = my_estimator.fit(sample)
if estimator == "MBAR":
my_estimator.initial_f_k = result.delta_f_.iloc[0, :]
orbeckst marked this conversation as resolved.
Show resolved Hide resolved
forward_list.append(result.delta_f_.iloc[0, -1])
if estimator.lower() == "bar":
error = np.sqrt(
sum(
[
result.d_delta_f_.iloc[i, i + 1] ** 2
for i in range(len(result.d_delta_f_) - 1)
]
)
)
forward_error_list.append(error)
else:
forward_error_list.append(result.d_delta_f_.iloc[0, -1])
mean, error = _forward_backward_convergence_estimate(
sample, estimator, my_estimator, error_tol, **kwargs
)
forward_list.append(mean)
forward_error_list.append(error)
logger.info(
"{:.2f} +/- {:.2f} kT".format(forward_list[-1], forward_error_list[-1])
)
Expand All @@ -122,23 +120,11 @@ def forward_backward_convergence(df_list, estimator="MBAR", num=10, **kwargs):
sample = []
for data in df_list:
sample.append(data[-len(data) // num * i :])
sample = concat(sample)
result = my_estimator.fit(sample)
if estimator == "MBAR":
my_estimator.initial_f_k = result.delta_f_.iloc[0, :]
orbeckst marked this conversation as resolved.
Show resolved Hide resolved
backward_list.append(result.delta_f_.iloc[0, -1])
if estimator.lower() == "bar":
error = np.sqrt(
sum(
[
result.d_delta_f_.iloc[i, i + 1] ** 2
for i in range(len(result.d_delta_f_) - 1)
]
)
)
backward_error_list.append(error)
else:
backward_error_list.append(result.d_delta_f_.iloc[0, -1])
mean, error = _forward_backward_convergence_estimate(
sample, estimator, my_estimator, error_tol, **kwargs
)
backward_list.append(mean)
backward_error_list.append(error)
logger.info(
"{:.2f} +/- {:.2f} kT".format(backward_list[-1], backward_error_list[-1])
)
Expand All @@ -156,6 +142,57 @@ def forward_backward_convergence(df_list, estimator="MBAR", num=10, **kwargs):
return convergence


def _forward_backward_convergence_estimate(
sample_list: List[pd.DataFrame],
estimator: str,
my_estimator: BaseEstimator,
error_tol: float,
**kwargs: Any,
) -> Tuple[float, float]:
"""Use estimator to run the estimation and return the mean and error.

Parameters
----------
sample_list: A list of samples as pandas Dataframe.
estimator: The string of the estimator
my_estimator: The estimator object.
error_tol: The error tolerance.
kwargs

Returns
-------
mean: The delta_f between 0 and 1
error: The d_delta_f between 0 and 1
"""
sample = concat(sample_list)
result = my_estimator.fit(sample)
if estimator == "MBAR":
my_estimator.initial_f_k = result.delta_f_.iloc[0, :]
mean = result.delta_f_.iloc[0, -1]
if estimator.lower() == "bar":
error = np.sqrt(
sum(
[
result.d_delta_f_.iloc[i, i + 1] ** 2
for i in range(len(result.d_delta_f_) - 1)
]
)
)
else:
error = result.d_delta_f_.iloc[0, -1]
if estimator.lower() == "mbar" and error > error_tol:
logger.warning(
f"Statistical Error ({error}) bigger than error tolerance ({error_tol}), use bootstrap error instead."
)
bootstraps_estimator = estimators_dispatch[estimator](
n_bootstraps=50, initial_f_k=result.delta_f_.iloc[0, :], **kwargs
)
bootstraps_estimator.fit(sample)
error = bootstraps_estimator.d_delta_f_.iloc[0, -1]

return mean, error


def _cummean(vals, out_length):
"""The cumulative mean of an array.

Expand Down
9 changes: 9 additions & 0 deletions src/alchemlyb/tests/test_convergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,15 @@ def test_convergence_wrong_cases(gmx_benzene_Coulomb_u_nk):
forward_backward_convergence(gmx_benzene_Coulomb_u_nk, "mbar")


def test_convergence_bootstrap(gmx_benzene_Coulomb_u_nk, caplog):
normal_c = forward_backward_convergence(gmx_benzene_Coulomb_u_nk, "mbar", num=2)
bootstrap_c = forward_backward_convergence(
gmx_benzene_Coulomb_u_nk, "mbar", error_tol=0.01, num=2
)
assert "use bootstrap error instead." in caplog.text
assert (bootstrap_c["Forward_Error"] != normal_c["Forward_Error"]).all()


def test_convergence_method(gmx_benzene_Coulomb_u_nk):
convergence = forward_backward_convergence(
gmx_benzene_Coulomb_u_nk, "MBAR", num=2, method="adaptive"
Expand Down
Loading