diff --git a/src/alchemlyb/convergence/convergence.py b/src/alchemlyb/convergence/convergence.py
index 372bd176..654fea2d 100644
--- a/src/alchemlyb/convergence/convergence.py
+++ b/src/alchemlyb/convergence/convergence.py
@@ -3,17 +3,16 @@
import logging
from warnings import warn
-import pandas as pd
import numpy as np
+import pandas as pd
-from ..estimators import BAR, TI, FEP_ESTIMATORS, TI_ESTIMATORS
-from ..estimators import AutoMBAR as MBAR
from .. import concat
+from ..estimators import FEP_ESTIMATORS, TI_ESTIMATORS
from ..postprocessors.units import to_kT
-def forward_backward_convergence(df_list, estimator='MBAR', num=10, **kwargs):
- '''Forward and backward convergence of the free energy estimate.
+def forward_backward_convergence(df_list, estimator="MBAR", num=10, **kwargs):
+ """Forward and backward convergence of the free energy estimate.
Generate the free energy estimate as a function of time in both directions,
with the specified number of equally spaced points in the time
@@ -69,16 +68,17 @@ def forward_backward_convergence(df_list, estimator='MBAR', num=10, **kwargs):
The default for using ``estimator='MBAR'`` was changed from
:class:`~alchemlyb.estimators.MBAR` to :class:`~alchemlyb.estimators.AutoMBAR`.
- '''
- logger = logging.getLogger('alchemlyb.convergence.'
- 'forward_backward_convergence')
- logger.info('Start convergence analysis.')
- logger.info('Check data availability.')
+ """
+ logger = logging.getLogger("alchemlyb.convergence." "forward_backward_convergence")
+ logger.info("Start convergence analysis.")
+ logger.info("Check data availability.")
if estimator.upper() != estimator:
- warn("Using lower-case strings for the 'estimator' kwarg in "
- "convergence.forward_backward_convergence() is deprecated in "
- "1.0.0 and only upper case will be accepted in 2.0.0",
- DeprecationWarning)
+ warn(
+ "Using lower-case strings for the 'estimator' kwarg in "
+ "convergence.forward_backward_convergence() is deprecated in "
+ "1.0.0 and only upper case will be accepted in 2.0.0",
+ DeprecationWarning,
+ )
estimator = estimator.upper()
if estimator not in (FEP_ESTIMATORS + TI_ESTIMATORS):
@@ -88,61 +88,77 @@ def forward_backward_convergence(df_list, estimator='MBAR', num=10, **kwargs):
else:
# select estimator class by name
estimator_fit = globals()[estimator](**kwargs).fit
- logger.info(f'Use {estimator} estimator for convergence analysis.')
+ logger.info(f"Use {estimator} estimator for convergence analysis.")
- logger.info('Begin forward analysis')
+ logger.info("Begin forward analysis")
forward_list = []
forward_error_list = []
for i in range(1, num + 1):
- logger.info('Forward analysis: {:.2f}%'.format(100 * i / num))
+ logger.info("Forward analysis: {:.2f}%".format(100 * i / num))
sample = []
for data in df_list:
- sample.append(data[:len(data) // num * i])
+ sample.append(data[: len(data) // num * i])
sample = concat(sample)
result = estimator_fit(sample)
forward_list.append(result.delta_f_.iloc[0, -1])
- if estimator.lower() == 'bar':
- error = np.sqrt(sum(
- [result.d_delta_f_.iloc[i, i + 1] ** 2
- for i in range(len(result.d_delta_f_) - 1)]))
+ if estimator.lower() == "bar":
+ error = np.sqrt(
+ sum(
+ [
+ result.d_delta_f_.iloc[i, i + 1] ** 2
+ for i in range(len(result.d_delta_f_) - 1)
+ ]
+ )
+ )
forward_error_list.append(error)
else:
forward_error_list.append(result.d_delta_f_.iloc[0, -1])
- logger.info('{:.2f} +/- {:.2f} kT'.format(forward_list[-1],
- forward_error_list[-1]))
+ logger.info(
+ "{:.2f} +/- {:.2f} kT".format(forward_list[-1], forward_error_list[-1])
+ )
- logger.info('Begin backward analysis')
+ logger.info("Begin backward analysis")
backward_list = []
backward_error_list = []
for i in range(1, num + 1):
- logger.info('Backward analysis: {:.2f}%'.format(100 * i / num))
+ logger.info("Backward analysis: {:.2f}%".format(100 * i / num))
sample = []
for data in df_list:
- sample.append(data[-len(data) // num * i:])
+ sample.append(data[-len(data) // num * i :])
sample = concat(sample)
result = estimator_fit(sample)
backward_list.append(result.delta_f_.iloc[0, -1])
- if estimator.lower() == 'bar':
- error = np.sqrt(sum(
- [result.d_delta_f_.iloc[i, i + 1] ** 2
- for i in range(len(result.d_delta_f_) - 1)]))
+ if estimator.lower() == "bar":
+ error = np.sqrt(
+ sum(
+ [
+ result.d_delta_f_.iloc[i, i + 1] ** 2
+ for i in range(len(result.d_delta_f_) - 1)
+ ]
+ )
+ )
backward_error_list.append(error)
else:
backward_error_list.append(result.d_delta_f_.iloc[0, -1])
- logger.info('{:.2f} +/- {:.2f} kT'.format(backward_list[-1],
- backward_error_list[-1]))
+ logger.info(
+ "{:.2f} +/- {:.2f} kT".format(backward_list[-1], backward_error_list[-1])
+ )
convergence = pd.DataFrame(
- {'Forward': forward_list,
- 'Forward_Error': forward_error_list,
- 'Backward': backward_list,
- 'Backward_Error': backward_error_list,
- 'data_fraction': [i / num for i in range(1, num + 1)]})
+ {
+ "Forward": forward_list,
+ "Forward_Error": forward_error_list,
+ "Backward": backward_list,
+ "Backward_Error": backward_error_list,
+ "data_fraction": [i / num for i in range(1, num + 1)],
+ }
+ )
convergence.attrs = df_list[0].attrs
return convergence
+
def _cummean(vals, out_length):
- '''The cumulative mean of an array.
+ """The cumulative mean of an array.
This function computes the cumulative mean and shapes the result to the
desired length.
@@ -167,18 +183,19 @@ def _cummean(vals, out_length):
.. versionadded:: 1.0.0
- '''
+ """
in_length = len(vals)
if in_length < out_length:
out_length = in_length
block = in_length // out_length
- reshape = vals[: block*out_length].reshape(block, out_length)
+ reshape = vals[: block * out_length].reshape(block, out_length)
mean = np.mean(reshape, axis=0)
- result = np.cumsum(mean) / np.arange(1, out_length+1)
+ result = np.cumsum(mean) / np.arange(1, out_length + 1)
return result
+
def fwdrev_cumavg_Rc(series, precision=0.01, tol=2):
- '''Generate the convergence criteria :math:`R_c` for a single simulation.
+ """Generate the convergence criteria :math:`R_c` for a single simulation.
The input will be :class:`pandas.Series` generated by
:func:`~alchemlyb.preprocessing.subsampling.decorrelate_u_nk` or
@@ -241,7 +258,7 @@ def fwdrev_cumavg_Rc(series, precision=0.01, tol=2):
.. _`equation 16`:
https://www.ncbi.nlm.nih.gov/pmc/articles/PMC8397498/#FD16
- '''
+ """
series = to_kT(series)
array = series.to_numpy()
out_length = int(1 / precision)
@@ -250,9 +267,12 @@ def fwdrev_cumavg_Rc(series, precision=0.01, tol=2):
length = len(g_forward)
convergence = pd.DataFrame(
- {'Forward': g_forward,
- 'Backward': g_backward,
- 'data_fraction': [i / length for i in range(1, length + 1)]})
+ {
+ "Forward": g_forward,
+ "Backward": g_backward,
+ "data_fraction": [i / length for i in range(1, length + 1)],
+ }
+ )
convergence.attrs = series.attrs
# Final value
@@ -270,8 +290,9 @@ def fwdrev_cumavg_Rc(series, precision=0.01, tol=2):
# the same as this branch will be triggered.
return 1.0, convergence
+
def A_c(series_list, precision=0.01, tol=2):
- '''Generate the ensemble convergence criteria :math:`A_c` for a set of simulations.
+ """Generate the ensemble convergence criteria :math:`A_c` for a set of simulations.
The input is a :class:`list` of :class:`pandas.Series` generated by
:func:`~alchemlyb.preprocessing.subsampling.decorrelate_u_nk` or
@@ -317,11 +338,11 @@ def A_c(series_list, precision=0.01, tol=2):
.. _`equation 18`:
https://www.ncbi.nlm.nih.gov/pmc/articles/PMC8397498/#FD18
- '''
- logger = logging.getLogger('alchemlyb.convergence.A_c')
+ """
+ logger = logging.getLogger("alchemlyb.convergence.A_c")
n_R_c = len(series_list)
R_c_list = [fwdrev_cumavg_Rc(series, precision, tol)[0] for series in series_list]
- logger.info(f'R_c list: {R_c_list}')
+ logger.info(f"R_c list: {R_c_list}")
# Integrate the R_c_list <= R_c over the range of 0 to 1
array_01 = np.hstack((R_c_list, [0, 1]))
sorted_array = np.sort(np.unique(array_01))
@@ -330,6 +351,6 @@ def A_c(series_list, precision=0.01, tol=2):
if i == 0:
continue
else:
- d_R_c = sorted_array[-i] - sorted_array[-i-1]
+ d_R_c = sorted_array[-i] - sorted_array[-i - 1]
result += d_R_c * sum(R_c_list <= element) / n_R_c
return result
diff --git a/src/alchemlyb/estimators/__init__.py b/src/alchemlyb/estimators/__init__.py
index ca48015b..4b4e7771 100644
--- a/src/alchemlyb/estimators/__init__.py
+++ b/src/alchemlyb/estimators/__init__.py
@@ -1,5 +1,5 @@
-from .mbar_ import MBAR, AutoMBAR
from .bar_ import BAR
+from .mbar_ import MBAR, AutoMBAR
from .ti_ import TI
FEP_ESTIMATORS = [MBAR.__name__, AutoMBAR.__name__, BAR.__name__]
diff --git a/src/alchemlyb/estimators/bar_.py b/src/alchemlyb/estimators/bar_.py
index 3a7150b2..7bf39bc7 100644
--- a/src/alchemlyb/estimators/bar_.py
+++ b/src/alchemlyb/estimators/bar_.py
@@ -1,11 +1,11 @@
import numpy as np
import pandas as pd
-
-from sklearn.base import BaseEstimator
from pymbar import BAR as BAR_
+from sklearn.base import BaseEstimator
from .base import _EstimatorMixOut
+
class BAR(BaseEstimator, _EstimatorMixOut):
"""Bennett acceptance ratio (BAR).
@@ -57,7 +57,13 @@ class BAR(BaseEstimator, _EstimatorMixOut):
"""
- def __init__(self, maximum_iterations=10000, relative_tolerance=1.0e-7, method='false-position', verbose=False):
+ def __init__(
+ self,
+ maximum_iterations=10000,
+ relative_tolerance=1.0e-7,
+ method="false-position",
+ verbose=False,
+ ):
self.maximum_iterations = maximum_iterations
self.relative_tolerance = relative_tolerance
@@ -87,7 +93,10 @@ def fit(self, u_nk):
# group u_nk by lambda states
groups = u_nk.groupby(level=u_nk.index.names[1:])
- N_k = [(len(groups.get_group(i)) if i in groups.groups else 0) for i in u_nk.columns]
+ N_k = [
+ (len(groups.get_group(i)) if i in groups.groups else 0)
+ for i in u_nk.columns
+ ]
# Now get free energy differences and their uncertainties for each step
deltas = np.array([])
@@ -96,19 +105,22 @@ def fit(self, u_nk):
# get us from lambda step k
uk = groups.get_group(self._states_[k])
# get w_F
- w_f = uk.iloc[:, k+1] - uk.iloc[:, k]
+ w_f = uk.iloc[:, k + 1] - uk.iloc[:, k]
# get us from lambda step k+1
- uk1 = groups.get_group(self._states_[k+1])
+ uk1 = groups.get_group(self._states_[k + 1])
# get w_R
- w_r = uk1.iloc[:, k] - uk1.iloc[:, k+1]
+ w_r = uk1.iloc[:, k] - uk1.iloc[:, k + 1]
# now determine df and ddf using pymbar.BAR
- df, ddf = BAR_(w_f, w_r,
- method=self.method,
- maximum_iterations=self.maximum_iterations,
- relative_tolerance=self.relative_tolerance,
- verbose=self.verbose)
+ df, ddf = BAR_(
+ w_f,
+ w_r,
+ method=self.method,
+ maximum_iterations=self.maximum_iterations,
+ relative_tolerance=self.relative_tolerance,
+ verbose=self.verbose,
+ )
deltas = np.append(deltas, df)
d_deltas = np.append(d_deltas, ddf**2)
@@ -121,14 +133,14 @@ def fit(self, u_nk):
out = []
dout = []
for i in range(len(deltas) - j):
- out.append(deltas[i:i + j + 1].sum())
+ out.append(deltas[i : i + j + 1].sum())
# See https://github.com/alchemistry/alchemlyb/pull/60#issuecomment-430720742
# Error estimate generated by BAR ARE correlated
# Use the BAR uncertainties between two neighbour states
if j == 0:
- dout.append(d_deltas[i:i + j + 1].sum())
+ dout.append(d_deltas[i : i + j + 1].sum())
# Other uncertainties are unknown at this point
else:
dout.append(np.nan)
@@ -137,14 +149,14 @@ def fit(self, u_nk):
ad_delta += np.diagflat(np.array(dout), k=j + 1)
# yield standard delta_f_ free energies between each state
- self._delta_f_ = pd.DataFrame(adelta - adelta.T,
- columns=self._states_,
- index=self._states_)
+ self._delta_f_ = pd.DataFrame(
+ adelta - adelta.T, columns=self._states_, index=self._states_
+ )
# yield standard deviation d_delta_f_ between each state
- self._d_delta_f_ = pd.DataFrame(np.sqrt(ad_delta + ad_delta.T),
- columns=self._states_,
- index=self._states_)
+ self._d_delta_f_ = pd.DataFrame(
+ np.sqrt(ad_delta + ad_delta.T), columns=self._states_, index=self._states_
+ )
self._delta_f_.attrs = u_nk.attrs
self._d_delta_f_.attrs = u_nk.attrs
diff --git a/src/alchemlyb/estimators/base.py b/src/alchemlyb/estimators/base.py
index e6b1b8be..93f3da8a 100644
--- a/src/alchemlyb/estimators/base.py
+++ b/src/alchemlyb/estimators/base.py
@@ -1,9 +1,11 @@
-class _EstimatorMixOut():
- '''This class creates view for the d_delta_f_, delta_f_, states_ for the
- estimator class to consume.'''
+class _EstimatorMixOut:
+ """This class creates view for the d_delta_f_, delta_f_, states_ for the
+ estimator class to consume."""
+
_d_delta_f_ = None
_delta_f_ = None
_states_ = None
+
@property
def d_delta_f_(self):
return self._d_delta_f_
@@ -15,4 +17,3 @@ def delta_f_(self):
@property
def states_(self):
return self._states_
-
\ No newline at end of file
diff --git a/src/alchemlyb/estimators/mbar_.py b/src/alchemlyb/estimators/mbar_.py
index d34434e9..4759d687 100644
--- a/src/alchemlyb/estimators/mbar_.py
+++ b/src/alchemlyb/estimators/mbar_.py
@@ -1,12 +1,12 @@
-import numpy as np
-import pandas as pd
import logging
-from sklearn.base import BaseEstimator
+import pandas as pd
import pymbar
+from sklearn.base import BaseEstimator
from .base import _EstimatorMixOut
+
class MBAR(BaseEstimator, _EstimatorMixOut):
"""Multi-state Bennett acceptance ratio (MBAR).
@@ -62,14 +62,20 @@ class MBAR(BaseEstimator, _EstimatorMixOut):
`delta_f_`, `d_delta_f_`, `states_` are view of the original object.
"""
- def __init__(self, maximum_iterations=10000, relative_tolerance=1.0e-7,
- initial_f_k=None, method='hybr', verbose=False):
+ def __init__(
+ self,
+ maximum_iterations=10000,
+ relative_tolerance=1.0e-7,
+ initial_f_k=None,
+ method="hybr",
+ verbose=False,
+ ):
self.maximum_iterations = maximum_iterations
self.relative_tolerance = relative_tolerance
self.initial_f_k = initial_f_k
self.method = method
self.verbose = verbose
- self.logger = logging.getLogger('alchemlyb.estimators.MBAR')
+ self.logger = logging.getLogger("alchemlyb.estimators.MBAR")
# handle for pymbar.MBAR object
self._mbar = None
@@ -90,22 +96,24 @@ def fit(self, u_nk):
u_nk = u_nk.sort_index(level=u_nk.index.names[1:])
groups = u_nk.groupby(level=u_nk.index.names[1:])
- N_k = [(len(groups.get_group(i)) if i in groups.groups else 0) for i in
- u_nk.columns]
+ N_k = [
+ (len(groups.get_group(i)) if i in groups.groups else 0)
+ for i in u_nk.columns
+ ]
self._states_ = u_nk.columns.values.tolist()
# Prepare the solver_protocol as stated in https://github.com/choderalab/pymbar/issues/419#issuecomment-803714103
- solver_options = {"maximum_iterations": self.maximum_iterations,
- "verbose": self.verbose}
- solver_protocol = {"method": self.method,
- "options": solver_options}
+ solver_options = {
+ "maximum_iterations": self.maximum_iterations,
+ "verbose": self.verbose,
+ }
+ solver_protocol = {"method": self.method, "options": solver_options}
self._mbar, out = self._do_MBAR(u_nk, N_k, solver_protocol)
- free_energy_differences = [pd.DataFrame(i,
- columns=self._states_,
- index=self._states_) for i in
- out]
+ free_energy_differences = [
+ pd.DataFrame(i, columns=self._states_, index=self._states_) for i in out
+ ]
(self._delta_f_, self._d_delta_f_, self.theta_) = free_energy_differences
@@ -118,15 +126,20 @@ def predict(self, u_ln):
pass
def _do_MBAR(self, u_nk, N_k, solver_protocol):
- mbar = pymbar.MBAR(u_nk.T, N_k,
- relative_tolerance=self.relative_tolerance,
- initial_f_k=self.initial_f_k,
- solver_protocol=(solver_protocol,))
- self.logger.info("Solved MBAR equations with method %r and "
- "maximum_iterations=%d, relative_tolerance=%g",
- solver_protocol['method'],
- solver_protocol['options']['maximum_iterations'],
- self.relative_tolerance)
+ mbar = pymbar.MBAR(
+ u_nk.T,
+ N_k,
+ relative_tolerance=self.relative_tolerance,
+ initial_f_k=self.initial_f_k,
+ solver_protocol=(solver_protocol,),
+ )
+ self.logger.info(
+ "Solved MBAR equations with method %r and "
+ "maximum_iterations=%d, relative_tolerance=%g",
+ solver_protocol["method"],
+ solver_protocol["options"]["maximum_iterations"],
+ self.relative_tolerance,
+ )
# set attributes
out = mbar.getFreeEnergyDifferences(return_theta=True)
return mbar, out
@@ -145,7 +158,7 @@ def overlap_matrix(self):
---------
pymbar.mbar.MBAR.computeOverlap
"""
- return self._mbar.computeOverlap()['matrix']
+ return self._mbar.computeOverlap()["matrix"]
class AutoMBAR(MBAR):
@@ -188,31 +201,42 @@ class AutoMBAR(MBAR):
.. versionchanged:: 1.0.0
AutoMBAR accepts the `method` argument.
"""
- def __init__(self, maximum_iterations=10000, relative_tolerance=1.0e-7,
- initial_f_k=None, verbose=False, method=None):
- super().__init__(maximum_iterations=maximum_iterations,
- relative_tolerance=relative_tolerance,
- initial_f_k=initial_f_k,
- verbose=verbose, method=method)
- self.logger = logging.getLogger('alchemlyb.estimators.AutoMBAR')
+
+ def __init__(
+ self,
+ maximum_iterations=10000,
+ relative_tolerance=1.0e-7,
+ initial_f_k=None,
+ verbose=False,
+ method=None,
+ ):
+ super().__init__(
+ maximum_iterations=maximum_iterations,
+ relative_tolerance=relative_tolerance,
+ initial_f_k=initial_f_k,
+ verbose=verbose,
+ method=method,
+ )
+ self.logger = logging.getLogger("alchemlyb.estimators.AutoMBAR")
def _do_MBAR(self, u_nk, N_k, solver_protocol):
if solver_protocol["method"] is None:
- self.logger.info('Initialise the automatic routine of the MBAR '
- 'estimator.')
+ self.logger.info(
+ "Initialise the automatic routine of the MBAR " "estimator."
+ )
# Try the fastest method first
try:
- self.logger.info('Trying the hybr method.')
- solver_protocol["method"] = 'hybr'
+ self.logger.info("Trying the hybr method.")
+ solver_protocol["method"] = "hybr"
mbar, out = super()._do_MBAR(u_nk, N_k, solver_protocol)
except pymbar.utils.ParameterError:
try:
- self.logger.info('Trying the adaptive method.')
- solver_protocol["method"] = 'adaptive'
+ self.logger.info("Trying the adaptive method.")
+ solver_protocol["method"] = "adaptive"
mbar, out = super()._do_MBAR(u_nk, N_k, solver_protocol)
except pymbar.utils.ParameterError:
- self.logger.info('Trying the BFGS method.')
- solver_protocol["method"] = 'BFGS'
+ self.logger.info("Trying the BFGS method.")
+ solver_protocol["method"] = "BFGS"
mbar, out = super()._do_MBAR(u_nk, N_k, solver_protocol)
return mbar, out
else:
diff --git a/src/alchemlyb/estimators/ti_.py b/src/alchemlyb/estimators/ti_.py
index e01b8f72..bef379cc 100644
--- a/src/alchemlyb/estimators/ti_.py
+++ b/src/alchemlyb/estimators/ti_.py
@@ -1,10 +1,10 @@
import numpy as np
import pandas as pd
-
from sklearn.base import BaseEstimator
from .base import _EstimatorMixOut
+
class TI(BaseEstimator, _EstimatorMixOut):
"""Thermodynamic integration (TI).
@@ -71,43 +71,49 @@ def fit(self, dHdl):
dl = means.reset_index()[means.index.names[:]].diff().iloc[1:].values
# apply trapezoid rule to obtain DF between each adjacent state
- deltas = (dl * (means.iloc[:-1].values + means.iloc[1:].values)/2).sum(axis=1)
+ deltas = (dl * (means.iloc[:-1].values + means.iloc[1:].values) / 2).sum(axis=1)
# build matrix of deltas between each state
- adelta = np.zeros((len(deltas)+1, len(deltas)+1))
+ adelta = np.zeros((len(deltas) + 1, len(deltas) + 1))
ad_delta = np.zeros_like(adelta)
for j in range(len(deltas)):
out = []
dout = []
for i in range(len(deltas) - j):
- out.append(deltas[i] + deltas[i+1:i+j+1].sum())
+ out.append(deltas[i] + deltas[i + 1 : i + j + 1].sum())
# Define additional zero lambda
a = [0.0] * len(l_types)
# Define dl series' with additional zero lambda on the left and right
- dll = np.insert(dl[i:i + j + 1], 0, [a], axis=0)
- dlr = np.append(dl[i:i + j + 1], [a], axis=0)
+ dll = np.insert(dl[i : i + j + 1], 0, [a], axis=0)
+ dlr = np.append(dl[i : i + j + 1], [a], axis=0)
# Get a series of the form: x1, x1 + x2, ..., x(n-1) + x(n), x(n)
dllr = dll + dlr
# Append deviation of free energy difference between state i and i+j+1
- dout.append((dllr ** 2 * variances.iloc[i:i + j + 2].values / 4).sum(axis=1).sum())
- adelta += np.diagflat(np.array(out), k=j+1)
- ad_delta += np.diagflat(np.array(dout), k=j+1)
+ dout.append(
+ (dllr**2 * variances.iloc[i : i + j + 2].values / 4)
+ .sum(axis=1)
+ .sum()
+ )
+ adelta += np.diagflat(np.array(out), k=j + 1)
+ ad_delta += np.diagflat(np.array(dout), k=j + 1)
# yield standard delta_f_ free energies between each state
- self._delta_f_ = pd.DataFrame(adelta - adelta.T,
- columns=means.index.values,
- index=means.index.values)
+ self._delta_f_ = pd.DataFrame(
+ adelta - adelta.T, columns=means.index.values, index=means.index.values
+ )
self.dhdl = means
# yield standard deviation d_delta_f_ between each state
- self._d_delta_f_ = pd.DataFrame(np.sqrt(ad_delta + ad_delta.T),
- columns=variances.index.values,
- index=variances.index.values)
+ self._d_delta_f_ = pd.DataFrame(
+ np.sqrt(ad_delta + ad_delta.T),
+ columns=variances.index.values,
+ index=variances.index.values,
+ )
self._states_ = means.index.values.tolist()
@@ -135,7 +141,9 @@ def separate_dhdl(self):
"""
if len(self.dhdl.index.names) == 1:
name = self.dhdl.columns[0]
- return [self.dhdl[name], ]
+ return [
+ self.dhdl[name],
+ ]
dhdl_list = []
# get the lambda names
l_types = self.dhdl.index.names
@@ -143,14 +151,14 @@ def separate_dhdl(self):
# Fix issue #148, where for pandas == 1.3.0
# lambdas = self.dhdl.reset_index()[list(l_types)]
lambdas = self.dhdl.reset_index()[l_types]
- diff = lambdas.diff().to_numpy(dtype='bool')
+ diff = lambdas.diff().to_numpy(dtype="bool")
# diff will give the first row as NaN so need to fix that
diff[0, :] = diff[1, :]
# Make sure that the start point is set to true as well
diff[:-1, :] = diff[:-1, :] | diff[1:, :]
for i in range(len(l_types)):
- if any(diff[:,i]):
- new = self.dhdl.iloc[diff[:,i], i]
+ if any(diff[:, i]):
+ new = self.dhdl.iloc[diff[:, i], i]
# drop all other index
for l in l_types:
if l != l_types[i]:
@@ -158,4 +166,3 @@ def separate_dhdl(self):
new.attrs = self.dhdl.attrs
dhdl_list.append(new)
return dhdl_list
-
diff --git a/src/alchemlyb/parsing/__init__.py b/src/alchemlyb/parsing/__init__.py
index 60165ac9..dc048732 100644
--- a/src/alchemlyb/parsing/__init__.py
+++ b/src/alchemlyb/parsing/__init__.py
@@ -1,33 +1,38 @@
from functools import wraps
+
def _init_attrs(func):
- '''Add temperature to the parsed dataframe.
+ """Add temperature to the parsed dataframe.
The temperature is added to the dataframe as dataframe.attrs['temperature']
and the energy unit is initiated as dataframe.attrs['energy_unit'] = 'kT'.
- '''
+ """
+
@wraps(func)
def wrapper(outfile, T, *args, **kwargs):
dataframe = func(outfile, T, *args, **kwargs)
if dataframe is not None:
- dataframe.attrs['temperature'] = T
- dataframe.attrs['energy_unit'] = 'kT'
+ dataframe.attrs["temperature"] = T
+ dataframe.attrs["energy_unit"] = "kT"
return dataframe
+
return wrapper
def _init_attrs_dict(func):
- '''Add temperature and energy units to the parsed dataframes.
+ """Add temperature and energy units to the parsed dataframes.
The temperature is added to the dataframe as dataframe.attrs['temperature']
and the energy unit is initiated as dataframe.attrs['energy_unit'] = 'kT'.
- '''
+ """
+
@wraps(func)
def wrapper(outfile, T, *args, **kwargs):
dict_with_df = func(outfile, T, *args, **kwargs)
for k in dict_with_df.keys():
if dict_with_df[k] is not None:
- dict_with_df[k].attrs['temperature'] = T
- dict_with_df[k].attrs['energy_unit'] = 'kT'
+ dict_with_df[k].attrs["temperature"] = T
+ dict_with_df[k].attrs["energy_unit"] = "kT"
return dict_with_df
+
return wrapper
diff --git a/src/alchemlyb/parsing/amber.py b/src/alchemlyb/parsing/amber.py
index 8e23ada1..d129a064 100644
--- a/src/alchemlyb/parsing/amber.py
+++ b/src/alchemlyb/parsing/amber.py
@@ -11,21 +11,21 @@
"""
-import re
import logging
+import re
-import pandas as pd
import numpy as np
+import pandas as pd
-from .util import anyopen
from . import _init_attrs_dict
+from .util import anyopen
from ..postprocessors.units import R_kJmol, kJ2kcal
logger = logging.getLogger("alchemlyb.parsers.Amber")
k_b = R_kJmol * kJ2kcal
-_FP_RE = r'[+-]?(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)?'
+_FP_RE = r"[+-]?(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)?"
def convert_to_pandas(file_datum):
@@ -39,10 +39,13 @@ def convert_to_pandas(file_datum):
data_dic["lambdas"].append(file_datum.clambda)
frame_time = file_datum.t0 + (frame_index + 1) * file_datum.dt * file_datum.ntpr
data_dic["time"].append(frame_time)
- df = pd.DataFrame(data_dic["dHdl"], columns=["dHdl"],
- index=pd.Index(data_dic["time"], name='time', dtype='Float64'))
+ df = pd.DataFrame(
+ data_dic["dHdl"],
+ columns=["dHdl"],
+ index=pd.Index(data_dic["time"], name="time", dtype="Float64"),
+ )
df["lambdas"] = data_dic["lambdas"][0]
- df = df.reset_index().set_index(['time'] + ['lambdas'])
+ df = df.reset_index().set_index(["time"] + ["lambdas"])
return df
@@ -59,7 +62,7 @@ def _pre_gen(it, first):
return
-class SectionParser():
+class SectionParser:
"""
A simple parser to extract data values from sections.
"""
@@ -68,7 +71,7 @@ def __init__(self, filename):
"""Opens a file according to its file type."""
self.filename = filename
try:
- self.fileh = anyopen(self.filename, 'r')
+ self.fileh = anyopen(self.filename, "r")
except:
logger.exception("Cannot open file %s", filename)
raise
@@ -93,7 +96,7 @@ def skip_after(self, pattern):
break
return Found_pattern
- def extract_section(self, start, end, fields, limit=None, extra=''):
+ def extract_section(self, start, end, fields, limit=None, extra=""):
"""
Extract data values (int, float) in fields from a section
marked with start and end regexes. Do not read further than
@@ -109,15 +112,15 @@ def extract_section(self, start, end, fields, limit=None, extra=''):
if inside:
if re.search(end, line):
break
- lines.append(line.rstrip('\n'))
- line = ''.join(lines)
+ lines.append(line.rstrip("\n"))
+ line = "".join(lines)
result = []
for field in fields:
- match = re.search(fr' {field}\s*=\s*(\*+|{_FP_RE}|\d+)', line)
+ match = re.search(rf" {field}\s*=\s*(\*+|{_FP_RE}|\d+)", line)
if match:
value = match.group(1)
- if '*' in value: # catch fortran format overflow
- result.append(float('Inf'))
+ if "*" in value: # catch fortran format overflow
+ result.append(float("Inf"))
else:
try:
result.append(int(value))
@@ -146,12 +149,21 @@ def __exit__(self, typ, value, traceback):
self.close()
-class FEData():
+class FEData:
"""A simple struct container to collect data from individual files."""
- __slots__ = ['clambda', 't0', 'dt', 'T', 'ntpr', 'gradients',
- 'mbar_energies',
- 'have_mbar', 'mbar_lambdas', 'mbar_lambda_idx']
+ __slots__ = [
+ "clambda",
+ "t0",
+ "dt",
+ "T",
+ "ntpr",
+ "gradients",
+ "mbar_energies",
+ "have_mbar",
+ "mbar_lambdas",
+ "mbar_lambda_idx",
+ ]
def __init__(self):
self.clambda = -1.0
@@ -170,7 +182,7 @@ def file_validation(outfile):
"""
Function that validate and parse an AMBER output file.
:exc:`ValueError` are risen if inconsinstencies in the input file are found.
-
+
Parameters
----------
outfile : str
@@ -189,76 +201,81 @@ def file_validation(outfile):
if not line:
logger.error("The file %s does not contain any data, it's empty.", outfile)
- raise ValueError(f'file {outfile} does not contain any data.')
+ raise ValueError(f"file {outfile} does not contain any data.")
- if not secp.skip_after('^ 2. CONTROL DATA FOR THE RUN'):
+ if not secp.skip_after("^ 2. CONTROL DATA FOR THE RUN"):
logger.error('No "CONTROL DATA" section found in file %s.', outfile)
raise ValueError(f'no "CONTROL DATA" section found in file {outfile}')
- ntpr, = secp.extract_section('^Nature and format of output:', '^$',
- ['ntpr'])
- nstlim, dt = secp.extract_section('Molecular dynamics:', '^$',
- ['nstlim', 'dt'])
- T, = secp.extract_section('temperature regulation:', '^$',
- ['temp0'])
+ (ntpr,) = secp.extract_section("^Nature and format of output:", "^$", ["ntpr"])
+ nstlim, dt = secp.extract_section("Molecular dynamics:", "^$", ["nstlim", "dt"])
+ (T,) = secp.extract_section("temperature regulation:", "^$", ["temp0"])
if not T:
logger.error('No valid "temp0" record found in file %s.', outfile)
raise ValueError(f'no valid "temp0" record found in file {outfile}')
- clambda, = secp.extract_section('^Free energy options:', '^$',
- ['clambda'], '^---')
+ (clambda,) = secp.extract_section(
+ "^Free energy options:", "^$", ["clambda"], "^---"
+ )
if clambda is None:
- logger.error('No free energy section found in file %s, "clambda" was None.', outfile)
- raise ValueError(f'no free energy section found in file {outfile}')
+ logger.error(
+ 'No free energy section found in file %s, "clambda" was None.', outfile
+ )
+ raise ValueError(f"no free energy section found in file {outfile}")
mbar_ndata = 0
- have_mbar, mbar_ndata, mbar_states = secp.extract_section('^FEP MBAR options:',
- '^$',
- ['ifmbar',
- 'bar_intervall',
- 'mbar_states'],
- '^---')
+ have_mbar, mbar_ndata, mbar_states = secp.extract_section(
+ "^FEP MBAR options:",
+ "^$",
+ ["ifmbar", "bar_intervall", "mbar_states"],
+ "^---",
+ )
if have_mbar:
mbar_ndata = int(nstlim / mbar_ndata)
mbar_lambdas = _process_mbar_lambdas(secp)
file_datum.mbar_lambdas = mbar_lambdas
- clambda_str = f'{clambda:6.4f}'
+ clambda_str = f"{clambda:6.4f}"
if clambda_str not in mbar_lambdas:
- logger.warning('WARNING: lamba %s not contained in set of '
- 'MBAR lambas: %s\nNot using MBAR.',
- clambda_str, ', '.join(mbar_lambdas))
+ logger.warning(
+ "WARNING: lamba %s not contained in set of "
+ "MBAR lambas: %s\nNot using MBAR.",
+ clambda_str,
+ ", ".join(mbar_lambdas),
+ )
have_mbar = False
else:
mbar_nlambda = len(mbar_lambdas)
if mbar_nlambda != mbar_states:
logger.error(
- 'the number of lambda windows read (%s)'
- 'is different from what expected (%d)',
- ','.join(mbar_lambdas), mbar_states)
+ "the number of lambda windows read (%s)"
+ "is different from what expected (%d)",
+ ",".join(mbar_lambdas),
+ mbar_states,
+ )
raise ValueError(
- f'the number of lambda windows read ({mbar_nlambda})'
- f' is different from what expected ({mbar_states})')
+ f"the number of lambda windows read ({mbar_nlambda})"
+ f" is different from what expected ({mbar_states})"
+ )
mbar_lambda_idx = mbar_lambdas.index(clambda_str)
file_datum.mbar_lambda_idx = mbar_lambda_idx
for _ in range(mbar_nlambda):
file_datum.mbar_energies.append([])
- if not secp.skip_after('^ 3. ATOMIC '):
+ if not secp.skip_after("^ 3. ATOMIC "):
logger.error('No "ATOMIC" section found in the file %s.', outfile)
raise ValueError(f'no "ATOMIC" section found in file {outfile}')
- t0, = secp.extract_section('^ begin time', '^$', ['coords'])
+ (t0,) = secp.extract_section("^ begin time", "^$", ["coords"])
if t0 is None:
- logger.error('No starting simulation time in file %s.', outfile)
- raise ValueError(f'No starting simulation time in file {outfile}')
+ logger.error("No starting simulation time in file %s.", outfile)
+ raise ValueError(f"No starting simulation time in file {outfile}")
- if not secp.skip_after('^ 4. RESULTS'):
+ if not secp.skip_after("^ 4. RESULTS"):
logger.error('No "RESULTS" section found in the file %s.', outfile)
raise ValueError(f'no "RESULTS" section found in file {outfile}')
-
file_datum.clambda = clambda
file_datum.t0 = t0
file_datum.dt = dt
@@ -293,13 +310,13 @@ def extract(outfile, T):
"""
- beta = 1/(k_b * T)
+ beta = 1 / (k_b * T)
file_datum = file_validation(outfile)
if not np.isclose(T, file_datum.T, atol=0.01):
- msg = f'The temperature read from the input file ({file_datum.T:.2f} K)'
- msg += f' is different from the temperature passed as parameter ({T:.2f} K)'
+ msg = f"The temperature read from the input file ({file_datum.T:.2f} K)"
+ msg += f" is different from the temperature passed as parameter ({T:.2f} K)"
logger.error(msg)
raise ValueError(msg)
@@ -311,18 +328,19 @@ def extract(outfile, T):
old_nstep = -1
for line in secp:
if " A V E R A G E S O V E R" in line:
- _ = secp.skip_after('^|=========================================')
- elif line.startswith(' NSTEP'):
- nstep, dvdl = secp.extract_section('^ NSTEP', '^ ---',
- ['NSTEP', 'DV/DL'],
- extra=line)
+ _ = secp.skip_after("^|=========================================")
+ elif line.startswith(" NSTEP"):
+ nstep, dvdl = secp.extract_section(
+ "^ NSTEP", "^ ---", ["NSTEP", "DV/DL"], extra=line
+ )
if nstep != old_nstep and dvdl is not None and nstep is not None:
file_datum.gradients.append(dvdl)
nensec += 1
old_nstep = nstep
- elif line.startswith('MBAR Energy analysis') and file_datum.have_mbar:
- mbar = secp.extract_section('^MBAR', '^ ---', file_datum.mbar_lambdas,
- extra=line)
+ elif line.startswith("MBAR Energy analysis") and file_datum.have_mbar:
+ mbar = secp.extract_section(
+ "^MBAR", "^ ---", file_datum.mbar_lambdas, extra=line
+ )
if None in mbar:
msg = "Something strange parsing the following MBAR section."
@@ -335,40 +353,48 @@ def extract(outfile, T):
if energy > 0.0:
high_E_cnt += 1
- file_datum.mbar_energies[lmbda].append(beta * (energy - reference_energy))
- elif line == ' 5. TIMINGS\n':
+ file_datum.mbar_energies[lmbda].append(
+ beta * (energy - reference_energy)
+ )
+ elif line == " 5. TIMINGS\n":
finished = True
break
if high_E_cnt:
- logger.warning('%i MBAR energ%s > 0.0 kcal/mol',
- high_E_cnt, 'ies are' if high_E_cnt > 1 else 'y is')
+ logger.warning(
+ "%i MBAR energ%s > 0.0 kcal/mol",
+ high_E_cnt,
+ "ies are" if high_E_cnt > 1 else "y is",
+ )
if not finished:
- logger.warning('WARNING: file %s is a prematurely terminated run', outfile)
+ logger.warning("WARNING: file %s is a prematurely terminated run", outfile)
if file_datum.have_mbar:
mbar_time = [
file_datum.t0 + (frame_index + 1) * file_datum.dt * file_datum.ntpr
- for frame_index in range(len(file_datum.mbar_energies[0]))]
+ for frame_index in range(len(file_datum.mbar_energies[0]))
+ ]
mbar_df = pd.DataFrame(
file_datum.mbar_energies,
index=np.array(file_datum.mbar_lambdas, dtype=np.float64),
columns=pd.MultiIndex.from_arrays(
- [mbar_time, np.repeat(file_datum.clambda, len(mbar_time))], names=['time', 'lambdas'])
- ).T
+ [mbar_time, np.repeat(file_datum.clambda, len(mbar_time))],
+ names=["time", "lambdas"],
+ ),
+ ).T
else:
logger.info('WARNING: No MBAR energies found! "u_nk" entry will be None')
mbar_df = None
if not nensec:
- logger.warning('WARNING: File %s does not contain any dV/dl data', outfile)
+ logger.warning("WARNING: File %s does not contain any dV/dl data", outfile)
dHdl_df = None
else:
- logger.info('Read %s dV/dl data points in file %s', nensec, outfile)
+ logger.info("Read %s dV/dl data points in file %s", nensec, outfile)
dHdl_df = convert_to_pandas(file_datum)
- dHdl_df['dHdl'] *= beta
+ dHdl_df["dHdl"] *= beta
return {"u_nk": mbar_df, "dHdl": dHdl_df}
@@ -395,7 +421,7 @@ def extract_dHdl(outfile, T):
"""
extracted = extract(outfile, T)
- return extracted['dHdl']
+ return extracted["dHdl"]
def extract_u_nk(outfile, T):
@@ -421,7 +447,7 @@ def extract_u_nk(outfile, T):
"""
extracted = extract(outfile, T)
- return extracted['u_nk']
+ return extracted["u_nk"]
def _process_mbar_lambdas(secp):
@@ -441,15 +467,15 @@ def _process_mbar_lambdas(secp):
mbar_lambdas = []
for line in secp:
- if line.startswith(' MBAR - lambda values considered:'):
+ if line.startswith(" MBAR - lambda values considered:"):
in_mbar = True
continue
if in_mbar:
- if line.startswith(' Extra'):
+ if line.startswith(" Extra"):
break
- if 'total' in line:
+ if "total" in line:
data = line.split()
mbar_lambdas.extend(data[2:])
else:
diff --git a/src/alchemlyb/parsing/gmx.py b/src/alchemlyb/parsing/gmx.py
index 00267c66..a9f83498 100644
--- a/src/alchemlyb/parsing/gmx.py
+++ b/src/alchemlyb/parsing/gmx.py
@@ -1,15 +1,16 @@
"""Parsers for extracting alchemical data from `Gromacs `_ output files.
"""
-import pandas as pd
import numpy as np
+import pandas as pd
-from .util import anyopen
from . import _init_attrs
+from .util import anyopen
from ..postprocessors.units import R_kJmol
k_b = R_kJmol
+
@_init_attrs
def extract_u_nk(xvg, T, filter=True):
r"""Return reduced potentials `u_nk` from a Hamiltonian differences XVG file.
@@ -60,9 +61,9 @@ def extract_u_nk(xvg, T, filter=True):
"""
h_col_match = r"\xD\f{}H \xl\f{}"
- pv_col_match = 'pV'
- u_col_match = ['Total Energy', 'Potential Energy']
- beta = 1/(k_b * T)
+ pv_col_match = "pV"
+ u_col_match = ["Total Energy", "Potential Energy"]
+ beta = 1 / (k_b * T)
state, lambdas, statevec = _extract_state(xvg)
@@ -82,7 +83,11 @@ def extract_u_nk(xvg, T, filter=True):
pv = df[pv_cols[0]]
# gromacs also gives us total/potential energy U directly; need this for reduced potential
- u_cols = [col for col in df.columns if any(single_u_col_match in col for single_u_col_match in u_col_match)]
+ u_cols = [
+ col
+ for col in df.columns
+ if any(single_u_col_match in col for single_u_col_match in u_col_match)
+ ]
u = None
if u_cols:
u = df[u_cols[0]]
@@ -90,7 +95,7 @@ def extract_u_nk(xvg, T, filter=True):
u_k = dict()
cols = list()
for col in dH:
- u_col = eval(col.split('to')[1])
+ u_col = eval(col.split("to")[1])
# calculate reduced potential u_k = dH + pV + U
u_k[u_col] = beta * dH[col].values
if pv_cols:
@@ -99,8 +104,9 @@ def extract_u_nk(xvg, T, filter=True):
u_k[u_col] += beta * u.values
cols.append(u_col)
- u_k = pd.DataFrame(u_k, columns=cols,
- index=pd.Index(times.values, name='time', dtype='Float64'))
+ u_k = pd.DataFrame(
+ u_k, columns=cols, index=pd.Index(times.values, name="time", dtype="Float64")
+ )
# create columns for each lambda, indicating state each row sampled from
# if state is None run as expanded ensemble data or REX
@@ -108,8 +114,8 @@ def extract_u_nk(xvg, T, filter=True):
# if thermodynamic state is specified map thermodynamic
# state data to lambda values, else (for REX)
# define state based on the legend
- if 'Thermodynamic state' in df:
- ts_index = df.columns.get_loc('Thermodynamic state')
+ if "Thermodynamic state" in df:
+ ts_index = df.columns.get_loc("Thermodynamic state")
thermo_state = df[df.columns[ts_index]]
for i, l in enumerate(lambdas):
v = []
@@ -128,13 +134,14 @@ def extract_u_nk(xvg, T, filter=True):
u_k[l] = statevec
# set up new multi-index
- newind = ['time'] + lambdas
+ newind = ["time"] + lambdas
u_k = u_k.reset_index().set_index(newind)
- u_k.name = 'u_nk'
+ u_k.name = "u_nk"
return u_k
+
@_init_attrs
def extract_dHdl(xvg, T, filter=True):
r"""Return gradients `dH/dl` from a Hamiltonian differences XVG file.
@@ -182,7 +189,7 @@ def extract_dHdl(xvg, T, filter=True):
parsed and is turned on by default.
"""
- beta = 1/(k_b * T)
+ beta = 1 / (k_b * T)
headers = _get_headers(xvg)
state, lambdas, statevec = _extract_state(xvg, headers)
@@ -204,10 +211,13 @@ def extract_dHdl(xvg, T, filter=True):
# rename columns to not include the word 'lambda', since we use this for
# index below
- cols = [l.split('-')[0] for l in lambdas]
+ cols = [l.split("-")[0] for l in lambdas]
- dHdl = pd.DataFrame(dHdl.values, columns=cols,
- index=pd.Index(times.values, name='time', dtype='Float64'))
+ dHdl = pd.DataFrame(
+ dHdl.values,
+ columns=cols,
+ index=pd.Index(times.values, name="time", dtype="Float64"),
+ )
# create columns for each lambda, indicating state each row sampled from
# if state is None run as expanded ensemble data or REX
@@ -215,8 +225,8 @@ def extract_dHdl(xvg, T, filter=True):
# if thermodynamic state is specified map thermodynamic
# state data to lambda values, else (for REX)
# define state based on the legend
- if 'Thermodynamic state' in df:
- ts_index = df.columns.get_loc('Thermodynamic state')
+ if "Thermodynamic state" in df:
+ ts_index = df.columns.get_loc("Thermodynamic state")
thermo_state = df[df.columns[ts_index]]
for i, l in enumerate(lambdas):
v = []
@@ -235,10 +245,10 @@ def extract_dHdl(xvg, T, filter=True):
dHdl[l] = statevec
# set up new multi-index
- newind = ['time'] + lambdas
- dHdl= dHdl.reset_index().set_index(newind)
+ newind = ["time"] + lambdas
+ dHdl = dHdl.reset_index().set_index(newind)
- dHdl.name='dH/dl'
+ dHdl.name = "dH/dl"
return dHdl
@@ -289,34 +299,44 @@ def _extract_state(xvg, headers=None):
state = None
if headers is None:
headers = _get_headers(xvg)
- subtitle = _get_value_by_key(headers, 'subtitle')
- if subtitle and 'state' in subtitle:
- state = int(subtitle.split('state')[1].split(':')[0])
- lambdas = [word.strip(')(,') for word in subtitle.split() if 'lambda' in word]
- statevec = eval(subtitle.strip().split(' = ')[-1].strip('"'))
+ subtitle = _get_value_by_key(headers, "subtitle")
+ if subtitle and "state" in subtitle:
+ state = int(subtitle.split("state")[1].split(":")[0])
+ lambdas = [word.strip(")(,") for word in subtitle.split() if "lambda" in word]
+ statevec = eval(subtitle.strip().split(" = ")[-1].strip('"'))
# if expanded ensemble data is used the state variable will never be assigned
# parsing expanded ensemble data
if state is None:
lambdas = []
statevec = []
- for line in headers['_raw_lines']:
- if ('legend' in line) and ('lambda' in line):
- lambdas.append([word.strip(')(,') for word in line.split() if 'lambda' in word][0])
- if ('legend' in line) and (' to ' in line):
- statevec.append(([float(i) for i in line.strip().split(' to ')[-1].strip('"()').split(',')]))
+ for line in headers["_raw_lines"]:
+ if ("legend" in line) and ("lambda" in line):
+ lambdas.append(
+ [word.strip(")(,") for word in line.split() if "lambda" in word][0]
+ )
+ if ("legend" in line) and (" to " in line):
+ statevec.append(
+ (
+ [
+ float(i)
+ for i in line.strip()
+ .split(" to ")[-1]
+ .strip('"()')
+ .split(",")
+ ]
+ )
+ )
return state, lambdas, statevec
def _extract_legend(xvg):
- """Extract information on state sampled for REX simulations.
-
- """
+ """Extract information on state sampled for REX simulations."""
state_legend = {}
- with anyopen(xvg, 'r') as f:
+ with anyopen(xvg, "r") as f:
for line in f:
- if ('legend' in line) and ('lambda' in line):
+ if ("legend" in line) and ("lambda" in line):
state_legend[line.split()[4]] = float(line.split()[6].strip('"'))
return state_legend
@@ -344,31 +364,49 @@ def _extract_dataframe(xvg, headers=None, filter=filter):
if headers is None:
headers = _get_headers(xvg)
- xaxis = _get_value_by_key(headers, 'xaxis', 'label')
- names = [_get_value_by_key(headers, 's{}'.format(x), 'legend') for x in
- range(len(headers)) if 's{}'.format(x) in headers]
+ xaxis = _get_value_by_key(headers, "xaxis", "label")
+ names = [
+ _get_value_by_key(headers, "s{}".format(x), "legend")
+ for x in range(len(headers))
+ if "s{}".format(x) in headers
+ ]
cols = [xaxis] + names
# march through column names, mark duplicates when found
- cols = [col + "{}[duplicated]".format(i) if col in cols[:i] else col
- for i, col, in enumerate(cols)]
+ cols = [
+ col + "{}[duplicated]".format(i) if col in cols[:i] else col
+ for i, col, in enumerate(cols)
+ ]
- header_cnt = len(headers['_raw_lines'])
+ header_cnt = len(headers["_raw_lines"])
if not filter:
# assumes clean files
- df = pd.read_csv(xvg, sep=r"\s+", header=None, skiprows=header_cnt,
- na_filter=True, memory_map=True, names=cols,
- dtype=np.float64,
- float_precision='high')
+ df = pd.read_csv(
+ xvg,
+ sep=r"\s+",
+ header=None,
+ skiprows=header_cnt,
+ na_filter=True,
+ memory_map=True,
+ names=cols,
+ dtype=np.float64,
+ float_precision="high",
+ )
else:
- df = pd.read_csv(xvg, sep=r"\s+", header=None, skiprows=header_cnt,
- memory_map=True, on_bad_lines='skip')
+ df = pd.read_csv(
+ xvg,
+ sep=r"\s+",
+ header=None,
+ skiprows=header_cnt,
+ memory_map=True,
+ on_bad_lines="skip",
+ )
# If names=cols is passed to read_csv, rows with more than the
# designated columns will be truncated and used instead of discarded.
df.rename(columns={i: name for i, name in enumerate(cols)}, inplace=True)
# If dtype=np.float64 and float_precision='high' are passed to read_csv,
# 12.345.56 and - cannot be read.
- df = df.apply(pd.to_numeric, errors='coerce')
+ df = df.apply(pd.to_numeric, errors="coerce")
# drop duplicate
df.dropna(inplace=True)
@@ -423,7 +461,7 @@ def _parse_header(line, headers={}, depth=2):
else:
break
- next_t["_val"] = ''.join(s[1:]).rstrip().strip('"')
+ next_t["_val"] = "".join(s[1:]).rstrip().strip('"')
def _get_headers(xvg):
@@ -484,17 +522,17 @@ def _get_headers(xvg):
headers: dict
"""
- with anyopen(xvg, 'r') as f:
- headers = { '_raw_lines': [] }
+ with anyopen(xvg, "r") as f:
+ headers = {"_raw_lines": []}
for line in f:
line = line.strip()
if len(line) == 0:
continue
- if line.startswith('@'):
+ if line.startswith("@"):
_parse_header(line, headers)
- headers['_raw_lines'].append(line)
- elif line.startswith('#'):
- headers['_raw_lines'].append(line)
+ headers["_raw_lines"].append(line)
+ elif line.startswith("#"):
+ headers["_raw_lines"].append(line)
continue
# assuming to start a body section
else:
@@ -522,8 +560,8 @@ def _get_value_by_key(headers, key1, key2=None):
val = None
if key1 in headers:
if key2 is not None and key2 in headers[key1]:
- val = headers[key1][key2]['_val']
+ val = headers[key1][key2]["_val"]
else:
- val = headers[key1]['_val']
+ val = headers[key1]["_val"]
return val
diff --git a/src/alchemlyb/parsing/gomc.py b/src/alchemlyb/parsing/gomc.py
index 7cf03af4..90124687 100644
--- a/src/alchemlyb/parsing/gomc.py
+++ b/src/alchemlyb/parsing/gomc.py
@@ -3,12 +3,13 @@
"""
import pandas as pd
-from .util import anyopen
from . import _init_attrs
+from .util import anyopen
from ..postprocessors.units import R_kJmol
k_b = R_kJmol
+
@_init_attrs
def extract_u_nk(filename, T):
"""Return reduced potentials `u_nk` from a Hamiltonian differences dat file.
@@ -34,9 +35,9 @@ def extract_u_nk(filename, T):
dh_col_match = "dU/dL"
h_col_match = "DelE"
- pv_col_match = 'PV'
- u_col_match = ['Total_En']
- beta = 1/(k_b * T)
+ pv_col_match = "PV"
+ u_col_match = ["Total_En"]
+ beta = 1 / (k_b * T)
state, lambdas, statevec = _extract_state(filename)
@@ -56,7 +57,11 @@ def extract_u_nk(filename, T):
pv = df[pv_cols[0]]
# GOMC also gives us total energy U directly; need this for reduced potential
- u_cols = [col for col in df.columns if any(single_u_col_match in col for single_u_col_match in u_col_match)]
+ u_cols = [
+ col
+ for col in df.columns
+ if any(single_u_col_match in col for single_u_col_match in u_col_match)
+ ]
u = None
if u_cols:
u = df[u_cols[0]]
@@ -64,7 +69,7 @@ def extract_u_nk(filename, T):
u_k = dict()
cols = list()
for col in dH:
- u_col = eval(col.split('->')[1][:-1])
+ u_col = eval(col.split("->")[1][:-1])
# calculate reduced potential u_k = dH + pV + U
u_k[u_col] = beta * dH[col].values
if pv_cols:
@@ -73,8 +78,9 @@ def extract_u_nk(filename, T):
u_k[u_col] += beta * u.values
cols.append(u_col)
- u_k = pd.DataFrame(u_k, columns=cols,
- index=pd.Index(times.values, name='time', dtype='Float64'))
+ u_k = pd.DataFrame(
+ u_k, columns=cols, index=pd.Index(times.values, name="time", dtype="Float64")
+ )
# Need to modify the lambda name
cols = [l + "-lambda" for l in lambdas]
@@ -83,13 +89,14 @@ def extract_u_nk(filename, T):
u_k[l] = statevec[i]
# set up new multi-index
- newind = ['time'] + cols
+ newind = ["time"] + cols
u_k = u_k.reset_index().set_index(newind)
- u_k.name = 'u_nk'
+ u_k.name = "u_nk"
return u_k
+
@_init_attrs
def extract_dHdl(filename, T):
"""Return gradients `dH/dl` from a Hamiltonian differences free energy file.
@@ -112,7 +119,7 @@ def extract_dHdl(filename, T):
the constants used by the corresponding MD engine.
"""
- beta = 1/(k_b * T)
+ beta = 1 / (k_b * T)
state, lambdas, statevec = _extract_state(filename)
@@ -131,8 +138,11 @@ def extract_dHdl(filename, T):
# make dimensionless
dHdl *= beta
- dHdl = pd.DataFrame(dHdl.values, columns=lambdas,
- index=pd.Index(times.values, name='time', dtype='Float64'))
+ dHdl = pd.DataFrame(
+ dHdl.values,
+ columns=lambdas,
+ index=pd.Index(times.values, name="time", dtype="Float64"),
+ )
# Need to modify the lambda name
cols = [l + "-lambda" for l in lambdas]
@@ -141,10 +151,10 @@ def extract_dHdl(filename, T):
dHdl[l] = statevec[i]
# set up new multi-index
- newind = ['time'] + cols
- dHdl= dHdl.reset_index().set_index(newind)
+ newind = ["time"] + cols
+ dHdl = dHdl.reset_index().set_index(newind)
- dHdl.name='dH/dl'
+ dHdl.name = "dH/dl"
return dHdl
@@ -180,33 +190,29 @@ def extract(filename, T):
def _extract_state(filename):
- """Extract information on state sampled, names of lambdas.
-
- """
+ """Extract information on state sampled, names of lambdas."""
state = None
- with anyopen(filename, 'r') as f:
+ with anyopen(filename, "r") as f:
for line in f:
- if ('#' in line) and ('State' in line):
- state = int(line.split('State')[1].split(':')[0])
+ if ("#" in line) and ("State" in line):
+ state = int(line.split("State")[1].split(":")[0])
# GOMC always print these two fields
- lambdas = ['Coulomb', 'VDW']
- statevec = eval(line.strip().split(' = ')[-1])
+ lambdas = ["Coulomb", "VDW"]
+ statevec = eval(line.strip().split(" = ")[-1])
break
return state, lambdas, statevec
def _extract_dataframe(filename):
- """Extract a DataFrame from free energy data.
-
- """
+ """Extract a DataFrame from free energy data."""
dh_col_match = "dU/dL"
h_col_match = "DelE"
- pv_col_match = 'PV'
- u_col_match = 'Total_En'
+ pv_col_match = "PV"
+ u_col_match = "Total_En"
xaxis = "time"
- with anyopen(filename, 'r') as f:
+ with anyopen(filename, "r") as f:
names = []
rows = []
for line in f:
@@ -214,7 +220,7 @@ def _extract_dataframe(filename):
if len(line) == 0:
# avoid parsing empty line
continue
- elif line.startswith('#T'):
+ elif line.startswith("#T"):
# this line has state information. No need to be parsed
continue
elif line.startswith("#Steps"):
diff --git a/src/alchemlyb/parsing/namd.py b/src/alchemlyb/parsing/namd.py
index c4181b1d..1647467c 100644
--- a/src/alchemlyb/parsing/namd.py
+++ b/src/alchemlyb/parsing/namd.py
@@ -1,13 +1,15 @@
"""Parsers for extracting alchemical data from `NAMD `_ output files.
"""
-import pandas as pd
-import numpy as np
+import logging
from os.path import basename
from re import split
-import logging
-from .util import anyopen
+
+import numpy as np
+import pandas as pd
+
from . import _init_attrs
+from .util import anyopen
from ..postprocessors.units import R_kJmol, kJ2kcal
logger = logging.getLogger("alchemlyb.parsers.NAMD")
@@ -21,12 +23,12 @@ def _filename_sort_key(s):
This means that unlike with the standard Python sorted() function, "foo9" < "foo10".
"""
- return [int(t) if t.isdigit() else t.lower() for t in split(r'(\d+)', basename(s))]
+ return [int(t) if t.isdigit() else t.lower() for t in split(r"(\d+)", basename(s))]
def _get_lambdas(fep_files):
"""Retrieves all lambda values included in the FEP files provided.
-
+
We have to do this in order to tolerate truncated and restarted fepout files.
The IDWS lambda is not present at the termination of the window, presumably
for backwards compatibility with ParseFEP and probably other things.
@@ -48,25 +50,25 @@ def _get_lambdas(fep_files):
endpoint_windows = []
for fep_file in sorted(fep_files, key=_filename_sort_key):
- with anyopen(fep_file, 'r') as f:
+ with anyopen(fep_file, "r") as f:
for line in f:
l = line.strip().split()
# We might not have a #NEW line so make the best guess
- if l[0] == '#NEW':
+ if l[0] == "#NEW":
lambda1, lambda2 = float(l[6]), float(l[8])
- lambda_idws = float(l[10]) if 'LAMBDA_IDWS' in l else None
- elif l[0] == '#Free':
+ lambda_idws = float(l[10]) if "LAMBDA_IDWS" in l else None
+ elif l[0] == "#Free":
lambda1, lambda2, lambda_idws = float(l[7]), float(l[8]), None
else:
# We only care about lines with lambda values. No need to
# do all that other processing below for every line
- continue # pragma: no cover
+ continue # pragma: no cover
# Keep track of whether the lambda values are increasing or decreasing, so we can return
# a sorted list of the lambdas in the correct order.
# If it changes during parsing of this set of fepout files, then we know something is wrong
-
+
# Keep track of endpoints separately since in IDWS runs there must be one of opposite direction
if 0.0 in (lambda1, lambda2) or 1.0 in (lambda1, lambda2):
endpoint_windows.append((lambda1, lambda2))
@@ -78,23 +80,35 @@ def _get_lambdas(fep_files):
is_ascending.add(lambda1 > lambda_idws)
if len(is_ascending) > 1:
- raise ValueError(f'Lambda values change direction in {fep_file}, relative to the other files: {lambda1} -> {lambda2} (IDWS: {lambda_idws})')
+ raise ValueError(
+ f"Lambda values change direction in {fep_file}, relative to the other files: {lambda1} -> {lambda2} (IDWS: {lambda_idws})"
+ )
# Make sure the lambda2 values are consistent
if lambda1 in lambda_fwd_map and lambda_fwd_map[lambda1] != lambda2:
- logger.error(f'fwd: lambda1 {lambda1} has lambda2 {lambda_fwd_map[lambda1]} in {fep_file} but it has already been {lambda2}')
- raise ValueError('More than one lambda2 value for a particular lambda1')
+ logger.error(
+ f"fwd: lambda1 {lambda1} has lambda2 {lambda_fwd_map[lambda1]} in {fep_file} but it has already been {lambda2}"
+ )
+ raise ValueError(
+ "More than one lambda2 value for a particular lambda1"
+ )
lambda_fwd_map[lambda1] = lambda2
# Make sure the lambda_idws values are consistent
if lambda_idws is not None:
- if lambda1 in lambda_bwd_map and lambda_bwd_map[lambda1] != lambda_idws:
- logger.error(f'bwd: lambda1 {lambda1} has lambda_idws {lambda_bwd_map[lambda1]} but it has already been {lambda_idws}')
- raise ValueError('More than one lambda_idws value for a particular lambda1')
+ if (
+ lambda1 in lambda_bwd_map
+ and lambda_bwd_map[lambda1] != lambda_idws
+ ):
+ logger.error(
+ f"bwd: lambda1 {lambda1} has lambda_idws {lambda_bwd_map[lambda1]} but it has already been {lambda_idws}"
+ )
+ raise ValueError(
+ "More than one lambda_idws value for a particular lambda1"
+ )
lambda_bwd_map[lambda1] = lambda_idws
-
is_ascending = next(iter(is_ascending))
all_lambdas = set()
@@ -147,7 +161,7 @@ def extract_u_nk(fep_files, T):
`fep_files` can now be a list of filenames.
"""
- beta = 1/(k_b * T)
+ beta = 1 / (k_b * T)
# lists to get times and work values of each window
win_ts = []
@@ -156,7 +170,7 @@ def extract_u_nk(fep_files, T):
win_de_back = []
# create dataframe for results
- u_nk = pd.DataFrame(columns=['time','fep-lambda'])
+ u_nk = pd.DataFrame(columns=["time", "fep-lambda"])
# boolean flag to parse data after equil time
parsing = False
@@ -176,32 +190,36 @@ def extract_u_nk(fep_files, T):
for fep_file in sorted(fep_files, key=_filename_sort_key):
# Note we have not set parsing=False because we could be continuing one window across
# more than one fepout file
- with anyopen(fep_file, 'r') as f:
+ with anyopen(fep_file, "r") as f:
has_idws = False
for line in f:
l = line.strip().split()
# We don't know if IDWS was enabled just from the #Free line, and we might not have
# a #NEW line in this file, so we have to check for the existence of FepE_back lines
# We rely on short-circuit evaluation to avoid the string comparison most of the time
- if has_idws is False and l[0] == 'FepE_back:':
+ if has_idws is False and l[0] == "FepE_back:":
has_idws = True
# New window, get IDWS lambda if any
# We keep track of lambdas from the #NEW line and if they disagree with the #Free line
# within the same file, then complain. This can happen if truncated fepout files
# are presented in the wrong order.
- if l[0] == '#NEW':
+ if l[0] == "#NEW":
if parsing:
- logger.error(f'Window with lambda1: {lambda1_at_start} lambda2: {lambda2_at_start} lambda_idws: {lambda_idws_at_start} appears truncated')
- logger.error(f'because a new window was encountered in {fep_file} before the previous one finished.')
- raise ValueError('New window begun after truncated window')
+ logger.error(
+ f"Window with lambda1: {lambda1_at_start} lambda2: {lambda2_at_start} lambda_idws: {lambda_idws_at_start} appears truncated"
+ )
+ logger.error(
+ f"because a new window was encountered in {fep_file} before the previous one finished."
+ )
+ raise ValueError("New window begun after truncated window")
lambda1_at_start, lambda2_at_start = float(l[6]), float(l[8])
- lambda_idws_at_start = float(l[10]) if 'LAMBDA_IDWS' in l else None
+ lambda_idws_at_start = float(l[10]) if "LAMBDA_IDWS" in l else None
has_idws = True if lambda_idws_at_start is not None else False
# this line marks end of window; dump data into dataframe
- if l[0] == '#Free':
+ if l[0] == "#Free":
# extract lambda values for finished window
# lambda1 = sampling lambda (row), lambda2 = comparison lambda (col)
lambda1 = float(l[7])
@@ -210,17 +228,25 @@ def extract_u_nk(fep_files, T):
# If the lambdas are not what we thought they would be, raise an exception to ensure the calculation
# fails. This can happen if fepouts where one window spans multiple fepouts are processed out of order
# NB: There is no way to tell if lambda_idws changed because it isn't in the '#Free' line that ends a window
- if lambda1_at_start is not None \
- and (lambda1, lambda2) != (lambda1_at_start, lambda2_at_start):
- logger.error(f"Lambdas changed unexpectedly while processing {fep_file}")
- logger.error(f"l1, l2: {lambda1_at_start}, {lambda2_at_start} changed to {lambda1}, {lambda2}")
+ if lambda1_at_start is not None and (lambda1, lambda2) != (
+ lambda1_at_start,
+ lambda2_at_start,
+ ):
+ logger.error(
+ f"Lambdas changed unexpectedly while processing {fep_file}"
+ )
+ logger.error(
+ f"l1, l2: {lambda1_at_start}, {lambda2_at_start} changed to {lambda1}, {lambda2}"
+ )
logger.error(line)
- raise ValueError("Inconsistent lambda values within the same window")
+ raise ValueError(
+ "Inconsistent lambda values within the same window"
+ )
# As we are at the end of a window, convert last window's work and times values to np arrays
# (with energy unit kT since they were kcal/mol in the fepouts)
- win_de_arr = beta * np.asarray(win_de) # dE values
- win_ts_arr = np.asarray(win_ts) # timesteps
+ win_de_arr = beta * np.asarray(win_de) # dE values
+ win_ts_arr = np.asarray(win_ts) # timesteps
# This handles the special case where there are IDWS energies but no lambda_idws value in the
# current .fepout file. This can happen when the NAMD firsttimestep is not 0, because NAMD only emits
@@ -236,10 +262,16 @@ def extract_u_nk(fep_files, T):
# Test for the highly pathological case where the first window is both incomplete and has IDWS
# data but no lambda_idws value.
if l1_idx == 0:
- raise ValueError(f'IDWS data present in first window but lambda_idws not included; no way to infer the correct lambda_idws')
+ raise ValueError(
+ f"IDWS data present in first window but lambda_idws not included; no way to infer the correct lambda_idws"
+ )
lambda_idws_at_start = all_lambdas[l1_idx - 1]
- logger.warning(f'Warning: {fep_file} has IDWS data but lambda_idws not included.')
- logger.warning(f' lambda1 = {lambda1}, lambda2 = {lambda2}; inferring lambda_idws to be {lambda_idws_at_start}')
+ logger.warning(
+ f"Warning: {fep_file} has IDWS data but lambda_idws not included."
+ )
+ logger.warning(
+ f" lambda1 = {lambda1}, lambda2 = {lambda2}; inferring lambda_idws to be {lambda_idws_at_start}"
+ )
if lambda_idws_at_start is not None:
# Mimic classic DWS data
@@ -248,22 +280,28 @@ def extract_u_nk(fep_files, T):
win_de_back_arr = beta * np.asarray(win_de_back)
n = min(len(win_de_back_arr), len(win_de_arr))
- tempDF = pd.DataFrame({
- 'time': win_ts_arr[:n],
- 'fep-lambda': np.full(n,lambda1),
- lambda1: 0,
- lambda2: win_de_arr[:n],
- lambda_idws_at_start: win_de_back_arr[:n]})
+ tempDF = pd.DataFrame(
+ {
+ "time": win_ts_arr[:n],
+ "fep-lambda": np.full(n, lambda1),
+ lambda1: 0,
+ lambda2: win_de_arr[:n],
+ lambda_idws_at_start: win_de_back_arr[:n],
+ }
+ )
# print(f"{fep_file}: IDWS window {lambda1} {lambda2} {lambda_idws_at_start}")
else:
# print(f"{fep_file}: Forward-only window {lambda1} {lambda2}")
# create dataframe of times and work values
# this window's data goes in row LAMBDA1 and column LAMBDA2
- tempDF = pd.DataFrame({
- 'time': win_ts_arr,
- 'fep-lambda': np.full(len(win_de_arr), lambda1),
- lambda1: 0,
- lambda2: win_de_arr})
+ tempDF = pd.DataFrame(
+ {
+ "time": win_ts_arr,
+ "fep-lambda": np.full(len(win_de_arr), lambda1),
+ lambda1: 0,
+ lambda2: win_de_arr,
+ }
+ )
# join the new window's df to existing df
u_nk = pd.concat([u_nk, tempDF], sort=False)
@@ -275,38 +313,42 @@ def extract_u_nk(fep_files, T):
win_ts_back = []
parsing = False
has_idws = False
- lambda1_at_start, lambda2_at_start, lambda_idws_at_start = None, None, None
+ lambda1_at_start, lambda2_at_start, lambda_idws_at_start = (
+ None,
+ None,
+ None,
+ )
# append work value from 'dE' column of fepout file
if parsing:
- if l[0] == 'FepEnergy:':
+ if l[0] == "FepEnergy:":
win_de.append(float(l[6]))
win_ts.append(float(l[1]))
- elif l[0] == 'FepE_back:':
+ elif l[0] == "FepE_back:":
win_de_back.append(float(l[6]))
win_ts_back.append(float(l[1]))
# Turn parsing on after line 'STARTING COLLECTION OF ENSEMBLE AVERAGE'
- if '#STARTING' in l:
+ if "#STARTING" in l:
parsing = True
- if len(win_de) != 0 or len(win_de_back) != 0: # pragma: no cover
- logger.warning('Trailing data without footer line (\"#Free energy...\"). Interrupted run?')
- raise ValueError('Last window is truncated')
-
+ if len(win_de) != 0 or len(win_de_back) != 0: # pragma: no cover
+ logger.warning(
+ 'Trailing data without footer line ("#Free energy..."). Interrupted run?'
+ )
+ raise ValueError("Last window is truncated")
if lambda2 in (0.0, 1.0):
# this excludes the IDWS case where a dataframe already exists for both endpoints
# create last dataframe for fep-lambda at last LAMBDA2
- tempDF = pd.DataFrame({
- 'time': win_ts_arr,
- 'fep-lambda': lambda2})
+ tempDF = pd.DataFrame({"time": win_ts_arr, "fep-lambda": lambda2})
u_nk = pd.concat([u_nk, tempDF], sort=True)
- u_nk.set_index(['time','fep-lambda'], inplace=True)
+ u_nk.set_index(["time", "fep-lambda"], inplace=True)
return u_nk
+
def extract(fep_files, T):
"""Return reduced potentials `u_nk` from NAMD fepout file(s).
@@ -342,4 +384,6 @@ def extract(fep_files, T):
.. versionadded:: 1.0.0
"""
- return {"u_nk": extract_u_nk(fep_files, T)} # NOTE: maybe we should also have 'dHdl': None
+ return {
+ "u_nk": extract_u_nk(fep_files, T)
+ } # NOTE: maybe we should also have 'dHdl': None
diff --git a/src/alchemlyb/parsing/util.py b/src/alchemlyb/parsing/util.py
index f8259aa6..28e5a568 100644
--- a/src/alchemlyb/parsing/util.py
+++ b/src/alchemlyb/parsing/util.py
@@ -1,23 +1,24 @@
"""Collection of utilities used by many parsers.
"""
-import os
-from os import PathLike
-from typing import IO, Optional, Union
import bz2
import gzip
+import os
+from os import PathLike
+from typing import IO, Union
+
def bz2_open(filename, mode):
- mode += 't' if mode in ['r','w','a','x'] else ''
+ mode += "t" if mode in ["r", "w", "a", "x"] else ""
return bz2.open(filename, mode)
def gzip_open(filename, mode):
- mode += 't' if mode in ['r','w','a','x'] else ''
+ mode += "t" if mode in ["r", "w", "a", "x"] else ""
return gzip.open(filename, mode)
-def anyopen(datafile: Union[PathLike, IO], mode='r', compression=None):
+def anyopen(datafile: Union[PathLike, IO], mode="r", compression=None):
"""Return a file stream for file or stream, even if compressed.
Supports files compressed with bzip2 (.bz2) and gzip (.gz) compression
@@ -59,16 +60,15 @@ def anyopen(datafile: Union[PathLike, IO], mode='r', compression=None):
"""
# opener for each type of file
- extensions = {'.bz2': bz2_open,
- '.gz': gzip_open}
+ extensions = {".bz2": bz2_open, ".gz": gzip_open}
# compression selections available
- compressions = {'bzip2': bz2_open,
- 'gzip': gzip_open}
+ compressions = {"bzip2": bz2_open, "gzip": gzip_open}
# if `datafile` is a stream
- if ((hasattr(datafile, 'read') and any((i in mode for i in ('r',)))) or
- (hasattr(datafile, 'write') and any((i in mode for i in ('w', 'a', 'x'))))):
+ if (hasattr(datafile, "read") and any((i in mode for i in ("r",)))) or (
+ hasattr(datafile, "write") and any((i in mode for i in ("w", "a", "x")))
+ ):
# if no compression specified, just pass the stream through
if compression is None:
return datafile
@@ -76,7 +76,9 @@ def anyopen(datafile: Union[PathLike, IO], mode='r', compression=None):
compressor = compressions[compression]
return compressor(datafile, mode=mode)
else:
- raise ValueError("`datafile` is a stream, but specified `compression` '{compression}' is not supported")
+ raise ValueError(
+ "`datafile` is a stream, but specified `compression` '{compression}' is not supported"
+ )
# otherwise, treat as a file
# allow compression to override any extension on the file
diff --git a/src/alchemlyb/postprocessors/__init__.py b/src/alchemlyb/postprocessors/__init__.py
index 6e769ac4..932d2b06 100644
--- a/src/alchemlyb/postprocessors/__init__.py
+++ b/src/alchemlyb/postprocessors/__init__.py
@@ -1,3 +1,3 @@
__all__ = [
- 'units',
+ "units",
]
diff --git a/src/alchemlyb/postprocessors/units.py b/src/alchemlyb/postprocessors/units.py
index f5e1984d..510b4465 100644
--- a/src/alchemlyb/postprocessors/units.py
+++ b/src/alchemlyb/postprocessors/units.py
@@ -12,8 +12,9 @@
#: in :mod:`scipy.constants`
R_kJmol = R / 1000
+
def to_kT(df, T=None):
- """ Convert the unit of a DataFrame to `kT`.
+ """Convert the unit of a DataFrame to `kT`.
If temperature `T` is not provided, the DataFrame need to have attribute
`temperature` and `energy_unit`. Otherwise, the temperature of the output
@@ -33,28 +34,28 @@ def to_kT(df, T=None):
"""
new_df = df.copy()
if T is not None:
- new_df.attrs['temperature'] = T
- elif 'temperature' not in df.attrs:
- raise TypeError('Attribute temperature not found in the input '
- 'Dataframe.')
+ new_df.attrs["temperature"] = T
+ elif "temperature" not in df.attrs:
+ raise TypeError("Attribute temperature not found in the input " "Dataframe.")
- if 'energy_unit' not in df.attrs:
- raise TypeError('Attribute energy_unit not found in the input '
- 'Dataframe.')
+ if "energy_unit" not in df.attrs:
+ raise TypeError("Attribute energy_unit not found in the input " "Dataframe.")
- if df.attrs['energy_unit'] == 'kT':
+ if df.attrs["energy_unit"] == "kT":
return new_df
- elif df.attrs['energy_unit'] == 'kJ/mol':
- new_df /= R_kJmol * df.attrs['temperature']
- new_df.attrs['energy_unit'] = 'kT'
+ elif df.attrs["energy_unit"] == "kJ/mol":
+ new_df /= R_kJmol * df.attrs["temperature"]
+ new_df.attrs["energy_unit"] = "kT"
return new_df
- elif df.attrs['energy_unit'] == 'kcal/mol':
- new_df /= R_kJmol * df.attrs['temperature'] * kJ2kcal
- new_df.attrs['energy_unit'] = 'kT'
+ elif df.attrs["energy_unit"] == "kcal/mol":
+ new_df /= R_kJmol * df.attrs["temperature"] * kJ2kcal
+ new_df.attrs["energy_unit"] = "kT"
return new_df
else:
- raise ValueError('energy_unit {} can only be kT, kJ/mol or ' \
- 'kcal/mol.'.format(df.attrs['energy_unit']))
+ raise ValueError(
+ "energy_unit {} can only be kT, kJ/mol or "
+ "kcal/mol.".format(df.attrs["energy_unit"])
+ )
def to_kcalmol(df, T=None):
@@ -77,10 +78,11 @@ def to_kcalmol(df, T=None):
`df` converted.
"""
kt_df = to_kT(df, T)
- kt_df *= R_kJmol * df.attrs['temperature'] * kJ2kcal
- kt_df.attrs['energy_unit'] = 'kcal/mol'
+ kt_df *= R_kJmol * df.attrs["temperature"] * kJ2kcal
+ kt_df.attrs["energy_unit"] = "kcal/mol"
return kt_df
+
def to_kJmol(df, T=None):
"""Convert the unit of a DataFrame to kJ/mol.
@@ -101,12 +103,13 @@ def to_kJmol(df, T=None):
`df` converted.
"""
kt_df = to_kT(df, T)
- kt_df *= R_kJmol * df.attrs['temperature']
- kt_df.attrs['energy_unit'] = 'kJ/mol'
+ kt_df *= R_kJmol * df.attrs["temperature"]
+ kt_df.attrs["energy_unit"] = "kJ/mol"
return kt_df
+
def get_unit_converter(units):
- """ Obtain the converter according to the unit string.
+ """Obtain the converter according to the unit string.
If `units` is 'kT', the `to_kT` converter is returned. If `units` is
'kJ/mol', the `to_kJmol` converter is returned. If `units` is 'kcal/mol',
@@ -125,12 +128,12 @@ def get_unit_converter(units):
.. versionadded:: 0.5.0
"""
- converters = {'kT': to_kT, 'kJ/mol': to_kJmol,
- 'kcal/mol': to_kcalmol}
+ converters = {"kT": to_kT, "kJ/mol": to_kJmol, "kcal/mol": to_kcalmol}
try:
convert = converters[units]
except KeyError:
raise ValueError(
f"Energy unit {units} is not supported, "
- f"choose one of {list(converters.keys())}")
+ f"choose one of {list(converters.keys())}"
+ )
return convert
diff --git a/src/alchemlyb/preprocessing/__init__.py b/src/alchemlyb/preprocessing/__init__.py
index 6b759482..223c942e 100644
--- a/src/alchemlyb/preprocessing/__init__.py
+++ b/src/alchemlyb/preprocessing/__init__.py
@@ -3,16 +3,22 @@
preparing data for estimators.
"""
-from .subsampling import slicing, dhdl2series, u_nk2series, decorrelate_dhdl, decorrelate_u_nk
-from .subsampling import statistical_inefficiency
from .subsampling import equilibrium_detection
+from .subsampling import (
+ slicing,
+ dhdl2series,
+ u_nk2series,
+ decorrelate_dhdl,
+ decorrelate_u_nk,
+)
+from .subsampling import statistical_inefficiency
__all__ = [
- 'slicing',
- 'statistical_inefficiency',
- 'equilibrium_detection',
- 'decorrelate_dhdl',
- 'decorrelate_u_nk',
- 'dhdl2series',
- 'u_nk2series'
+ "slicing",
+ "statistical_inefficiency",
+ "equilibrium_detection",
+ "decorrelate_dhdl",
+ "decorrelate_u_nk",
+ "dhdl2series",
+ "u_nk2series",
]
diff --git a/src/alchemlyb/preprocessing/subsampling.py b/src/alchemlyb/preprocessing/subsampling.py
index 3b8bd8af..653adbca 100644
--- a/src/alchemlyb/preprocessing/subsampling.py
+++ b/src/alchemlyb/preprocessing/subsampling.py
@@ -4,13 +4,18 @@
import warnings
import pandas as pd
-from pymbar.timeseries import (statisticalInefficiency,
- detectEquilibration,
- subsampleCorrelatedData, )
+from pymbar.timeseries import (
+ statisticalInefficiency,
+ detectEquilibration,
+ subsampleCorrelatedData,
+)
+
from .. import pass_attrs
-def decorrelate_u_nk(df, method='dE', drop_duplicates=True,
- sort=True, remove_burnin=False, **kwargs):
+
+def decorrelate_u_nk(
+ df, method="dE", drop_duplicates=True, sort=True, remove_burnin=False, **kwargs
+):
"""Subsample an u_nk DataFrame based on the selected method.
The method can be either 'all' (obtained as a sum over all energy
@@ -57,8 +62,8 @@ def decorrelate_u_nk(df, method='dE', drop_duplicates=True,
deprecate the 'dhdl'.
"""
- kwargs['drop_duplicates'] = drop_duplicates
- kwargs['sort'] = sort
+ kwargs["drop_duplicates"] = drop_duplicates
+ kwargs["sort"] = sort
series = u_nk2series(df, method)
@@ -67,8 +72,10 @@ def decorrelate_u_nk(df, method='dE', drop_duplicates=True,
else:
return statistical_inefficiency(df, series, **kwargs)
-def decorrelate_dhdl(df, drop_duplicates=True, sort=True,
- remove_burnin=False, **kwargs):
+
+def decorrelate_dhdl(
+ df, drop_duplicates=True, sort=True, remove_burnin=False, **kwargs
+):
"""Subsample a dhdl DataFrame.
This is a wrapper function around the function
:func:`~alchemlyb.preprocessing.subsampling.statistical_inefficiency` and
@@ -111,8 +118,8 @@ def decorrelate_dhdl(df, drop_duplicates=True, sort=True,
"""
- kwargs['drop_duplicates'] = drop_duplicates
- kwargs['sort'] = sort
+ kwargs["drop_duplicates"] = drop_duplicates
+ kwargs["sort"] = sort
series = dhdl2series(df)
@@ -121,8 +128,9 @@ def decorrelate_dhdl(df, drop_duplicates=True, sort=True,
else:
return statistical_inefficiency(df, series, **kwargs)
+
@pass_attrs
-def u_nk2series(df, method='dE'):
+def u_nk2series(df, method="dE"):
"""Convert an u_nk DataFrame into a series based on the selected method
for subsampling.
@@ -152,18 +160,22 @@ def u_nk2series(df, method='dE'):
# deprecation: remove in 3.0.0
# (the deprecations should show up in the calling functions)
- if method == 'dhdl':
- warnings.warn("Method 'dhdl' has been deprecated, using 'dE' instead. "
- "'dhdl' will be removed in alchemlyb 3.0.0.",
- category=DeprecationWarning,
- stacklevel=2)
- method = 'dE'
- elif method == 'dhdl_all':
- warnings.warn("Method 'dhdl_all' has been deprecated, using 'all' instead. "
- "'dhdl_all' will be removed in alchemlyb 3.0.0.",
- category=DeprecationWarning,
- stacklevel=2)
- method = 'all'
+ if method == "dhdl":
+ warnings.warn(
+ "Method 'dhdl' has been deprecated, using 'dE' instead. "
+ "'dhdl' will be removed in alchemlyb 3.0.0.",
+ category=DeprecationWarning,
+ stacklevel=2,
+ )
+ method = "dE"
+ elif method == "dhdl_all":
+ warnings.warn(
+ "Method 'dhdl_all' has been deprecated, using 'all' instead. "
+ "'dhdl_all' will be removed in alchemlyb 3.0.0.",
+ category=DeprecationWarning,
+ stacklevel=2,
+ )
+ method = "all"
# Check if the input is u_nk
try:
@@ -172,11 +184,11 @@ def u_nk2series(df, method='dE'):
key = key[0]
df[key]
except KeyError:
- raise ValueError('The input should be u_nk')
+ raise ValueError("The input should be u_nk")
- if method == 'all':
+ if method == "all":
series = df.sum(axis=1)
- elif method == 'dE':
+ elif method == "dE":
# Using the same logic as alchemical-analysis
key = df.index.values[0][1:]
if len(key) == 1:
@@ -192,13 +204,12 @@ def u_nk2series(df, method='dE'):
else:
series = df.iloc[:, index - 1]
else:
- raise ValueError(
- 'Decorrelation method {} not found.'.format(method))
+ raise ValueError("Decorrelation method {} not found.".format(method))
return series
@pass_attrs
-def dhdl2series(df, method='all'):
+def dhdl2series(df, method="all"):
"""Convert a dhdl DataFrame to a series for subsampling.
The series is generated by summing over all energy components (axis 1 of
@@ -235,13 +246,15 @@ def dhdl2series(df, method='all'):
def _check_multiple_times(df):
if isinstance(df, pd.Series):
- return df.sort_index(axis=0).reset_index('time', name='').duplicated('time').any()
+ return (
+ df.sort_index(axis=0).reset_index("time", name="").duplicated("time").any()
+ )
else:
- return df.sort_index(axis=0).reset_index('time').duplicated('time').any()
+ return df.sort_index(axis=0).reset_index("time").duplicated("time").any()
def _check_sorted(df):
- return df.reset_index(0)['time'].is_monotonic_increasing
+ return df.reset_index(0)["time"].is_monotonic_increasing
def _drop_duplicates(df, series=None):
@@ -265,34 +278,44 @@ def _drop_duplicates(df, series=None):
"""
if isinstance(df, pd.Series):
# remove the duplicate based on time
- drop_duplicates_series = df.reset_index('time', name=''). \
- drop_duplicates('time')
+ drop_duplicates_series = df.reset_index("time", name="").drop_duplicates("time")
# Rest the time index
- lambda_names = ['time', ]
+ lambda_names = [
+ "time",
+ ]
lambda_names.extend(drop_duplicates_series.index.names)
- df = drop_duplicates_series.set_index('time', append=True). \
- reorder_levels(lambda_names)
+ df = drop_duplicates_series.set_index("time", append=True).reorder_levels(
+ lambda_names
+ )
else:
# remove the duplicate based on time
- drop_duplicates_df = df.reset_index('time').drop_duplicates('time')
+ drop_duplicates_df = df.reset_index("time").drop_duplicates("time")
# Rest the time index
- lambda_names = ['time', ]
+ lambda_names = [
+ "time",
+ ]
lambda_names.extend(drop_duplicates_df.index.names)
- df = drop_duplicates_df.set_index('time', append=True). \
- reorder_levels(lambda_names)
+ df = drop_duplicates_df.set_index("time", append=True).reorder_levels(
+ lambda_names
+ )
# Do the same withing with the series
if series is not None:
# remove the duplicate based on time
- drop_duplicates_series = series.reset_index('time', name=''). \
- drop_duplicates('time')
+ drop_duplicates_series = series.reset_index("time", name="").drop_duplicates(
+ "time"
+ )
# Rest the time index
- lambda_names = ['time', ]
+ lambda_names = [
+ "time",
+ ]
lambda_names.extend(drop_duplicates_series.index.names)
- series = drop_duplicates_series.set_index('time', append=True). \
- reorder_levels(lambda_names)
+ series = drop_duplicates_series.set_index("time", append=True).reorder_levels(
+ lambda_names
+ )
return df, series
+
def _sort_by_time(df, series=None):
"""Sort the ``df`` by time which could be Dataframe or
Series, if series is provided, sort the series as well.
@@ -311,12 +334,13 @@ def _sort_by_time(df, series=None):
series : Series
Formatted Series.
"""
- df = df.sort_index(level='time')
+ df = df.sort_index(level="time")
if series is not None:
- series = series.sort_index(level='time')
+ series = series.sort_index(level="time")
return df, series
+
def _prepare_input(df, series, drop_duplicates, sort):
"""Prepare and check the input to be used for statistical_inefficiency or equilibrium_detection.
@@ -341,7 +365,8 @@ def _prepare_input(df, series, drop_duplicates, sort):
raise KeyError(
"Duplicate time values found; statistical inefficiency "
"only works on a single, contiguous, "
- "and sorted timeseries.")
+ "and sorted timeseries."
+ )
if not _check_sorted(df):
if sort:
@@ -349,16 +374,17 @@ def _prepare_input(df, series, drop_duplicates, sort):
else:
raise KeyError(
"Statistical inefficiency only works as expected if "
- "values are sorted by time, increasing.")
+ "values are sorted by time, increasing."
+ )
if series is not None:
- if (len(series) != len(df) or
- not all(
- series.reset_index()['time'] == df.reset_index()['time'])):
- raise ValueError(
- "series and data must be sampled at the same times")
+ if len(series) != len(df) or not all(
+ series.reset_index()["time"] == df.reset_index()["time"]
+ ):
+ raise ValueError("series and data must be sampled at the same times")
return df, series
+
def slicing(df, lower=None, upper=None, step=None, force=False):
"""Subsample a DataFrame using simple slicing.
@@ -390,16 +416,25 @@ def slicing(df, lower=None, upper=None, step=None, force=False):
raise KeyError("DataFrame rows must be sorted by time, increasing.")
if not force and _check_multiple_times(df):
- raise KeyError("Duplicate time values found; it's generally advised "
- "to use slicing on DataFrames with unique time values "
- "for each row. Use `force=True` to ignore this error.")
+ raise KeyError(
+ "Duplicate time values found; it's generally advised "
+ "to use slicing on DataFrames with unique time values "
+ "for each row. Use `force=True` to ignore this error."
+ )
return df
-def statistical_inefficiency(df, series=None, lower=None, upper=None,
- step=None, conservative=True,
- drop_duplicates=False, sort=False):
+def statistical_inefficiency(
+ df,
+ series=None,
+ lower=None,
+ upper=None,
+ step=None,
+ conservative=True,
+ drop_duplicates=False,
+ sort=False,
+):
"""Subsample a DataFrame based on the calculated statistical inefficiency
of a timeseries.
@@ -480,8 +515,7 @@ def statistical_inefficiency(df, series=None, lower=None, upper=None,
statinef = statisticalInefficiency(series, fast=False)
# use the subsampleCorrelatedData function to get the subsample index
- indices = subsampleCorrelatedData(series, g=statinef,
- conservative=conservative)
+ indices = subsampleCorrelatedData(series, g=statinef, conservative=conservative)
df = df.iloc[indices]
else:
df = slicing(df, lower=lower, upper=upper, step=step)
@@ -489,8 +523,15 @@ def statistical_inefficiency(df, series=None, lower=None, upper=None,
return df
-def equilibrium_detection(df, series=None, lower=None, upper=None, step=None,
- drop_duplicates=False, sort=False):
+def equilibrium_detection(
+ df,
+ series=None,
+ lower=None,
+ upper=None,
+ step=None,
+ drop_duplicates=False,
+ sort=False,
+):
"""Subsample a DataFrame using automated equilibrium detection on a timeseries.
This function uses the :mod:`pymbar` implementation of the *simple
diff --git a/src/alchemlyb/tests/conftest.py b/src/alchemlyb/tests/conftest.py
index b5b485e9..2a9530b1 100644
--- a/src/alchemlyb/tests/conftest.py
+++ b/src/alchemlyb/tests/conftest.py
@@ -198,7 +198,6 @@ def amber_simplesolvated_charge_dHdl(amber_simplesolvated):
@pytest.fixture
def amber_simplesolvated_vdw_dHdl(amber_simplesolvated):
-
return [
amber.extract_dHdl(filename, T=298.0)
for filename in amber_simplesolvated["vdw"]
diff --git a/src/alchemlyb/tests/parsing/test_amber.py b/src/alchemlyb/tests/parsing/test_amber.py
index c1fe137d..0d186cc7 100644
--- a/src/alchemlyb/tests/parsing/test_amber.py
+++ b/src/alchemlyb/tests/parsing/test_amber.py
@@ -2,28 +2,30 @@
"""
import logging
+
+import pandas as pd
import pytest
+from alchemtest.amber import load_bace_example
+from alchemtest.amber import load_bace_improper
+from alchemtest.amber import load_simplesolvated
+from alchemtest.amber import load_testfiles
from numpy.testing import assert_allclose
-import pandas as pd
+from alchemlyb.parsing.amber import extract
from alchemlyb.parsing.amber import extract_dHdl
from alchemlyb.parsing.amber import extract_u_nk
-from alchemlyb.parsing.amber import extract
-from alchemtest.amber import load_simplesolvated
-from alchemtest.amber import load_bace_example
-from alchemtest.amber import load_bace_improper
-from alchemtest.amber import load_testfiles
##################################################################################
################ Check the parser behaviour with problematic files
##################################################################################
+
@pytest.fixture(name="testfiles", scope="module")
def fixture_testfiles():
- """ Returns the testfiles data dictionary """
+ """Returns the testfiles data dictionary"""
bunch = load_testfiles()
- return bunch['data']
+ return bunch["data"]
def test_file_not_found():
@@ -77,10 +79,10 @@ def test_no_control_data(caplog, testfiles):
def test_no_free_energy_info(caplog, testfiles):
"""Test if we raise an exception if there is no free energy section"""
filename = testfiles["no_free_energy_info"][0]
- with pytest.raises(ValueError, match='no free energy section found'):
+ with pytest.raises(ValueError, match="no free energy section found"):
with caplog.at_level(logging.ERROR):
_ = extract(str(filename), T=298.0)
- assert 'No free energy section found' in caplog.text
+ assert "No free energy section found" in caplog.text
def test_no_useful_data(caplog, testfiles):
@@ -89,7 +91,7 @@ def test_no_useful_data(caplog, testfiles):
with pytest.raises(ValueError, match="does not contain any data"):
with caplog.at_level(logging.ERROR):
_ = extract(str(filename), T=298.0)
- assert 'does not contain any data' in caplog.text
+ assert "does not contain any data" in caplog.text
def test_no_temp0_set(caplog, testfiles):
@@ -119,16 +121,16 @@ def test_long_and_wrong_number_MBAR(caplog, testfiles):
with pytest.raises(ValueError, match="the number of lambda windows read"):
with caplog.at_level(logging.ERROR):
_ = extract_u_nk(str(filename), T=300.0)
- assert 'the number of lambda windows read' in caplog.text
+ assert "the number of lambda windows read" in caplog.text
def test_no_starting_time(caplog, testfiles):
"""Test if raise an exception if the starting time is not read"""
filename = testfiles["no_starting_simulation_time"][0]
- with pytest.raises(ValueError, match='No starting simulation time in file'):
+ with pytest.raises(ValueError, match="No starting simulation time in file"):
with caplog.at_level(logging.ERROR):
_ = extract(str(filename), T=298.0)
- assert 'No starting simulation time in file' in caplog.text
+ assert "No starting simulation time in file" in caplog.text
def test_parse_without_spaces_around_equal(testfiles):
@@ -138,23 +140,24 @@ def test_parse_without_spaces_around_equal(testfiles):
"""
filename = testfiles["no_spaces_around_equal"][0]
df_dict = extract(str(filename), T=298.0)
- assert isinstance(df_dict['dHdl'], pd.DataFrame)
+ assert isinstance(df_dict["dHdl"], pd.DataFrame)
##################################################################################
################ Check the parser behaviour with standard single files
##################################################################################
+
@pytest.fixture(name="single_u_nk", scope="module")
def fixture_single_u_nk():
"""return a single file to check u_unk parsing"""
- return load_bace_example().data['complex']['vdw'][0]
+ return load_bace_example().data["complex"]["vdw"][0]
@pytest.fixture(name="single_dHdl", scope="module")
def fixture_single_dHdl():
"""return a single file to check dHdl parsing"""
- return load_simplesolvated().data['charge'][0]
+ return load_simplesolvated().data["charge"][0]
def test_dHdl_time_reading(single_dHdl):
@@ -175,18 +178,18 @@ def test_extract_with_both_data(single_u_nk):
"""Test that dHdl and u_nk have the correct form when
extracted from files with the single "extract" funcion."""
df_dict = extract(single_u_nk, T=298.0)
- assert df_dict['dHdl'].index.names == ('time', 'lambdas')
- assert df_dict['dHdl'].shape == (500, 1)
- assert df_dict['u_nk'].index.names == ('time', 'lambdas')
+ assert df_dict["dHdl"].index.names == ("time", "lambdas")
+ assert df_dict["dHdl"].shape == (500, 1)
+ assert df_dict["u_nk"].index.names == ("time", "lambdas")
def test_extract_with_only_dhdl_data(single_dHdl):
"""Test that parsing with the extract function a file
- with just dHdl gives the correct results"""
+ with just dHdl gives the correct results"""
df_dict = extract(single_dHdl, T=298.0)
- assert df_dict['dHdl'].index.names == ('time', 'lambdas')
- assert df_dict['dHdl'].shape == (500, 1)
- assert df_dict['u_nk'] is None
+ assert df_dict["dHdl"].index.names == ("time", "lambdas")
+ assert df_dict["dHdl"].shape == (500, 1)
+ assert df_dict["u_nk"] is None
def test_wrong_T_should_raise_warning(single_dHdl, T=300.0):
@@ -195,24 +198,21 @@ def test_wrong_T_should_raise_warning(single_dHdl, T=300.0):
read from the AMBER file gives a warning
"""
with pytest.raises(
- ValueError,
- match="is different from the temperature passed as parameter"):
+ ValueError, match="is different from the temperature passed as parameter"
+ ):
_ = extract(single_dHdl, T=T)
-
###################################################################
################ Check the behaviour on proper datasets
###################################################################
-@pytest.mark.parametrize("filename",
- [filename
- for leg in load_simplesolvated()['data'].values()
- for filename in leg])
-def test_dHdl(filename,
- names=('time', 'lambdas'),
- shape=(500, 1)):
+@pytest.mark.parametrize(
+ "filename",
+ [filename for leg in load_simplesolvated()["data"].values() for filename in leg],
+)
+def test_dHdl(filename, names=("time", "lambdas"), shape=(500, 1)):
"""Test that dHdl has the correct form when extracted from files."""
dHdl = extract_dHdl(filename, T=298.0)
@@ -220,27 +220,33 @@ def test_dHdl(filename,
assert dHdl.shape == shape
-@pytest.mark.parametrize("mbar_filename",
- [mbar_filename
- for leg in load_bace_example()['data']['complex'].values()
- for mbar_filename in leg])
-def test_u_nk(mbar_filename,
- names=('time', 'lambdas')):
+@pytest.mark.parametrize(
+ "mbar_filename",
+ [
+ mbar_filename
+ for leg in load_bace_example()["data"]["complex"].values()
+ for mbar_filename in leg
+ ],
+)
+def test_u_nk(mbar_filename, names=("time", "lambdas")):
"""Test the u_nk has the correct form when extracted from files"""
u_nk = extract_u_nk(mbar_filename, T=298.0)
assert u_nk.index.names == names
-@pytest.mark.parametrize("improper_filename",
- [improper_filename
- for leg in load_bace_improper()['data'].values()
- for improper_filename in leg])
-def test_u_nk_improper(improper_filename,
- names=('time', 'lambdas')):
+@pytest.mark.parametrize(
+ "improper_filename",
+ [
+ improper_filename
+ for leg in load_bace_improper()["data"].values()
+ for improper_filename in leg
+ ],
+)
+def test_u_nk_improper(improper_filename, names=("time", "lambdas")):
"""Test the u_nk has the correct form when extracted from files"""
try:
u_nk = extract_u_nk(improper_filename, T=298.0)
assert u_nk.index.names == names
except Exception:
- assert '0.5626' in improper_filename
+ assert "0.5626" in improper_filename
diff --git a/src/alchemlyb/tests/parsing/test_gmx.py b/src/alchemlyb/tests/parsing/test_gmx.py
index d85ad1bf..1959c925 100644
--- a/src/alchemlyb/tests/parsing/test_gmx.py
+++ b/src/alchemlyb/tests/parsing/test_gmx.py
@@ -3,118 +3,144 @@
"""
import bz2
-import pytest
-from alchemlyb.parsing.gmx import extract_dHdl, extract_u_nk, extract
+import pytest
from alchemtest.gmx import load_benzene
-from alchemtest.gmx import load_expanded_ensemble_case_1, load_expanded_ensemble_case_2, load_expanded_ensemble_case_3
-from alchemtest.gmx import load_water_particle_with_total_energy
+from alchemtest.gmx import (
+ load_expanded_ensemble_case_1,
+ load_expanded_ensemble_case_2,
+ load_expanded_ensemble_case_3,
+)
from alchemtest.gmx import load_water_particle_with_potential_energy
+from alchemtest.gmx import load_water_particle_with_total_energy
from alchemtest.gmx import load_water_particle_without_energy
from numpy.testing import assert_almost_equal
+from alchemlyb.parsing.gmx import extract_dHdl, extract_u_nk, extract
+
def test_dHdl():
- """Test that dHdl has the correct form when extracted from files.
-
- """
+ """Test that dHdl has the correct form when extracted from files."""
dataset = load_benzene()
- for leg in dataset['data']:
- for filename in dataset['data'][leg]:
+ for leg in dataset["data"]:
+ for filename in dataset["data"][leg]:
dHdl = extract_dHdl(filename, T=300)
- assert dHdl.index.names == ['time', 'fep-lambda']
+ assert dHdl.index.names == ["time", "fep-lambda"]
assert dHdl.shape == (4001, 1)
-def test_u_nk():
- """Test that u_nk has the correct form when extracted from files.
- """
+def test_u_nk():
+ """Test that u_nk has the correct form when extracted from files."""
dataset = load_benzene()
- for leg in dataset['data']:
- for filename in dataset['data'][leg]:
+ for leg in dataset["data"]:
+ for filename in dataset["data"][leg]:
u_nk = extract_u_nk(filename, T=300)
- assert u_nk.index.names == ['time', 'fep-lambda']
- if leg == 'Coulomb':
+ assert u_nk.index.names == ["time", "fep-lambda"]
+ if leg == "Coulomb":
assert u_nk.shape == (4001, 5)
- elif leg == 'VDW':
+ elif leg == "VDW":
assert u_nk.shape == (4001, 16)
-def test_u_nk_case1():
- """Test that u_nk has the correct form when extracted from expanded ensemble files (case 1).
- """
+def test_u_nk_case1():
+ """Test that u_nk has the correct form when extracted from expanded ensemble files (case 1)."""
dataset = load_expanded_ensemble_case_1()
- for leg in dataset['data']:
- for filename in dataset['data'][leg]:
+ for leg in dataset["data"]:
+ for filename in dataset["data"][leg]:
u_nk = extract_u_nk(filename, T=300, filter=False)
- assert u_nk.index.names == ['time', 'fep-lambda', 'coul-lambda', 'vdw-lambda', 'restraint-lambda']
+ assert u_nk.index.names == [
+ "time",
+ "fep-lambda",
+ "coul-lambda",
+ "vdw-lambda",
+ "restraint-lambda",
+ ]
assert u_nk.shape == (50001, 28)
-def test_dHdl_case1():
- """Test that dHdl has the correct form when extracted from expanded ensemble files (case 1).
- """
+def test_dHdl_case1():
+ """Test that dHdl has the correct form when extracted from expanded ensemble files (case 1)."""
dataset = load_expanded_ensemble_case_1()
- for leg in dataset['data']:
- for filename in dataset['data'][leg]:
+ for leg in dataset["data"]:
+ for filename in dataset["data"][leg]:
dHdl = extract_dHdl(filename, T=300, filter=False)
- assert dHdl.index.names == ['time', 'fep-lambda', 'coul-lambda', 'vdw-lambda', 'restraint-lambda']
+ assert dHdl.index.names == [
+ "time",
+ "fep-lambda",
+ "coul-lambda",
+ "vdw-lambda",
+ "restraint-lambda",
+ ]
assert dHdl.shape == (50001, 4)
-def test_u_nk_case2():
- """Test that u_nk has the correct form when extracted from expanded ensemble files (case 2).
- """
+def test_u_nk_case2():
+ """Test that u_nk has the correct form when extracted from expanded ensemble files (case 2)."""
dataset = load_expanded_ensemble_case_2()
- for leg in dataset['data']:
- for filename in dataset['data'][leg]:
+ for leg in dataset["data"]:
+ for filename in dataset["data"][leg]:
u_nk = extract_u_nk(filename, T=300, filter=False)
- assert u_nk.index.names == ['time', 'fep-lambda', 'coul-lambda', 'vdw-lambda', 'restraint-lambda']
+ assert u_nk.index.names == [
+ "time",
+ "fep-lambda",
+ "coul-lambda",
+ "vdw-lambda",
+ "restraint-lambda",
+ ]
assert u_nk.shape == (25001, 28)
-def test_u_nk_case3():
- """Test that u_nk has the correct form when extracted from REX files (case 3).
- """
+def test_u_nk_case3():
+ """Test that u_nk has the correct form when extracted from REX files (case 3)."""
dataset = load_expanded_ensemble_case_3()
- for leg in dataset['data']:
- for filename in dataset['data'][leg]:
+ for leg in dataset["data"]:
+ for filename in dataset["data"][leg]:
u_nk = extract_u_nk(filename, T=300, filter=False)
- assert u_nk.index.names == ['time', 'fep-lambda', 'coul-lambda', 'vdw-lambda', 'restraint-lambda']
+ assert u_nk.index.names == [
+ "time",
+ "fep-lambda",
+ "coul-lambda",
+ "vdw-lambda",
+ "restraint-lambda",
+ ]
assert u_nk.shape == (2500, 28)
-def test_dHdl_case3():
- """Test that dHdl has the correct form when extracted from REX files (case 3).
- """
+def test_dHdl_case3():
+ """Test that dHdl has the correct form when extracted from REX files (case 3)."""
dataset = load_expanded_ensemble_case_3()
- for leg in dataset['data']:
- for filename in dataset['data'][leg]:
+ for leg in dataset["data"]:
+ for filename in dataset["data"][leg]:
dHdl = extract_dHdl(filename, T=300, filter=False)
- assert dHdl.index.names == ['time', 'fep-lambda', 'coul-lambda', 'vdw-lambda', 'restraint-lambda']
+ assert dHdl.index.names == [
+ "time",
+ "fep-lambda",
+ "coul-lambda",
+ "vdw-lambda",
+ "restraint-lambda",
+ ]
assert dHdl.shape == (2500, 4)
-def test_u_nk_with_total_energy():
- """Test that the reduced potential is calculated correctly when the total energy is given.
- """
+def test_u_nk_with_total_energy():
+ """Test that the reduced potential is calculated correctly when the total energy is given."""
# Load dataset
dataset = load_water_particle_with_total_energy()
@@ -124,15 +150,16 @@ def test_u_nk_with_total_energy():
# Check one specific value in the dataframe
assert_almost_equal(
- extract_u_nk(dataset['data']['AllStates'][0], T=300).loc[0][(0.0,0.0)].values[0],
+ extract_u_nk(dataset["data"]["AllStates"][0], T=300)
+ .loc[0][(0.0, 0.0)]
+ .values[0],
-11211.577658852531,
- decimal=6
+ decimal=6,
)
-def test_u_nk_with_potential_energy():
- """Test that the reduced potential is calculated correctly when the potential energy is given.
- """
+def test_u_nk_with_potential_energy():
+ """Test that the reduced potential is calculated correctly when the potential energy is given."""
# Load dataset
dataset = load_water_particle_with_potential_energy()
@@ -142,16 +169,16 @@ def test_u_nk_with_potential_energy():
# Check one specific value in the dataframe
assert_almost_equal(
- extract_u_nk(dataset['data']['AllStates'][0], T=300).loc[0][(0.0,0.0)].values[0],
+ extract_u_nk(dataset["data"]["AllStates"][0], T=300)
+ .loc[0][(0.0, 0.0)]
+ .values[0],
-15656.557252200757,
- decimal=6
+ decimal=6,
)
def test_u_nk_without_energy():
- """Test that the reduced potential is calculated correctly when no energy is given.
-
- """
+ """Test that the reduced potential is calculated correctly when no energy is given."""
# Load dataset
dataset = load_water_particle_without_energy()
@@ -161,105 +188,114 @@ def test_u_nk_without_energy():
# Check one specific value in the dataframe
assert_almost_equal(
- extract_u_nk(dataset['data']['AllStates'][0], T=300).loc[0][(0.0,0.0)].values[0],
+ extract_u_nk(dataset["data"]["AllStates"][0], T=300)
+ .loc[0][(0.0, 0.0)]
+ .values[0],
0.0,
- decimal=6
+ decimal=6,
)
def _diag_sum(dataset):
- """Calculate the sum of diagonal elements (i, i)
-
- """
+ """Calculate the sum of diagonal elements (i, i)"""
# Initialize the sum variable
ds = 0.0
- for leg in dataset['data']:
- for filename in dataset['data'][leg]:
+ for leg in dataset["data"]:
+ for filename in dataset["data"][leg]:
u_nk = extract_u_nk(filename, T=300)
# Calculate the sum of diagonal elements:
for i, lambda_ in enumerate(u_nk.columns):
- #18.6 is the time step
- ds += u_nk.loc[i*186/10][lambda_].values[0]
+ # 18.6 is the time step
+ ds += u_nk.loc[i * 186 / 10][lambda_].values[0]
return ds
+
def test_extract_u_nk_unit():
- '''Test if extract_u_nk assign the attr correctly'''
+ """Test if extract_u_nk assign the attr correctly"""
dataset = load_benzene()
- u_nk = extract_u_nk(dataset['data']['Coulomb'][0], 310)
- assert u_nk.attrs['temperature'] == 310
- assert u_nk.attrs['energy_unit'] == 'kT'
+ u_nk = extract_u_nk(dataset["data"]["Coulomb"][0], 310)
+ assert u_nk.attrs["temperature"] == 310
+ assert u_nk.attrs["energy_unit"] == "kT"
+
def test_extract_dHdl_unit():
- '''Test if extract_u_nk assign the attr correctly'''
+ """Test if extract_u_nk assign the attr correctly"""
dataset = load_benzene()
- dhdl = extract_dHdl(dataset['data']['Coulomb'][0], 310)
- assert dhdl.attrs['temperature'] == 310
- assert dhdl.attrs['energy_unit'] == 'kT'
+ dhdl = extract_dHdl(dataset["data"]["Coulomb"][0], 310)
+ assert dhdl.attrs["temperature"] == 310
+ assert dhdl.attrs["energy_unit"] == "kT"
+
def test_calling_extract():
- '''Test if the extract function is working'''
+ """Test if the extract function is working"""
dataset = load_benzene()
- df_dict = extract(dataset['data']['Coulomb'][0], 310)
- assert df_dict['dHdl'].attrs['temperature'] == 310
- assert df_dict['dHdl'].attrs['energy_unit'] == 'kT'
- assert df_dict['u_nk'].attrs['temperature'] == 310
- assert df_dict['u_nk'].attrs['energy_unit'] == 'kT'
-
-class TestRobustGMX():
- '''Test dropping the row that is wrong in different way'''
+ df_dict = extract(dataset["data"]["Coulomb"][0], 310)
+ assert df_dict["dHdl"].attrs["temperature"] == 310
+ assert df_dict["dHdl"].attrs["energy_unit"] == "kT"
+ assert df_dict["u_nk"].attrs["temperature"] == 310
+ assert df_dict["u_nk"].attrs["energy_unit"] == "kT"
+
+
+class TestRobustGMX:
+ """Test dropping the row that is wrong in different way"""
+
@staticmethod
- @pytest.fixture(scope='class')
+ @pytest.fixture(scope="class")
def data():
- dhdl = extract_dHdl(load_benzene()['data']['Coulomb'][0], 310)
- with bz2.open(load_benzene()['data']['Coulomb'][0], "rt") as bz_file:
+ dhdl = extract_dHdl(load_benzene()["data"]["Coulomb"][0], 310)
+ with bz2.open(load_benzene()["data"]["Coulomb"][0], "rt") as bz_file:
text = bz_file.read()
return text, len(dhdl)
def test_sanity(self, data, tmp_path):
- '''Test if the test routine is working.'''
+ """Test if the test routine is working."""
text, length = data
- new_text = tmp_path / 'text.xvg'
+ new_text = tmp_path / "text.xvg"
new_text.write_text(text)
dhdl = extract_dHdl(new_text, 310)
assert len(dhdl) == length
def test_truncated_row(self, data, tmp_path):
- '''Test the case where the last row has been truncated.'''
+ """Test the case where the last row has been truncated."""
text, length = data
- new_text = tmp_path / 'text.xvg'
- new_text.write_text(text + '40010.0 27.0\n')
+ new_text = tmp_path / "text.xvg"
+ new_text.write_text(text + "40010.0 27.0\n")
dhdl = extract_dHdl(new_text, 310, filter=True)
assert len(dhdl) == length
def test_truncated_number(self, data, tmp_path):
- '''Test the case where the last row has been truncated and a - has
- been left.'''
+ """Test the case where the last row has been truncated and a - has
+ been left."""
text, length = data
- new_text = tmp_path / 'text.xvg'
- new_text.write_text(text + '40010.0 27.0 -\n')
+ new_text = tmp_path / "text.xvg"
+ new_text.write_text(text + "40010.0 27.0 -\n")
dhdl = extract_dHdl(new_text, 310, filter=True)
assert len(dhdl) == length
def test_weirdnumber(self, data, tmp_path):
- '''Test the case where the last number has been appended a weird
- number.'''
+ """Test the case where the last number has been appended a weird
+ number."""
text, length = data
- new_text = tmp_path / 'text.xvg'
+ new_text = tmp_path / "text.xvg"
# Note the 27.040010.0 which is the sum of 27.0 and 40010.0
- new_text.write_text(text + '40010.0 27.040010.0 27.0 0.0 6.7 13.5 20.2 27.0 0.7 27.0 0.0 6.7 '
- '13.5 20.2 27.0 0.7\n')
+ new_text.write_text(
+ text + "40010.0 27.040010.0 27.0 0.0 6.7 13.5 20.2 27.0 0.7 27.0 0.0 6.7 "
+ "13.5 20.2 27.0 0.7\n"
+ )
dhdl = extract_dHdl(new_text, 310, filter=True)
assert len(dhdl) == length
def test_too_many_cols(self, data, tmp_path):
- '''Test the case where the row has too many columns.'''
+ """Test the case where the row has too many columns."""
text, length = data
- new_text = tmp_path / 'text.xvg'
- new_text.write_text(text +
- '40010.0 27.0 0.0 6.7 13.5 20.2 27.0 0.7 27.0 0.0 6.7 13.5 20.2 27.0 0.7\n')
+ new_text = tmp_path / "text.xvg"
+ new_text.write_text(
+ text
+ + "40010.0 27.0 0.0 6.7 13.5 20.2 27.0 0.7 27.0 0.0 6.7 13.5 20.2 27.0 0.7\n"
+ )
dhdl = extract_dHdl(new_text, 310, filter=True)
assert len(dhdl) == length
diff --git a/src/alchemlyb/tests/parsing/test_gomc.py b/src/alchemlyb/tests/parsing/test_gomc.py
index d61b3789..241add45 100644
--- a/src/alchemlyb/tests/parsing/test_gomc.py
+++ b/src/alchemlyb/tests/parsing/test_gomc.py
@@ -2,43 +2,40 @@
"""
-from alchemlyb.parsing.gomc import extract_dHdl, extract_u_nk, extract
from alchemtest.gomc import load_benzene
+from alchemlyb.parsing.gomc import extract_dHdl, extract_u_nk, extract
+
def test_dHdl():
- """Test that dHdl has the correct form when extracted from files.
-
- """
+ """Test that dHdl has the correct form when extracted from files."""
dataset = load_benzene()
- for filename in dataset['data']:
+ for filename in dataset["data"]:
dHdl = extract_dHdl(filename, T=298)
- assert dHdl.index.names == ['time', 'Coulomb-lambda', 'VDW-lambda']
+ assert dHdl.index.names == ["time", "Coulomb-lambda", "VDW-lambda"]
assert dHdl.shape == (1000, 2)
-def test_u_nk():
- """Test that u_nk has the correct form when extracted from files.
- """
+def test_u_nk():
+ """Test that u_nk has the correct form when extracted from files."""
dataset = load_benzene()
- for filename in dataset['data']:
+ for filename in dataset["data"]:
u_nk = extract_u_nk(filename, T=298)
- assert u_nk.index.names == ['time', 'Coulomb-lambda', 'VDW-lambda']
+ assert u_nk.index.names == ["time", "Coulomb-lambda", "VDW-lambda"]
assert u_nk.shape == (1000, 23)
-def test_extract():
- """Test that u_nk and dHdl have the correct form when extracted from files.
- """
+def test_extract():
+ """Test that u_nk and dHdl have the correct form when extracted from files."""
dataset = load_benzene()
- df_dict = extract(dataset['data'][0], T=298)
+ df_dict = extract(dataset["data"][0], T=298)
- assert df_dict['u_nk'].index.names == ['time', 'Coulomb-lambda', 'VDW-lambda']
- assert df_dict['u_nk'].shape == (1000, 23)
- assert df_dict['dHdl'].index.names == ['time', 'Coulomb-lambda', 'VDW-lambda']
- assert df_dict['dHdl'].shape == (1000, 2)
+ assert df_dict["u_nk"].index.names == ["time", "Coulomb-lambda", "VDW-lambda"]
+ assert df_dict["u_nk"].shape == (1000, 23)
+ assert df_dict["dHdl"].index.names == ["time", "Coulomb-lambda", "VDW-lambda"]
+ assert df_dict["dHdl"].shape == (1000, 2)
diff --git a/src/alchemlyb/tests/parsing/test_namd.py b/src/alchemlyb/tests/parsing/test_namd.py
index 8c4c1858..f5168e8f 100644
--- a/src/alchemlyb/tests/parsing/test_namd.py
+++ b/src/alchemlyb/tests/parsing/test_namd.py
@@ -1,16 +1,17 @@
"""NAMD parser tests.
"""
+import bz2
from os.path import basename
from re import search
-import bz2
-import pytest
-from alchemlyb.parsing.namd import extract_u_nk, extract
-from alchemtest.namd import load_tyr2ala
+import pytest
from alchemtest.namd import load_idws
from alchemtest.namd import load_restarted
from alchemtest.namd import load_restarted_reversed
+from alchemtest.namd import load_tyr2ala
+
+from alchemlyb.parsing.namd import extract_u_nk, extract
# Indices of lambda values in the following line in NAMD fepout files:
# #NEW FEP WINDOW: LAMBDA SET TO 0.6 LAMBDA2 0.7 LAMBDA_IDWS 0.5
@@ -27,27 +28,30 @@
def dataset():
return load_tyr2ala()
-@pytest.mark.parametrize("direction,shape",
- [('forward', (21021, 21)),
- ('backward', (21021, 21)),
- ])
+
+@pytest.mark.parametrize(
+ "direction,shape",
+ [
+ ("forward", (21021, 21)),
+ ("backward", (21021, 21)),
+ ],
+)
def test_u_nk(dataset, direction, shape):
- """Test that u_nk has the correct form when extracted from files.
- """
- for filename in dataset['data'][direction]:
+ """Test that u_nk has the correct form when extracted from files."""
+ for filename in dataset["data"][direction]:
u_nk = extract_u_nk(filename, T=300)
- assert u_nk.index.names == ['time', 'fep-lambda']
+ assert u_nk.index.names == ["time", "fep-lambda"]
assert u_nk.shape == shape
+
def test_u_nk_idws():
- """Test that u_nk has the correct form when extracted from files.
- """
+ """Test that u_nk has the correct form when extracted from files."""
- filenames = load_idws()['data']['forward']
+ filenames = load_idws()["data"]["forward"]
u_nk = extract_u_nk(filenames, T=300)
- assert u_nk.index.names == ['time', 'fep-lambda']
+ assert u_nk.index.names == ["time", "fep-lambda"]
assert u_nk.shape == (29252, 11)
@@ -64,7 +68,7 @@ def _corrupt_fepout(fepout_in, params, tmp_path):
----------
fepout_in: str
Path to fepout file to be modified. This file will not be overwritten.
-
+
params: list of tuples
For each tuple, the first element must be a str that will be passed to
startswith() to identify the line(s) to modify (e.g. "#NEW"). The
@@ -82,13 +86,17 @@ def _corrupt_fepout(fepout_in, params, tmp_path):
"""
fepout_out = tmp_path / basename(fepout_in)
- with bz2.open(fepout_out, 'wt') as f_out:
- with bz2.open(fepout_in, 'rt') as f_in:
+ with bz2.open(fepout_out, "wt") as f_out:
+ with bz2.open(fepout_in, "rt") as f_in:
for line in f_in:
for prefix, func in params:
if line.startswith(prefix):
tokens_out = func(line.split())
- line = ' '.join(tokens_out) + '\n' if tokens_out is not None else None
+ line = (
+ " ".join(tokens_out) + "\n"
+ if tokens_out is not None
+ else None
+ )
if line is not None:
f_out.write(line)
return str(fepout_out)
@@ -99,9 +107,10 @@ def restarted_dataset_inconsistent(restarted_dataset, tmp_path):
"""Returns intentionally messed up dataset where lambda1 and lambda2 at start and end of
a window are different."""
- filenames = sorted(restarted_dataset['data']['both'])
+ filenames = sorted(restarted_dataset["data"]["both"])
changed = False
+
def func_free_line(l):
nonlocal changed
if float(l[7]) >= 0.7 and float(l[7]) < 0.9:
@@ -110,13 +119,15 @@ def func_free_line(l):
return l
for i in range(len(filenames)):
- filenames[i] = _corrupt_fepout(filenames[i], [('#Free', func_free_line)], tmp_path)
+ filenames[i] = _corrupt_fepout(
+ filenames[i], [("#Free", func_free_line)], tmp_path
+ )
# Only actually modify one window so we don't trigger the wrong exception
if changed is True:
break
# Don't directly modify the glob object
- restarted_dataset['data']['both'] = filenames
+ restarted_dataset["data"]["both"] = filenames
return restarted_dataset
@@ -129,26 +140,32 @@ def restarted_dataset_idws_without_lambda_idws(restarted_dataset, tmp_path):
# First window won't have any IDWS data so we just drop all its files and fudge the lambdas
# in the next window to include 0.0 or 1.0 (as appropriate) so we still have a nominally complete calculation
-
- filenames = [x for x in sorted(restarted_dataset['data']['both']) if search('000[a-z]?.fepout', x) is None]
+
+ filenames = [
+ x
+ for x in sorted(restarted_dataset["data"]["both"])
+ if search("000[a-z]?.fepout", x) is None
+ ]
def func_new_line(l):
- if float(l[LAMBDA1_IDX_NEW]) > 0.5: # 1->0 (reversed) calculation
- l[LAMBDA1_IDX_NEW] == '1.0'
- else: # regular 0->1 calculation
- l[LAMBDA1_IDX_NEW] = '0.0'
+ if float(l[LAMBDA1_IDX_NEW]) > 0.5: # 1->0 (reversed) calculation
+ l[LAMBDA1_IDX_NEW] == "1.0"
+ else: # regular 0->1 calculation
+ l[LAMBDA1_IDX_NEW] = "0.0"
# Drop the lambda_idws
return l[:9]
-
+
def func_free_line(l):
- if float(l[LAMBDA1_IDX_FREE]) > 0.5: # 1->0 (reversed) calculation
- l[LAMBDA1_IDX_FREE] == '1.0'
- else: # regular 0->1 calculation
- l[LAMBDA1_IDX_FREE] = '0.0'
+ if float(l[LAMBDA1_IDX_FREE]) > 0.5: # 1->0 (reversed) calculation
+ l[LAMBDA1_IDX_FREE] == "1.0"
+ else: # regular 0->1 calculation
+ l[LAMBDA1_IDX_FREE] = "0.0"
return l
-
- filenames[0] = _corrupt_fepout(filenames[0], [('#NEW', func_new_line), ('#Free', func_free_line)], tmp_path)
- restarted_dataset['data']['both'] = filenames
+
+ filenames[0] = _corrupt_fepout(
+ filenames[0], [("#NEW", func_new_line), ("#Free", func_free_line)], tmp_path
+ )
+ restarted_dataset["data"]["both"] = filenames
return restarted_dataset
@@ -157,7 +174,7 @@ def restarted_dataset_toomany_lambda2(restarted_dataset, tmp_path):
"""Returns intentionally messed up dataset, where there are too many lambda2 values for a
given lambda1."""
- filenames = sorted(restarted_dataset['data']['both'])
+ filenames = sorted(restarted_dataset["data"]["both"])
# For the same l1 and lidws we retain old lambda2 values thus ensuring a collision
# Also, don't make a window where lambda1 >= lambda2 because this will trigger the
@@ -165,22 +182,23 @@ def restarted_dataset_toomany_lambda2(restarted_dataset, tmp_path):
def func_new_line(l):
if float(l[LAMBDA2_IDX_NEW]) <= 0.2:
return l
- l[LAMBDA1_IDX_NEW] = '0.2'
- if len(l) > 9 and l[9] == 'LAMBDA_IDWS':
- l[LAMBDA_IDWS_IDX_NEW] = '0.1'
+ l[LAMBDA1_IDX_NEW] = "0.2"
+ if len(l) > 9 and l[9] == "LAMBDA_IDWS":
+ l[LAMBDA_IDWS_IDX_NEW] = "0.1"
return l
def func_free_line(l):
if float(l[LAMBDA2_IDX_FREE]) <= 0.2:
return l
- l[LAMBDA1_IDX_FREE] = '0.2'
+ l[LAMBDA1_IDX_FREE] = "0.2"
return l
for i in range(len(filenames)):
- filenames[i] = \
- _corrupt_fepout(filenames[i], [('#NEW', func_new_line), ('#Free', func_free_line)], tmp_path)
+ filenames[i] = _corrupt_fepout(
+ filenames[i], [("#NEW", func_new_line), ("#Free", func_free_line)], tmp_path
+ )
- restarted_dataset['data']['both'] = filenames
+ restarted_dataset["data"]["both"] = filenames
return restarted_dataset
@@ -189,7 +207,7 @@ def restarted_dataset_toomany_lambda_idws(restarted_dataset, tmp_path):
"""Returns intentionally messed up dataset, where there are too many lambda2 values for a
given lambda1."""
- filenames = sorted(restarted_dataset['data']['both'])
+ filenames = sorted(restarted_dataset["data"]["both"])
# For the same lambda1 and lambda2 we retain the first set of lambda1/lambda2 values
# and replicate them across all windows thus ensuring that there will be more than
@@ -198,7 +216,7 @@ def restarted_dataset_toomany_lambda_idws(restarted_dataset, tmp_path):
def func_new_line(l):
nonlocal this_lambda1, this_lambda2
-
+
if this_lambda1 is None:
this_lambda1, this_lambda2 = l[LAMBDA1_IDX_NEW], l[LAMBDA2_IDX_NEW]
# Ensure that changing these lambda values won't cause a reversal in direction and trigger
@@ -212,9 +230,11 @@ def func_free_line(l):
return l
for i in range(len(filenames)):
- filenames[i] = _corrupt_fepout(filenames[i], [('#NEW', func_new_line)], tmp_path)
+ filenames[i] = _corrupt_fepout(
+ filenames[i], [("#NEW", func_new_line)], tmp_path
+ )
- restarted_dataset['data']['both'] = filenames
+ restarted_dataset["data"]["both"] = filenames
return restarted_dataset
@@ -222,7 +242,7 @@ def func_free_line(l):
def restarted_dataset_direction_changed(restarted_dataset, tmp_path):
"""Returns intentionally messed up dataset, with one window where the lambda values are reversed."""
- filenames = sorted(restarted_dataset['data']['both'])
+ filenames = sorted(restarted_dataset["data"]["both"])
def func_new_line(l):
l[6], l[8], l[10] = l[10], l[8], l[6]
@@ -231,12 +251,16 @@ def func_new_line(l):
def func_free_line(l):
l[7], l[8] = l[8], l[7]
return l
-
+
# Reverse the direction of lambdas for this window
idx_to_corrupt = filenames.index(sorted(filenames)[-3])
- fname1 = _corrupt_fepout(filenames[idx_to_corrupt], [('#NEW', func_new_line), ('#Free', func_free_line)], tmp_path)
+ fname1 = _corrupt_fepout(
+ filenames[idx_to_corrupt],
+ [("#NEW", func_new_line), ("#Free", func_free_line)],
+ tmp_path,
+ )
filenames[idx_to_corrupt] = fname1
- restarted_dataset['data']['both'] = filenames
+ restarted_dataset["data"]["both"] = filenames
return restarted_dataset
@@ -244,15 +268,17 @@ def func_free_line(l):
def restarted_dataset_all_windows_truncated(restarted_dataset, tmp_path):
"""Returns dataset where all windows are truncated (no #Free... footer lines)."""
- filenames = sorted(restarted_dataset['data']['both'])
+ filenames = sorted(restarted_dataset["data"]["both"])
def func_free_line(l):
return None
for i in range(len(filenames)):
- filenames[i] = _corrupt_fepout(filenames[i], [('#Free', func_free_line)], tmp_path)
-
- restarted_dataset['data']['both'] = filenames
+ filenames[i] = _corrupt_fepout(
+ filenames[i], [("#Free", func_free_line)], tmp_path
+ )
+
+ restarted_dataset["data"]["both"] = filenames
return restarted_dataset
@@ -260,13 +286,15 @@ def func_free_line(l):
def restarted_dataset_last_window_truncated(restarted_dataset, tmp_path):
"""Returns dataset where the last window is truncated (no #Free... footer line)."""
- filenames = sorted(restarted_dataset['data']['both'])
+ filenames = sorted(restarted_dataset["data"]["both"])
def func_free_line(l):
return None
- filenames[-1] = _corrupt_fepout(filenames[-1], [('#Free', func_free_line)], tmp_path)
- restarted_dataset['data']['both'] = filenames
+ filenames[-1] = _corrupt_fepout(
+ filenames[-1], [("#Free", func_free_line)], tmp_path
+ )
+ restarted_dataset["data"]["both"] = filenames
return restarted_dataset
@@ -274,72 +302,91 @@ def test_u_nk_restarted():
"""Test that u_nk has the correct form when extracted from an IDWS
FEP run that includes terminations and restarts.
"""
- filenames = load_restarted()['data']['both']
+ filenames = load_restarted()["data"]["both"]
u_nk = extract_u_nk(filenames, T=300)
-
- assert u_nk.index.names == ['time', 'fep-lambda']
+
+ assert u_nk.index.names == ["time", "fep-lambda"]
assert u_nk.shape == (30061, 11)
def test_u_nk_restarted_missing_window_header(tmp_path):
"""Test that u_nk has the correct form when a #NEW line is missing from the restarted dataset
and the parser has to infer lambda_idws for that window."""
- filenames = sorted(load_restarted()['data']['both'])
+ filenames = sorted(load_restarted()["data"]["both"])
# Remove "#NEW" line
- filenames[4] = _corrupt_fepout(filenames[4], [('#NEW', lambda l: None),], tmp_path)
+ filenames[4] = _corrupt_fepout(
+ filenames[4],
+ [
+ ("#NEW", lambda l: None),
+ ],
+ tmp_path,
+ )
u_nk = extract_u_nk(filenames, T=300)
-
- assert u_nk.index.names == ['time', 'fep-lambda']
+
+ assert u_nk.index.names == ["time", "fep-lambda"]
assert u_nk.shape == (30061, 11)
def test_u_nk_restarted_reversed():
- filenames = load_restarted_reversed()['data']['both']
+ filenames = load_restarted_reversed()["data"]["both"]
u_nk = extract_u_nk(filenames, T=300)
-
- assert u_nk.index.names == ['time', 'fep-lambda']
+
+ assert u_nk.index.names == ["time", "fep-lambda"]
assert u_nk.shape == (30170, 11)
def test_extract():
- filenames = load_restarted_reversed()['data']['both']
+ filenames = load_restarted_reversed()["data"]["both"]
df_dict = extract(filenames, T=300)
- assert df_dict['u_nk'].index.names == ['time', 'fep-lambda']
- assert df_dict['u_nk'].shape == (30170, 11)
- assert 'dHdl' not in df_dict
+ assert df_dict["u_nk"].index.names == ["time", "fep-lambda"]
+ assert df_dict["u_nk"].shape == (30170, 11)
+ assert "dHdl" not in df_dict
def test_u_nk_restarted_reversed_missing_window_header(tmp_path):
"""Test that u_nk has the correct form when a #NEW line is missing from the restarted_reversed dataset
and the parser has to infer lambda_idws for that window."""
- filenames = sorted(load_restarted_reversed()['data']['both'])
+ filenames = sorted(load_restarted_reversed()["data"]["both"])
# Remove "#NEW" line
- filenames[4] = _corrupt_fepout(filenames[4], [('#NEW', lambda l: None),], tmp_path)
+ filenames[4] = _corrupt_fepout(
+ filenames[4],
+ [
+ ("#NEW", lambda l: None),
+ ],
+ tmp_path,
+ )
u_nk = extract_u_nk(filenames, T=300)
-
- assert u_nk.index.names == ['time', 'fep-lambda']
+
+ assert u_nk.index.names == ["time", "fep-lambda"]
assert u_nk.shape == (30170, 11)
def test_u_nk_restarted_direction_changed(restarted_dataset_direction_changed):
"""Test that when lambda values change direction within a dataset, parsing throws an error."""
- with pytest.raises(ValueError, match='Lambda values change direction'):
- u_nk = extract_u_nk(restarted_dataset_direction_changed['data']['both'], T=300)
+ with pytest.raises(ValueError, match="Lambda values change direction"):
+ u_nk = extract_u_nk(restarted_dataset_direction_changed["data"]["both"], T=300)
-def test_u_nk_restarted_idws_without_lambda_idws(restarted_dataset_idws_without_lambda_idws):
+def test_u_nk_restarted_idws_without_lambda_idws(
+ restarted_dataset_idws_without_lambda_idws,
+):
"""Test that when the first window has IDWS data but no lambda_idws, parsing throws an error.
-
+
In this situation, the lambda_idws cannot be inferred, because there's no previous lambda
value available.
"""
- with pytest.raises(ValueError, match='IDWS data present in first window but lambda_idws not included'):
- u_nk = extract_u_nk(restarted_dataset_idws_without_lambda_idws['data']['both'], T=300)
+ with pytest.raises(
+ ValueError,
+ match="IDWS data present in first window but lambda_idws not included",
+ ):
+ u_nk = extract_u_nk(
+ restarted_dataset_idws_without_lambda_idws["data"]["both"], T=300
+ )
def test_u_nk_restarted_inconsistent(restarted_dataset_inconsistent):
@@ -347,33 +394,45 @@ def test_u_nk_restarted_inconsistent(restarted_dataset_inconsistent):
parsing throws an error.
"""
- with pytest.raises(ValueError, match='Inconsistent lambda values within the same window'):
- u_nk = extract_u_nk(restarted_dataset_inconsistent['data']['both'], T=300)
+ with pytest.raises(
+ ValueError, match="Inconsistent lambda values within the same window"
+ ):
+ u_nk = extract_u_nk(restarted_dataset_inconsistent["data"]["both"], T=300)
def test_u_nk_restarted_toomany_lambda_idws(restarted_dataset_toomany_lambda_idws):
"""Test that when there is more than one lambda_idws for a given lambda1, parsing throws an error."""
- with pytest.raises(ValueError, match='More than one lambda_idws value for a particular lambda1'):
- u_nk = extract_u_nk(restarted_dataset_toomany_lambda_idws['data']['both'], T=300)
+ with pytest.raises(
+ ValueError, match="More than one lambda_idws value for a particular lambda1"
+ ):
+ u_nk = extract_u_nk(
+ restarted_dataset_toomany_lambda_idws["data"]["both"], T=300
+ )
def test_u_nk_restarted_toomany_lambda2(restarted_dataset_toomany_lambda2):
"""Test that when there is more than one lambda2 for a given lambda1, parsing throws an error."""
- with pytest.raises(ValueError, match='More than one lambda2 value for a particular lambda1'):
- u_nk = extract_u_nk(restarted_dataset_toomany_lambda2['data']['both'], T=300)
+ with pytest.raises(
+ ValueError, match="More than one lambda2 value for a particular lambda1"
+ ):
+ u_nk = extract_u_nk(restarted_dataset_toomany_lambda2["data"]["both"], T=300)
def test_u_nk_restarted_all_windows_truncated(restarted_dataset_all_windows_truncated):
"""Test that when there is more than one lambda2 for a given lambda1, parsing throws an error."""
- with pytest.raises(ValueError, match='New window begun after truncated window'):
- u_nk = extract_u_nk(restarted_dataset_all_windows_truncated['data']['both'], T=300)
+ with pytest.raises(ValueError, match="New window begun after truncated window"):
+ u_nk = extract_u_nk(
+ restarted_dataset_all_windows_truncated["data"]["both"], T=300
+ )
def test_u_nk_restarted_last_window_truncated(restarted_dataset_last_window_truncated):
"""Test that when there is more than one lambda2 for a given lambda1, parsing throws an error."""
- with pytest.raises(ValueError, match='Last window is truncated'):
- u_nk = extract_u_nk(restarted_dataset_last_window_truncated['data']['both'], T=300)
+ with pytest.raises(ValueError, match="Last window is truncated"):
+ u_nk = extract_u_nk(
+ restarted_dataset_last_window_truncated["data"]["both"], T=300
+ )
diff --git a/src/alchemlyb/tests/parsing/test_util.py b/src/alchemlyb/tests/parsing/test_util.py
index 334ee0c2..85107d61 100644
--- a/src/alchemlyb/tests/parsing/test_util.py
+++ b/src/alchemlyb/tests/parsing/test_util.py
@@ -1,32 +1,29 @@
import io
-import pytest
+import pytest
from alchemtest.gmx import load_expanded_ensemble_case_1
+
from alchemlyb.parsing.util import anyopen
def test_gzip():
- """Test that gzip reads .gz files in the correct (text) mode.
-
- """
+ """Test that gzip reads .gz files in the correct (text) mode."""
dataset = load_expanded_ensemble_case_1()
- for leg in dataset['data']:
- for filename in dataset['data'][leg]:
- with anyopen(filename, 'r') as f:
+ for leg in dataset["data"]:
+ for filename in dataset["data"][leg]:
+ with anyopen(filename, "r") as f:
assert type(f.readline()) is str
def test_gzip_stream():
- """Test that `anyopen` reads streams with specified compression.
-
- """
+ """Test that `anyopen` reads streams with specified compression."""
dataset = load_expanded_ensemble_case_1()
- for leg in dataset['data']:
- for filename in dataset['data'][leg]:
- with open(filename, 'rb') as f:
- with anyopen(f, mode='r', compression='gzip') as f_uc:
+ for leg in dataset["data"]:
+ for filename in dataset["data"][leg]:
+ with open(filename, "rb") as f:
+ with anyopen(f, mode="r", compression="gzip") as f_uc:
assert type(f_uc.readline()) is str
@@ -37,11 +34,11 @@ def test_gzip_stream_wrong():
"""
dataset = load_expanded_ensemble_case_1()
- for leg in dataset['data']:
- for filename in dataset['data'][leg]:
- with open(filename, 'rb') as f:
- with anyopen(f, mode='r', compression='bzip2') as f_uc:
- with pytest.raises(OSError, match='Invalid data stream'):
+ for leg in dataset["data"]:
+ for filename in dataset["data"][leg]:
+ with open(filename, "rb") as f:
+ with anyopen(f, mode="r", compression="bzip2") as f_uc:
+ with pytest.raises(OSError, match="Invalid data stream"):
assert type(f_uc.readline()) is str
@@ -52,33 +49,30 @@ def test_gzip_stream_wrong_no_compression():
"""
dataset = load_expanded_ensemble_case_1()
- for leg in dataset['data']:
- for filename in dataset['data'][leg]:
- with open(filename, 'rb') as f:
- with anyopen(f, mode='r') as f_uc:
+ for leg in dataset["data"]:
+ for filename in dataset["data"][leg]:
+ with open(filename, "rb") as f:
+ with anyopen(f, mode="r") as f_uc:
assert type(f_uc.readline()) is bytes
-@pytest.mark.parametrize('extension', ['bz2', 'gz'])
+@pytest.mark.parametrize("extension", ["bz2", "gz"])
def test_file_roundtrip(extension, tmp_path):
- """Test that roundtripping write/read to a file works with `anyopen`.
-
- """
+ """Test that roundtripping write/read to a file works with `anyopen`."""
data = "my momma told me to pick the very best one and you are not it"
- filepath = tmp_path / f'testfile.txt.{extension}'
- with anyopen(filepath, mode='w') as f:
+ filepath = tmp_path / f"testfile.txt.{extension}"
+ with anyopen(filepath, mode="w") as f:
f.write(data)
- with anyopen(filepath, 'r') as f:
+ with anyopen(filepath, "r") as f:
data_out = f.read()
assert data_out == data
-@pytest.mark.parametrize('extension,compression',
- [('bz2', 'gzip'), ('gz', 'bzip2')])
+@pytest.mark.parametrize("extension,compression", [("bz2", "gzip"), ("gz", "bzip2")])
def test_file_roundtrip_force_compression(extension, compression, tmp_path):
"""Test that roundtripping write/read to a file works with `anyopen`,
in which we force compression despite different extension.
@@ -87,50 +81,45 @@ def test_file_roundtrip_force_compression(extension, compression, tmp_path):
data = "my momma told me to pick the very best one and you are not it"
- filepath = tmp_path / f'testfile.txt.{extension}'
- with anyopen(filepath, mode='w', compression=compression) as f:
+ filepath = tmp_path / f"testfile.txt.{extension}"
+ with anyopen(filepath, mode="w", compression=compression) as f:
f.write(data)
- with anyopen(filepath, 'r', compression=compression) as f:
+ with anyopen(filepath, "r", compression=compression) as f:
data_out = f.read()
assert data_out == data
-@pytest.mark.parametrize('compression', ['bzip2', 'gzip'])
+@pytest.mark.parametrize("compression", ["bzip2", "gzip"])
def test_stream_roundtrip(compression):
- """Test that roundtripping write/read to a stream works with `anyopen`
-
- """
+ """Test that roundtripping write/read to a stream works with `anyopen`"""
data = "my momma told me to pick the very best one and you are not it"
with io.BytesIO() as stream:
-
# write to stream
- with anyopen(stream, mode='w', compression=compression) as f:
+ with anyopen(stream, mode="w", compression=compression) as f:
f.write(data)
# start at the beginning
stream.seek(0)
# read from stream
- with anyopen(stream, 'r', compression=compression) as f:
+ with anyopen(stream, "r", compression=compression) as f:
data_out = f.read()
assert data_out == data
-def test_stream_unsupported_compression():
- """Test that we throw a ValueError when an unsupported compression is used.
- """
+def test_stream_unsupported_compression():
+ """Test that we throw a ValueError when an unsupported compression is used."""
- compression="fakez"
+ compression = "fakez"
data = b"my momma told me to pick the very best one and you are not it"
with io.BytesIO() as stream:
-
# write to stream
stream.write(data)
@@ -139,5 +128,5 @@ def test_stream_unsupported_compression():
# read from stream
with pytest.raises(ValueError):
- with anyopen(stream, 'r', compression=compression) as f:
+ with anyopen(stream, "r", compression=compression) as f:
data_out = f.read()
diff --git a/src/alchemlyb/tests/test_convergence.py b/src/alchemlyb/tests/test_convergence.py
index d0ffb2c2..32fa1eb9 100644
--- a/src/alchemlyb/tests/test_convergence.py
+++ b/src/alchemlyb/tests/test_convergence.py
@@ -2,8 +2,7 @@
import pandas as pd
import pytest
-from alchemlyb.convergence import forward_backward_convergence, \
- fwdrev_cumavg_Rc, A_c
+from alchemlyb.convergence import forward_backward_convergence, fwdrev_cumavg_Rc, A_c
from alchemlyb.convergence.convergence import _cummean
diff --git a/src/alchemlyb/tests/test_fep_estimators.py b/src/alchemlyb/tests/test_fep_estimators.py
index 05a2dd24..9d041eb8 100644
--- a/src/alchemlyb/tests/test_fep_estimators.py
+++ b/src/alchemlyb/tests/test_fep_estimators.py
@@ -13,7 +13,6 @@ class FEPestimatorMixin:
"""Mixin for all FEP Estimator test classes."""
def compare_delta_f(self, X_delta_f):
-
est = self.cls().fit(X_delta_f[0])
delta_f, d_delta_f = self.get_delta_f(est)
diff --git a/src/alchemlyb/tests/test_import.py b/src/alchemlyb/tests/test_import.py
index 50a5d933..ef467ec8 100644
--- a/src/alchemlyb/tests/test_import.py
+++ b/src/alchemlyb/tests/test_import.py
@@ -1,4 +1,5 @@
import alchemlyb
+
def test_name():
- assert alchemlyb.__name__ == 'alchemlyb'
+ assert alchemlyb.__name__ == "alchemlyb"
diff --git a/src/alchemlyb/tests/test_preprocessing.py b/src/alchemlyb/tests/test_preprocessing.py
index 5e2cc611..10085828 100644
--- a/src/alchemlyb/tests/test_preprocessing.py
+++ b/src/alchemlyb/tests/test_preprocessing.py
@@ -69,9 +69,9 @@ def test_unchanged(self, namd_idws):
# NAMD energy files only have dE for adjacent lambdas, this ensures
# that the slicer will not drop these rows as they have NaN values.
# Do the pre-processing as the u_nk are from all lambdas
- groups = namd_idws.groupby('fep-lambda')
+ groups = namd_idws.groupby("fep-lambda")
for key, group in groups:
- group = group[~group.index.duplicated(keep='first')]
+ group = group[~group.index.duplicated(keep="first")]
df = self.slicer(group, None, None, None)
assert len(df) == len(group)
@@ -252,7 +252,6 @@ def test_conservative(self, dataloader, size, conservative, request):
],
)
def test_raise_ValueError_for_mismatched_data(self, dataloader, end, step, request):
-
data = request.getfixturevalue(dataloader)
with pytest.raises(ValueError):
self.slicer(data, series=data[:end:step])
diff --git a/src/alchemlyb/tests/test_version.py b/src/alchemlyb/tests/test_version.py
index 4f2afc78..ddab2ab6 100644
--- a/src/alchemlyb/tests/test_version.py
+++ b/src/alchemlyb/tests/test_version.py
@@ -1,5 +1,6 @@
import alchemlyb
+
def test_version():
try:
version = alchemlyb.__version__
@@ -8,9 +9,10 @@ def test_version():
assert len(version) > 0
+
def test_version_get_versions():
import alchemlyb._version
+
version = alchemlyb._version.get_versions()
assert alchemlyb.__version__ == version["version"]
-
diff --git a/src/alchemlyb/tests/test_workflow.py b/src/alchemlyb/tests/test_workflow.py
index a4308145..cc31e611 100644
--- a/src/alchemlyb/tests/test_workflow.py
+++ b/src/alchemlyb/tests/test_workflow.py
@@ -1,11 +1,14 @@
+import os
+
+import pandas as pd
import pytest
+
from alchemlyb.workflows import base
-import pandas as pd
-import os
-class Test_automatic_base():
+
+class Test_automatic_base:
@staticmethod
- @pytest.fixture(scope='session')
+ @pytest.fixture(scope="session")
def workflow(tmp_path_factory):
outdir = tmp_path_factory.mktemp("out")
workflow = base.WorkflowBase(out=str(outdir))
@@ -13,9 +16,9 @@ def workflow(tmp_path_factory):
return workflow
def test_write(self, workflow):
- '''Patch the output directory to tmpdir'''
- workflow.result.to_pickle(os.path.join(workflow.out, 'result.pkl'))
- assert os.path.exists(os.path.join(workflow.out, 'result.pkl'))
+ """Patch the output directory to tmpdir"""
+ workflow.result.to_pickle(os.path.join(workflow.out, "result.pkl"))
+ assert os.path.exists(os.path.join(workflow.out, "result.pkl"))
def test_read(self, workflow):
assert len(workflow.u_nk_list) == 0
diff --git a/src/alchemlyb/tests/test_workflow_ABFE.py b/src/alchemlyb/tests/test_workflow_ABFE.py
index f282041b..e1d6a4f2 100644
--- a/src/alchemlyb/tests/test_workflow_ABFE.py
+++ b/src/alchemlyb/tests/test_workflow_ABFE.py
@@ -1,204 +1,235 @@
+import os
+
import numpy as np
import pytest
-import os
+from alchemtest.amber import load_bace_example
+from alchemtest.gmx import load_ABFE, load_benzene
from alchemlyb.workflows.abfe import ABFE
-from alchemtest.gmx import load_ABFE, load_benzene
-from alchemtest.amber import load_bace_example
-class Test_automatic_ABFE():
- '''Test the full automatic workflow for load_ABFE from alchemtest.gmx for
- three stage transformation.'''
+
+class Test_automatic_ABFE:
+ """Test the full automatic workflow for load_ABFE from alchemtest.gmx for
+ three stage transformation."""
@staticmethod
- @pytest.fixture(scope='session')
+ @pytest.fixture(scope="session")
def workflow(tmp_path_factory):
outdir = tmp_path_factory.mktemp("out")
- dir = os.path.dirname(load_ABFE()['data']['complex'][0])
- workflow = ABFE(units='kcal/mol', software='GROMACS', dir=dir,
- prefix='dhdl', suffix='xvg', T=310, outdirectory=str(outdir))
- workflow.run(skiptime=10, uncorr='dE', threshold=50,
- estimators=('MBAR', 'BAR', 'TI'), overlap='O_MBAR.pdf',
- breakdown=True, forwrev=10)
+ dir = os.path.dirname(load_ABFE()["data"]["complex"][0])
+ workflow = ABFE(
+ units="kcal/mol",
+ software="GROMACS",
+ dir=dir,
+ prefix="dhdl",
+ suffix="xvg",
+ T=310,
+ outdirectory=str(outdir),
+ )
+ workflow.run(
+ skiptime=10,
+ uncorr="dE",
+ threshold=50,
+ estimators=("MBAR", "BAR", "TI"),
+ overlap="O_MBAR.pdf",
+ breakdown=True,
+ forwrev=10,
+ )
return workflow
def test_read(self, workflow):
- '''test if the files has been loaded correctly.'''
+ """test if the files has been loaded correctly."""
assert len(workflow.u_nk_list) == 30
assert len(workflow.dHdl_list) == 30
assert all([len(u_nk) == 1001 for u_nk in workflow.u_nk_list])
assert all([len(dHdl) == 1001 for dHdl in workflow.dHdl_list])
def test_subsample(self, workflow):
- '''Test if the data has been shrinked by subsampling.'''
+ """Test if the data has been shrinked by subsampling."""
assert len(workflow.u_nk_sample_list) == 30
assert len(workflow.dHdl_sample_list) == 30
assert all([len(u_nk) < 1001 for u_nk in workflow.u_nk_sample_list])
assert all([len(dHdl) < 1001 for dHdl in workflow.dHdl_sample_list])
def test_estimator(self, workflow):
- '''Test if all three estimators have been used.'''
+ """Test if all three estimators have been used."""
assert len(workflow.estimator) == 3
- assert 'MBAR' in workflow.estimator
- assert 'TI' in workflow.estimator
- assert 'BAR' in workflow.estimator
+ assert "MBAR" in workflow.estimator
+ assert "TI" in workflow.estimator
+ assert "BAR" in workflow.estimator
def test_summary(self, workflow):
- '''Test if if the summary is right.'''
+ """Test if if the summary is right."""
summary = workflow.generate_result()
- assert np.isclose(summary['MBAR']['Stages']['TOTAL'], 21.8, 0.1)
- assert np.isclose(summary['TI']['Stages']['TOTAL'], 21.8, 0.1)
- assert np.isclose(summary['BAR']['Stages']['TOTAL'], 21.8, 0.1)
+ assert np.isclose(summary["MBAR"]["Stages"]["TOTAL"], 21.8, 0.1)
+ assert np.isclose(summary["TI"]["Stages"]["TOTAL"], 21.8, 0.1)
+ assert np.isclose(summary["BAR"]["Stages"]["TOTAL"], 21.8, 0.1)
def test_plot_O_MBAR(self, workflow):
- '''test if the O_MBAR.pdf has been plotted.'''
- assert os.path.isfile(os.path.join(workflow.out, 'O_MBAR.pdf'))
+ """test if the O_MBAR.pdf has been plotted."""
+ assert os.path.isfile(os.path.join(workflow.out, "O_MBAR.pdf"))
def test_plot_dhdl_TI(self, workflow):
- '''test if the dhdl_TI.pdf has been plotted.'''
- assert os.path.isfile(os.path.join(workflow.out, 'dhdl_TI.pdf'))
+ """test if the dhdl_TI.pdf has been plotted."""
+ assert os.path.isfile(os.path.join(workflow.out, "dhdl_TI.pdf"))
def test_plot_dF_state(self, workflow):
- '''test if the dF_state.pdf has been plotted.'''
- assert os.path.isfile(os.path.join(workflow.out, 'dF_state.pdf'))
- assert os.path.isfile(os.path.join(workflow.out, 'dF_state_long.pdf'))
+ """test if the dF_state.pdf has been plotted."""
+ assert os.path.isfile(os.path.join(workflow.out, "dF_state.pdf"))
+ assert os.path.isfile(os.path.join(workflow.out, "dF_state_long.pdf"))
def test_check_convergence(self, workflow):
- '''test if the dF_state.pdf has been plotted.'''
- assert os.path.isfile(os.path.join(workflow.out, 'dF_t.pdf'))
+ """test if the dF_state.pdf has been plotted."""
+ assert os.path.isfile(os.path.join(workflow.out, "dF_t.pdf"))
assert len(workflow.convergence) == 10
def test_estimator_method(self, workflow, monkeypatch):
- '''Test if the method keyword could be passed to the AutoMBAR estimator.'''
- monkeypatch.setattr(workflow, 'estimator',
- dict())
- workflow.estimate(estimators='MBAR', method='adaptive')
- assert 'MBAR' in workflow.estimator
+ """Test if the method keyword could be passed to the AutoMBAR estimator."""
+ monkeypatch.setattr(workflow, "estimator", dict())
+ workflow.estimate(estimators="MBAR", method="adaptive")
+ assert "MBAR" in workflow.estimator
def test_convergence_method(self, workflow, monkeypatch):
- '''Test if the method keyword could be passed to the AutoMBAR estimator from convergence.'''
- monkeypatch.setattr(workflow, 'convergence', None)
- workflow.check_convergence(2, estimator='MBAR', method='adaptive')
+ """Test if the method keyword could be passed to the AutoMBAR estimator from convergence."""
+ monkeypatch.setattr(workflow, "convergence", None)
+ workflow.check_convergence(2, estimator="MBAR", method="adaptive")
assert len(workflow.convergence) == 2
+
class Test_manual_ABFE(Test_automatic_ABFE):
- '''Test the manual workflow for load_ABFE from alchemtest.gmx for three
- stage transformation.'''
+ """Test the manual workflow for load_ABFE from alchemtest.gmx for three
+ stage transformation."""
@staticmethod
- @pytest.fixture(scope='session')
+ @pytest.fixture(scope="session")
def workflow(tmp_path_factory):
outdir = tmp_path_factory.mktemp("out")
- dir = os.path.dirname(load_ABFE()['data']['complex'][0])
- workflow = ABFE(software='GROMACS', dir=dir, prefix='dhdl',
- suffix='xvg', T=310, outdirectory=str(outdir))
- workflow.update_units('kcal/mol')
+ dir = os.path.dirname(load_ABFE()["data"]["complex"][0])
+ workflow = ABFE(
+ software="GROMACS",
+ dir=dir,
+ prefix="dhdl",
+ suffix="xvg",
+ T=310,
+ outdirectory=str(outdir),
+ )
+ workflow.update_units("kcal/mol")
workflow.read()
- workflow.preprocess(skiptime=10, uncorr='dE', threshold=50)
- workflow.estimate(estimators=('MBAR', 'BAR', 'TI'))
- workflow.plot_overlap_matrix(overlap='O_MBAR.pdf')
- workflow.plot_ti_dhdl(dhdl_TI='dhdl_TI.pdf')
- workflow.plot_dF_state(dF_state='dF_state.pdf')
- workflow.check_convergence(10, dF_t='dF_t.pdf')
+ workflow.preprocess(skiptime=10, uncorr="dE", threshold=50)
+ workflow.estimate(estimators=("MBAR", "BAR", "TI"))
+ workflow.plot_overlap_matrix(overlap="O_MBAR.pdf")
+ workflow.plot_ti_dhdl(dhdl_TI="dhdl_TI.pdf")
+ workflow.plot_dF_state(dF_state="dF_state.pdf")
+ workflow.check_convergence(10, dF_t="dF_t.pdf")
return workflow
def test_plot_dF_state(self, workflow):
- '''test if the dF_state.pdf has been plotted.'''
- assert os.path.isfile(os.path.join(workflow.out, 'dF_state.pdf'))
+ """test if the dF_state.pdf has been plotted."""
+ assert os.path.isfile(os.path.join(workflow.out, "dF_state.pdf"))
def test_convergence_nosample_u_nk(self, workflow, monkeypatch):
- '''test if the convergence routine would use the unsampled data
- when the data has not been subsampled.'''
- monkeypatch.setattr(workflow, 'u_nk_sample_list',
- None)
+ """test if the convergence routine would use the unsampled data
+ when the data has not been subsampled."""
+ monkeypatch.setattr(workflow, "u_nk_sample_list", None)
workflow.check_convergence(10)
assert len(workflow.convergence) == 10
def test_dhdl_TI_noTI(self, workflow, monkeypatch):
- '''Test to plot the dhdl_TI when ti estimator is not there'''
+ """Test to plot the dhdl_TI when ti estimator is not there"""
no_TI = workflow.estimator
- no_TI.pop('TI')
- monkeypatch.setattr(workflow, 'estimator',
- no_TI)
+ no_TI.pop("TI")
+ monkeypatch.setattr(workflow, "estimator", no_TI)
with pytest.raises(ValueError):
- workflow.plot_ti_dhdl(dhdl_TI='dhdl_TI.pdf')
+ workflow.plot_ti_dhdl(dhdl_TI="dhdl_TI.pdf")
def test_noMBAR_for_plot_overlap_matrix(self, workflow, monkeypatch):
- monkeypatch.setattr(workflow, 'estimator', {})
+ monkeypatch.setattr(workflow, "estimator", {})
assert workflow.plot_overlap_matrix() is None
def test_no_u_nk_for_check_convergence(self, workflow, monkeypatch):
- monkeypatch.setattr(workflow, 'u_nk_list', None)
- monkeypatch.setattr(workflow, 'u_nk_sample_list', None)
+ monkeypatch.setattr(workflow, "u_nk_list", None)
+ monkeypatch.setattr(workflow, "u_nk_sample_list", None)
with pytest.raises(ValueError):
- workflow.check_convergence(10, estimator='MBAR')
+ workflow.check_convergence(10, estimator="MBAR")
def test_no_dHdl_for_check_convergence(self, workflow, monkeypatch):
- monkeypatch.setattr(workflow, 'dHdl_list', None)
- monkeypatch.setattr(workflow, 'dHdl_sample_list', None)
+ monkeypatch.setattr(workflow, "dHdl_list", None)
+ monkeypatch.setattr(workflow, "dHdl_sample_list", None)
with pytest.raises(ValueError):
- workflow.check_convergence(10, estimator='TI')
+ workflow.check_convergence(10, estimator="TI")
def test_no_update_units(self, workflow):
assert workflow.update_units() is None
def test_no_name_estimate(self, workflow):
with pytest.raises(ValueError):
- workflow.estimate('aaa')
+ workflow.estimate("aaa")
-class Test_automatic_benzene():
- '''Test the full automatic workflow for load_benzene from alchemtest.gmx for
- single stage transformation.'''
+class Test_automatic_benzene:
+ """Test the full automatic workflow for load_benzene from alchemtest.gmx for
+ single stage transformation."""
@staticmethod
- @pytest.fixture(scope='session')
+ @pytest.fixture(scope="session")
def workflow(tmp_path_factory):
outdir = tmp_path_factory.mktemp("out")
- dir = os.path.dirname(os.path.dirname(
- load_benzene()['data']['Coulomb'][0]))
- dir = os.path.join(dir, '*')
- workflow = ABFE(units='kcal/mol', software='GROMACS', dir=dir,
- prefix='dhdl', suffix='bz2', T=310,
- outdirectory=outdir)
- workflow.run(skiptime=0, uncorr='dE', threshold=50,
- estimators=('MBAR', 'BAR', 'TI'), overlap='O_MBAR.pdf',
- breakdown=True, forwrev=10)
+ dir = os.path.dirname(os.path.dirname(load_benzene()["data"]["Coulomb"][0]))
+ dir = os.path.join(dir, "*")
+ workflow = ABFE(
+ units="kcal/mol",
+ software="GROMACS",
+ dir=dir,
+ prefix="dhdl",
+ suffix="bz2",
+ T=310,
+ outdirectory=outdir,
+ )
+ workflow.run(
+ skiptime=0,
+ uncorr="dE",
+ threshold=50,
+ estimators=("MBAR", "BAR", "TI"),
+ overlap="O_MBAR.pdf",
+ breakdown=True,
+ forwrev=10,
+ )
return workflow
def test_read(self, workflow):
- '''test if the files has been loaded correctly.'''
+ """test if the files has been loaded correctly."""
assert len(workflow.u_nk_list) == 5
assert len(workflow.dHdl_list) == 5
assert all([len(u_nk) == 4001 for u_nk in workflow.u_nk_list])
assert all([len(dHdl) == 4001 for dHdl in workflow.dHdl_list])
def test_estimator(self, workflow):
- '''Test if all three estimators have been used.'''
+ """Test if all three estimators have been used."""
assert len(workflow.estimator) == 3
- assert 'MBAR' in workflow.estimator
- assert 'TI' in workflow.estimator
- assert 'BAR' in workflow.estimator
+ assert "MBAR" in workflow.estimator
+ assert "TI" in workflow.estimator
+ assert "BAR" in workflow.estimator
def test_O_MBAR(self, workflow):
- '''test if the O_MBAR.pdf has been plotted.'''
- assert os.path.isfile(os.path.join(workflow.out, 'O_MBAR.pdf'))
+ """test if the O_MBAR.pdf has been plotted."""
+ assert os.path.isfile(os.path.join(workflow.out, "O_MBAR.pdf"))
def test_dhdl_TI(self, workflow):
- '''test if the dhdl_TI.pdf has been plotted.'''
- assert os.path.isfile(os.path.join(workflow.out, 'dhdl_TI.pdf'))
+ """test if the dhdl_TI.pdf has been plotted."""
+ assert os.path.isfile(os.path.join(workflow.out, "dhdl_TI.pdf"))
def test_dF_state(self, workflow):
- '''test if the dF_state.pdf has been plotted.'''
- assert os.path.isfile(os.path.join(workflow.out, 'dF_state.pdf'))
+ """test if the dF_state.pdf has been plotted."""
+ assert os.path.isfile(os.path.join(workflow.out, "dF_state.pdf"))
def test_convergence(self, workflow):
- '''test if the dF_state.pdf has been plotted.'''
- assert os.path.isfile(os.path.join(workflow.out, 'dF_t.pdf'))
+ """test if the dF_state.pdf has been plotted."""
+ assert os.path.isfile(os.path.join(workflow.out, "dF_t.pdf"))
assert len(workflow.convergence) == 10
-class Test_unpertubed_lambda():
- '''Test the if two lamdas present and one of them is not pertubed.
+
+class Test_unpertubed_lambda:
+ """Test the if two lamdas present and one of them is not pertubed.
fep bound
time fep-lambda bound-lambda
@@ -209,87 +240,118 @@ class Test_unpertubed_lambda():
40.0 0.5 0 7.768072 0
Where only fep-lambda changes but the bonded-lambda is always 0.
- '''
+ """
@staticmethod
- @pytest.fixture(scope='session')
+ @pytest.fixture(scope="session")
def workflow(tmp_path_factory):
outdir = tmp_path_factory.mktemp("out")
- dir = os.path.dirname(os.path.dirname(
- load_benzene()['data']['Coulomb'][0]))
- dir = os.path.join(dir, '*')
- workflow = ABFE(software='GROMACS', dir=dir, prefix='dhdl',
- suffix='bz2', T=310, outdirectory=outdir)
+ dir = os.path.dirname(os.path.dirname(load_benzene()["data"]["Coulomb"][0]))
+ dir = os.path.join(dir, "*")
+ workflow = ABFE(
+ software="GROMACS",
+ dir=dir,
+ prefix="dhdl",
+ suffix="bz2",
+ T=310,
+ outdirectory=outdir,
+ )
workflow.read()
# Block the n_uk
workflow.u_nk_list = []
# Add another lambda column
for dHdl in workflow.dHdl_list:
- dHdl.insert(1, 'bound-lambda', [1.0, ] * len(dHdl))
- dHdl.insert(1, 'bound', [1.0, ] * len(dHdl))
- dHdl.set_index('bound-lambda', append=True, inplace=True)
-
- workflow.estimate(estimators=('TI', ))
- workflow.plot_ti_dhdl(dhdl_TI='dhdl_TI.pdf')
- workflow.plot_dF_state(dF_state='dF_state.pdf')
- workflow.check_convergence(10, dF_t='dF_t.pdf', estimator='TI')
+ dHdl.insert(
+ 1,
+ "bound-lambda",
+ [
+ 1.0,
+ ]
+ * len(dHdl),
+ )
+ dHdl.insert(
+ 1,
+ "bound",
+ [
+ 1.0,
+ ]
+ * len(dHdl),
+ )
+ dHdl.set_index("bound-lambda", append=True, inplace=True)
+
+ workflow.estimate(estimators=("TI",))
+ workflow.plot_ti_dhdl(dhdl_TI="dhdl_TI.pdf")
+ workflow.plot_dF_state(dF_state="dF_state.pdf")
+ workflow.check_convergence(10, dF_t="dF_t.pdf", estimator="TI")
return workflow
def test_dhdl_TI(self, workflow):
- '''test if the dhdl_TI.pdf has been plotted.'''
- assert os.path.isfile(os.path.join(workflow.out, 'dhdl_TI.pdf'))
+ """test if the dhdl_TI.pdf has been plotted."""
+ assert os.path.isfile(os.path.join(workflow.out, "dhdl_TI.pdf"))
def test_dF_state(self, workflow):
- '''test if the dF_state.pdf has been plotted.'''
- assert os.path.isfile(os.path.join(workflow.out, 'dF_state.pdf'))
+ """test if the dF_state.pdf has been plotted."""
+ assert os.path.isfile(os.path.join(workflow.out, "dF_state.pdf"))
def test_convergence(self, workflow):
- '''test if the dF_state.pdf has been plotted.'''
- assert os.path.isfile(os.path.join(workflow.out, 'dF_t.pdf'))
+ """test if the dF_state.pdf has been plotted."""
+ assert os.path.isfile(os.path.join(workflow.out, "dF_t.pdf"))
assert len(workflow.convergence) == 10
def test_single_estimator_ti(self, workflow):
- workflow.estimate(estimators='TI')
+ workflow.estimate(estimators="TI")
summary = workflow.generate_result()
- assert np.isclose(summary['TI']['Stages']['TOTAL'], 2.946, 0.1)
+ assert np.isclose(summary["TI"]["Stages"]["TOTAL"], 2.946, 0.1)
+
-class Test_methods():
- '''Test various methods.'''
+class Test_methods:
+ """Test various methods."""
@staticmethod
- @pytest.fixture(scope='class')
+ @pytest.fixture(scope="class")
def workflow(tmp_path_factory):
outdir = tmp_path_factory.mktemp("out")
- dir = os.path.dirname(os.path.dirname(
- load_benzene()['data']['Coulomb'][0]))
- dir = os.path.join(dir, '*')
- workflow = ABFE(software='GROMACS', dir=dir, prefix='dhdl',
- suffix='bz2', T=310, outdirectory=outdir)
+ dir = os.path.dirname(os.path.dirname(load_benzene()["data"]["Coulomb"][0]))
+ dir = os.path.join(dir, "*")
+ workflow = ABFE(
+ software="GROMACS",
+ dir=dir,
+ prefix="dhdl",
+ suffix="bz2",
+ T=310,
+ outdirectory=outdir,
+ )
workflow.read()
return workflow
def test_run_none(self, workflow):
- '''Don't run anything'''
- workflow.run(uncorr=None, estimators=None, overlap=None, breakdown=None,
- forwrev=None)
+ """Don't run anything"""
+ workflow.run(
+ uncorr=None, estimators=None, overlap=None, breakdown=None, forwrev=None
+ )
def test_run_single_estimator(self, workflow, monkeypatch):
- monkeypatch.setattr(workflow, 'u_nk_list', [])
- monkeypatch.setattr(workflow, 'dHdl_list', [])
- workflow.run(uncorr=None, estimators='MBAR', overlap=None, breakdown=True,
- forwrev=None)
+ monkeypatch.setattr(workflow, "u_nk_list", [])
+ monkeypatch.setattr(workflow, "dHdl_list", [])
+ workflow.run(
+ uncorr=None, estimators="MBAR", overlap=None, breakdown=True, forwrev=None
+ )
def test_run_invalid_estimator(self, workflow):
- with pytest.raises(ValueError,
- match=r'Estimator aaa is not supported.'):
- workflow.run(uncorr=None, estimators='aaa', overlap=None, breakdown=None,
- forwrev=None)
-
- @pytest.mark.parametrize('read_u_nk', [True, False])
- @pytest.mark.parametrize('read_dHdl', [True, False])
+ with pytest.raises(ValueError, match=r"Estimator aaa is not supported."):
+ workflow.run(
+ uncorr=None,
+ estimators="aaa",
+ overlap=None,
+ breakdown=None,
+ forwrev=None,
+ )
+
+ @pytest.mark.parametrize("read_u_nk", [True, False])
+ @pytest.mark.parametrize("read_dHdl", [True, False])
def test_read_TI_FEP(self, workflow, monkeypatch, read_u_nk, read_dHdl):
- monkeypatch.setattr(workflow, 'u_nk_list', [])
- monkeypatch.setattr(workflow, 'dHdl_list', [])
+ monkeypatch.setattr(workflow, "u_nk_list", [])
+ monkeypatch.setattr(workflow, "dHdl_list", [])
workflow.read(read_u_nk, read_dHdl)
if read_u_nk:
assert len(workflow.u_nk_list) == 5
@@ -303,104 +365,112 @@ def test_read_TI_FEP(self, workflow, monkeypatch, read_u_nk, read_dHdl):
def test_read_invalid_u_nk(self, workflow, monkeypatch):
def extract_u_nk(self, T):
- raise IOError('Error read u_nk.')
- monkeypatch.setattr(workflow, '_extract_u_nk',
- extract_u_nk)
- with pytest.raises(OSError,
- match=r'Error reading u_nk .*dhdl\.xvg\.bz2'):
+ raise IOError("Error read u_nk.")
+
+ monkeypatch.setattr(workflow, "_extract_u_nk", extract_u_nk)
+ with pytest.raises(OSError, match=r"Error reading u_nk .*dhdl\.xvg\.bz2"):
workflow.read()
def test_read_invalid_dHdl(self, workflow, monkeypatch):
def extract_dHdl(self, T):
- raise IOError('Error read dHdl.')
- monkeypatch.setattr(workflow, '_extract_dHdl',
- extract_dHdl)
- with pytest.raises(OSError,
- match=r'Error reading dHdl .*dhdl\.xvg\.bz2'):
+ raise IOError("Error read dHdl.")
+
+ monkeypatch.setattr(workflow, "_extract_dHdl", extract_dHdl)
+ with pytest.raises(OSError, match=r"Error reading dHdl .*dhdl\.xvg\.bz2"):
workflow.read()
def test_uncorr_threshold(self, workflow, monkeypatch):
- '''Test if the full data will be used when the number of data points
- are less than the threshold.'''
- monkeypatch.setattr(workflow, 'u_nk_list',
- [u_nk[:40] for u_nk in workflow.u_nk_list])
- monkeypatch.setattr(workflow, 'dHdl_list',
- [dHdl[:40] for dHdl in workflow.dHdl_list])
+ """Test if the full data will be used when the number of data points
+ are less than the threshold."""
+ monkeypatch.setattr(
+ workflow, "u_nk_list", [u_nk[:40] for u_nk in workflow.u_nk_list]
+ )
+ monkeypatch.setattr(
+ workflow, "dHdl_list", [dHdl[:40] for dHdl in workflow.dHdl_list]
+ )
workflow.preprocess(threshold=50)
assert all([len(u_nk) == 40 for u_nk in workflow.u_nk_sample_list])
assert all([len(dHdl) == 40 for dHdl in workflow.dHdl_sample_list])
def test_no_u_nk_preprocess(self, workflow, monkeypatch):
- monkeypatch.setattr(workflow, 'u_nk_list', [])
+ monkeypatch.setattr(workflow, "u_nk_list", [])
workflow.preprocess(threshold=50)
assert len(workflow.u_nk_list) == 0
def test_no_dHdl_preprocess(self, workflow, monkeypatch):
- monkeypatch.setattr(workflow, 'dHdl_list', [])
+ monkeypatch.setattr(workflow, "dHdl_list", [])
workflow.preprocess(threshold=50)
assert len(workflow.dHdl_list) == 0
def test_single_estimator_mbar(self, workflow):
- workflow.estimate(estimators='MBAR')
+ workflow.estimate(estimators="MBAR")
assert len(workflow.estimator) == 1
- assert 'MBAR' in workflow.estimator
+ assert "MBAR" in workflow.estimator
summary = workflow.generate_result()
- assert np.isclose(summary['MBAR']['Stages']['TOTAL'], 2.946, 0.1)
+ assert np.isclose(summary["MBAR"]["Stages"]["TOTAL"], 2.946, 0.1)
def test_single_estimator_ti(self, workflow):
- workflow.estimate(estimators='TI')
+ workflow.estimate(estimators="TI")
summary = workflow.generate_result()
- assert np.isclose(summary['TI']['Stages']['TOTAL'], 2.946, 0.1)
+ assert np.isclose(summary["TI"]["Stages"]["TOTAL"], 2.946, 0.1)
def test_bar_convergence(self, workflow):
- workflow.check_convergence(10, estimator='BAR')
+ workflow.check_convergence(10, estimator="BAR")
assert len(workflow.convergence) == 10
def test_convergence_invalid_estimator(self, workflow):
with pytest.raises(ValueError):
- workflow.check_convergence(10, estimator='aaa')
+ workflow.check_convergence(10, estimator="aaa")
def test_ti_convergence(self, workflow):
- workflow.check_convergence(10, estimator='TI')
+ workflow.check_convergence(10, estimator="TI")
assert len(workflow.convergence) == 10
def test_unprocessed_n_uk(self, workflow, monkeypatch):
- monkeypatch.setattr(workflow, 'u_nk_sample_list',
- None)
+ monkeypatch.setattr(workflow, "u_nk_sample_list", None)
workflow.estimate()
assert len(workflow.estimator) == 3
- assert 'MBAR' in workflow.estimator
+ assert "MBAR" in workflow.estimator
def test_unprocessed_dhdl(self, workflow, monkeypatch):
- monkeypatch.setattr(workflow, 'dHdl_sample_list',
- None)
- workflow.check_convergence(10, estimator='TI')
+ monkeypatch.setattr(workflow, "dHdl_sample_list", None)
+ workflow.check_convergence(10, estimator="TI")
assert len(workflow.convergence) == 10
-class Test_automatic_amber():
- '''Test the full automatic workflow for load_ABFE from alchemtest.amber for
- three stage transformation.'''
+
+class Test_automatic_amber:
+ """Test the full automatic workflow for load_ABFE from alchemtest.amber for
+ three stage transformation."""
@staticmethod
- @pytest.fixture(scope='session')
+ @pytest.fixture(scope="session")
def workflow(tmp_path_factory):
outdir = tmp_path_factory.mktemp("out")
dir, _ = os.path.split(
- os.path.dirname(load_bace_example()['data']['complex']['vdw'][0]))
-
- workflow = ABFE(units='kcal/mol', software='AMBER', dir=dir,
- prefix='ti', suffix='bz2', T=298.0, outdirectory=str(
- outdir))
+ os.path.dirname(load_bace_example()["data"]["complex"]["vdw"][0])
+ )
+
+ workflow = ABFE(
+ units="kcal/mol",
+ software="AMBER",
+ dir=dir,
+ prefix="ti",
+ suffix="bz2",
+ T=298.0,
+ outdirectory=str(outdir),
+ )
workflow.read()
- workflow.estimate(estimators='TI')
+ workflow.estimate(estimators="TI")
return workflow
def test_summary(self, workflow):
- '''Test if if the summary is right.'''
+ """Test if if the summary is right."""
summary = workflow.generate_result()
- assert np.isclose(summary['TI']['Stages']['TOTAL'], 1.40405980473, 0.1)
+ assert np.isclose(summary["TI"]["Stages"]["TOTAL"], 1.40405980473, 0.1)
+
def test_no_parser():
with pytest.raises(NotImplementedError):
- workflow = ABFE(units='kcal/mol', software='aaa',
- prefix='ti', suffix='bz2', T=298.0)
+ workflow = ABFE(
+ units="kcal/mol", software="aaa", prefix="ti", suffix="bz2", T=298.0
+ )
diff --git a/src/alchemlyb/visualisation/__init__.py b/src/alchemlyb/visualisation/__init__.py
index d58b367e..6955dcaf 100644
--- a/src/alchemlyb/visualisation/__init__.py
+++ b/src/alchemlyb/visualisation/__init__.py
@@ -1,4 +1,4 @@
+from .convergence import plot_convergence
+from .dF_state import plot_dF_state
from .mbar_matrix import plot_mbar_overlap_matrix
from .ti_dhdl import plot_ti_dhdl
-from .dF_state import plot_dF_state
-from .convergence import plot_convergence
\ No newline at end of file
diff --git a/src/alchemlyb/visualisation/convergence.py b/src/alchemlyb/visualisation/convergence.py
index fcef3e50..c1cc477a 100644
--- a/src/alchemlyb/visualisation/convergence.py
+++ b/src/alchemlyb/visualisation/convergence.py
@@ -1,92 +1,91 @@
import matplotlib.pyplot as plt
-import pandas as pd
-from matplotlib.font_manager import FontProperties as FP
import numpy as np
+from matplotlib.font_manager import FontProperties as FP
from ..postprocessors.units import get_unit_converter
+
def plot_convergence(dataframe, units=None, final_error=None, ax=None):
"""Plot the forward and backward convergence.
- The input could be the result from
- :func:`~alchemlyb.convergence.forward_backward_convergence` or
- :func:`~alchemlyb.convergence.fwdrev_cumavg_Rc`. The input should be a
- :class:`pandas.DataFrame` which has column `Forward`, `Backward` and
- :attr:`pandas.DataFrame.attrs` should compile with :ref:`note-on-units`.
- The errorbar will be plotted if column `Forward_Error` and `Backward_Error`
- is present.
-
- `Forward`: A column of free energy estimate from the first X% of data,
- where optional `Forward_Error` column is the corresponding error.
-
- `Backward`: A column of free energy estimate from the last X% of data.,
- where optional `Backward_Error` column is the corresponding error.
-
- `final_error` is the error of the final value and is shown as the error band around the
- final value. It can be provided in case an estimate is available that is more appropriate
- than the default, which is the error of the last value in `Backward`.
-
- Parameters
- ----------
- dataframe : Dataframe
- Output Dataframe has column `Forward`, `Backward` or optionally
- `Forward_Error`, `Backward_Error` see :ref:`plot_convergence `.
- units : str
- The unit of the estimate. The default is `None`, which is to use the
- unit in the input. Setting this will change the output unit.
- final_error : float
- The error of the final value in ``units``. If not given, takes the last
- error in `backward_error`.
- ax : matplotlib.axes.Axes
- Matplotlib axes object where the plot will be drawn on. If ``ax=None``,
- a new axes will be generated.
-
- Returns
- -------
- matplotlib.axes.Axes
- An axes with the forward and backward convergence drawn.
-
- Note
- ----
- The code is taken and modified from
- `Alchemical Analysis `_.
-
-
- .. versionchanged:: 1.0.0
- Keyword arg final_error for plotting a horizontal error bar.
- The array input has been deprecated.
- The units default to `None` which uses the units in the input.
-
- .. versionchanged:: 0.6.0
- data now takes in dataframe
-
- .. versionadded:: 0.4.0
+ The input could be the result from
+ :func:`~alchemlyb.convergence.forward_backward_convergence` or
+ :func:`~alchemlyb.convergence.fwdrev_cumavg_Rc`. The input should be a
+ :class:`pandas.DataFrame` which has column `Forward`, `Backward` and
+ :attr:`pandas.DataFrame.attrs` should compile with :ref:`note-on-units`.
+ The errorbar will be plotted if column `Forward_Error` and `Backward_Error`
+ is present.
+
+ `Forward`: A column of free energy estimate from the first X% of data,
+ where optional `Forward_Error` column is the corresponding error.
+
+ `Backward`: A column of free energy estimate from the last X% of data.,
+ where optional `Backward_Error` column is the corresponding error.
+
+ `final_error` is the error of the final value and is shown as the error band around the
+ final value. It can be provided in case an estimate is available that is more appropriate
+ than the default, which is the error of the last value in `Backward`.
+
+ Parameters
+ ----------
+ dataframe : Dataframe
+ Output Dataframe has column `Forward`, `Backward` or optionally
+ `Forward_Error`, `Backward_Error` see :ref:`plot_convergence `.
+ units : str
+ The unit of the estimate. The default is `None`, which is to use the
+ unit in the input. Setting this will change the output unit.
+ final_error : float
+ The error of the final value in ``units``. If not given, takes the last
+ error in `backward_error`.
+ ax : matplotlib.axes.Axes
+ Matplotlib axes object where the plot will be drawn on. If ``ax=None``,
+ a new axes will be generated.
+
+ Returns
+ -------
+ matplotlib.axes.Axes
+ An axes with the forward and backward convergence drawn.
+
+ Note
+ ----
+ The code is taken and modified from
+ `Alchemical Analysis `_.
+
+
+ .. versionchanged:: 1.0.0
+ Keyword arg final_error for plotting a horizontal error bar.
+ The array input has been deprecated.
+ The units default to `None` which uses the units in the input.
+
+ .. versionchanged:: 0.6.0
+ data now takes in dataframe
+
+ .. versionadded:: 0.4.0
"""
if units is not None:
dataframe = get_unit_converter(units)(dataframe)
- forward = dataframe['Forward'].to_numpy()
- if 'Forward_Error' in dataframe:
- forward_error = dataframe['Forward_Error'].to_numpy()
+ forward = dataframe["Forward"].to_numpy()
+ if "Forward_Error" in dataframe:
+ forward_error = dataframe["Forward_Error"].to_numpy()
else:
forward_error = np.zeros(len(forward))
- backward = dataframe['Backward'].to_numpy()
- if 'Backward_Error' in dataframe:
- backward_error = dataframe['Backward_Error'].to_numpy()
+ backward = dataframe["Backward"].to_numpy()
+ if "Backward_Error" in dataframe:
+ backward_error = dataframe["Backward_Error"].to_numpy()
else:
backward_error = np.zeros(len(backward))
-
- if ax is None: # pragma: no cover
+ if ax is None: # pragma: no cover
fig, ax = plt.subplots(figsize=(8, 6))
- plt.setp(ax.spines['bottom'], color='#D2B9D3', lw=3, zorder=-2)
- plt.setp(ax.spines['left'], color='#D2B9D3', lw=3, zorder=-2)
+ plt.setp(ax.spines["bottom"], color="#D2B9D3", lw=3, zorder=-2)
+ plt.setp(ax.spines["left"], color="#D2B9D3", lw=3, zorder=-2)
- for dire in ['top', 'right']:
- ax.spines[dire].set_color('none')
+ for dire in ["top", "right"]:
+ ax.spines[dire].set_color("none")
- ax.xaxis.set_ticks_position('bottom')
- ax.yaxis.set_ticks_position('left')
+ ax.xaxis.set_ticks_position("bottom")
+ ax.yaxis.set_ticks_position("left")
f_ts = np.linspace(0, 1, len(forward) + 1)[1:]
r_ts = np.linspace(0, 1, len(backward) + 1)[1:]
@@ -94,28 +93,54 @@ def plot_convergence(dataframe, units=None, final_error=None, ax=None):
if final_error is None:
final_error = backward_error[-1]
- line0 = ax.fill_between([0, 1], backward[-1] - final_error,
- backward[-1] + final_error, color='#D2B9D3',
- zorder=1)
- line1 = ax.errorbar(f_ts, forward, yerr=forward_error, color='#736AFF',
- lw=3, zorder=2, marker='o',
- mfc='w', mew=2.5, mec='#736AFF', ms=12,)
- line2 = ax.errorbar(r_ts, backward, yerr=backward_error, color='#C11B17',
- lw=3, zorder=3, marker='o',
- mfc='w', mew=2.5, mec='#C11B17', ms=12, )
+ line0 = ax.fill_between(
+ [0, 1],
+ backward[-1] - final_error,
+ backward[-1] + final_error,
+ color="#D2B9D3",
+ zorder=1,
+ )
+ line1 = ax.errorbar(
+ f_ts,
+ forward,
+ yerr=forward_error,
+ color="#736AFF",
+ lw=3,
+ zorder=2,
+ marker="o",
+ mfc="w",
+ mew=2.5,
+ mec="#736AFF",
+ ms=12,
+ )
+ line2 = ax.errorbar(
+ r_ts,
+ backward,
+ yerr=backward_error,
+ color="#C11B17",
+ lw=3,
+ zorder=3,
+ marker="o",
+ mfc="w",
+ mew=2.5,
+ mec="#C11B17",
+ ms=12,
+ )
xticks_spacing = len(r_ts) // 10 or 1
xticks = r_ts[::xticks_spacing]
- plt.xticks(xticks, ['%.2f' % i for i in xticks], fontsize=10)
+ plt.xticks(xticks, ["%.2f" % i for i in xticks], fontsize=10)
plt.yticks(fontsize=10)
- ax.legend((line1[0], line2[0]), ('Forward', 'Reverse'), loc=9,
- prop=FP(size=18), frameon=False)
- ax.set_xlabel(r'Fraction of the simulation time', fontsize=16,
- color='#151B54')
- ax.set_ylabel(r'$\Delta G$ ({})'.format(units), fontsize=16, color='#151B54')
- plt.tick_params(axis='x', color='#D2B9D3')
- plt.tick_params(axis='y', color='#D2B9D3')
+ ax.legend(
+ (line1[0], line2[0]),
+ ("Forward", "Reverse"),
+ loc=9,
+ prop=FP(size=18),
+ frameon=False,
+ )
+ ax.set_xlabel(r"Fraction of the simulation time", fontsize=16, color="#151B54")
+ ax.set_ylabel(r"$\Delta G$ ({})".format(units), fontsize=16, color="#151B54")
+ plt.tick_params(axis="x", color="#D2B9D3")
+ plt.tick_params(axis="y", color="#D2B9D3")
return ax
-
-
diff --git a/src/alchemlyb/visualisation/dF_state.py b/src/alchemlyb/visualisation/dF_state.py
index 8f5a1409..e36fbc21 100644
--- a/src/alchemlyb/visualisation/dF_state.py
+++ b/src/alchemlyb/visualisation/dF_state.py
@@ -9,15 +9,17 @@
"""
import matplotlib.pyplot as plt
-from matplotlib.font_manager import FontProperties as FP
import numpy as np
+from matplotlib.font_manager import FontProperties as FP
from ..estimators import TI, BAR, MBAR
from ..postprocessors.units import get_unit_converter
-def plot_dF_state(estimators, labels=None, colors=None, units=None,
- orientation='portrait', nb=10):
- '''Plot the dhdl of TI.
+
+def plot_dF_state(
+ estimators, labels=None, colors=None, units=None, orientation="portrait", nb=10
+):
+ """Plot the dhdl of TI.
Parameters
----------
@@ -57,11 +59,13 @@ def plot_dF_state(estimators, labels=None, colors=None, units=None,
changing the figure legend.
.. versionadded:: 0.4.0
- '''
+ """
try:
len(estimators)
except TypeError:
- estimators = [estimators, ]
+ estimators = [
+ estimators,
+ ]
formatted_data = []
for dhdl in estimators:
@@ -69,10 +73,14 @@ def plot_dF_state(estimators, labels=None, colors=None, units=None,
len(dhdl)
formatted_data.append(dhdl)
except TypeError:
- formatted_data.append([dhdl, ])
+ formatted_data.append(
+ [
+ dhdl,
+ ]
+ )
if units is None:
- units = formatted_data[0][0].delta_f_.attrs['energy_unit']
+ units = formatted_data[0][0].delta_f_.attrs["energy_unit"]
estimators = formatted_data
@@ -96,47 +104,69 @@ def plot_dF_state(estimators, labels=None, colors=None, units=None,
error_list.append(error)
# Get the determine orientation
- if orientation == 'landscape':
+ if orientation == "landscape":
if max_length < 8:
fig, ax = plt.subplots(figsize=(8, 6))
else:
fig, ax = plt.subplots(figsize=(max_length, 6))
- axs = [ax, ]
- xs = [np.arange(max_length), ]
- elif orientation == 'portrait':
+ axs = [
+ ax,
+ ]
+ xs = [
+ np.arange(max_length),
+ ]
+ elif orientation == "portrait":
if max_length < nb:
- xs = [np.arange(max_length), ]
+ xs = [
+ np.arange(max_length),
+ ]
fig, ax = plt.subplots(figsize=(8, 6))
- axs = [ax, ]
+ axs = [
+ ax,
+ ]
else:
xs = np.array_split(np.arange(max_length), max_length / nb + 1)
fig, axs = plt.subplots(nrows=len(xs), figsize=(8, 6))
mnb = max([len(i) for i in xs])
else:
- raise ValueError("Not recognising {}, only supports 'landscape' or 'portrait'.".format(orientation))
+ raise ValueError(
+ "Not recognising {}, only supports 'landscape' or 'portrait'.".format(
+ orientation
+ )
+ )
# Sort out the colors
if colors is None:
- colors_dict = {'TI': '#C45AEC', 'TI-CUBIC': '#33CC33',
- 'DEXP': '#F87431', 'IEXP': '#FF3030', 'GINS': '#EAC117',
- 'GDEL': '#347235', 'BAR': '#6698FF', 'UBAR': '#817339',
- 'RBAR': '#C11B17', 'MBAR': '#F9B7FF'}
+ colors_dict = {
+ "TI": "#C45AEC",
+ "TI-CUBIC": "#33CC33",
+ "DEXP": "#F87431",
+ "IEXP": "#FF3030",
+ "GINS": "#EAC117",
+ "GDEL": "#347235",
+ "BAR": "#6698FF",
+ "UBAR": "#817339",
+ "RBAR": "#C11B17",
+ "MBAR": "#F9B7FF",
+ }
colors = []
for dhdl in estimators:
dhdl = dhdl[0]
if isinstance(dhdl, TI):
- colors.append(colors_dict['TI'])
+ colors.append(colors_dict["TI"])
elif isinstance(dhdl, BAR):
- colors.append(colors_dict['BAR'])
+ colors.append(colors_dict["BAR"])
elif isinstance(dhdl, MBAR):
- colors.append(colors_dict['MBAR'])
+ colors.append(colors_dict["MBAR"])
else:
if len(colors) >= len(estimators):
pass
else:
raise ValueError(
- 'Number of colors ({}) should be larger than the number of data ({})'.format(
- len(colors), len(estimators)))
+ "Number of colors ({}) should be larger than the number of data ({})".format(
+ len(colors), len(estimators)
+ )
+ )
# Sort out the labels
if labels is None:
@@ -144,21 +174,23 @@ def plot_dF_state(estimators, labels=None, colors=None, units=None,
for dhdl in estimators:
dhdl = dhdl[0]
if isinstance(dhdl, TI):
- labels.append('TI')
+ labels.append("TI")
elif isinstance(dhdl, BAR):
- labels.append('BAR')
+ labels.append("BAR")
elif isinstance(dhdl, MBAR):
- labels.append('MBAR')
+ labels.append("MBAR")
else:
if len(labels) == len(estimators):
pass
else:
raise ValueError(
- 'Length of labels ({}) should be the same as the number of data ({})'.format(
- len(labels), len(estimators)))
+ "Length of labels ({}) should be the same as the number of data ({})".format(
+ len(labels), len(estimators)
+ )
+ )
# Plot the figure
- width = 1. / (len(estimators) + 1)
+ width = 1.0 / (len(estimators) + 1)
elw = 30 * width
ndx = 1
for x, ax in zip(xs, axs):
@@ -166,35 +198,49 @@ def plot_dF_state(estimators, labels=None, colors=None, units=None,
for i, (dF, error) in enumerate(zip(dF_list, error_list)):
y = [dF[j] for j in x]
ye = [error[j] for j in x]
- if orientation == 'landscape':
+ if orientation == "landscape":
lw = 0.1 * elw
- elif orientation == 'portrait':
+ elif orientation == "portrait":
lw = 0.05 * elw
- line = ax.bar(x + len(lines) * width, y, width,
- color=colors[i], yerr=ye, lw=lw,
- error_kw=dict(elinewidth=elw, ecolor='black',
- capsize=0.5 * elw))
+ line = ax.bar(
+ x + len(lines) * width,
+ y,
+ width,
+ color=colors[i],
+ yerr=ye,
+ lw=lw,
+ error_kw=dict(elinewidth=elw, ecolor="black", capsize=0.5 * elw),
+ )
lines += (line[0],)
- for dir in ['left', 'right', 'top', 'bottom']:
- if dir == 'left':
+ for dir in ["left", "right", "top", "bottom"]:
+ if dir == "left":
ax.yaxis.set_ticks_position(dir)
else:
- ax.spines[dir].set_color('none')
+ ax.spines[dir].set_color("none")
- if orientation == 'landscape':
+ if orientation == "landscape":
plt.yticks(fontsize=8)
- ax.set_xlim(x[0]-width, x[-1] + len(lines) * width)
- plt.xticks(x + 0.5 * width * len(estimators),
- tuple(['%d--%d' % (i, i + 1) for i in x]), fontsize=8)
- elif orientation == 'portrait':
+ ax.set_xlim(x[0] - width, x[-1] + len(lines) * width)
+ plt.xticks(
+ x + 0.5 * width * len(estimators),
+ tuple(["%d--%d" % (i, i + 1) for i in x]),
+ fontsize=8,
+ )
+ elif orientation == "portrait":
plt.yticks(fontsize=10)
ax.xaxis.set_ticks([])
for i in x + 0.5 * width * len(estimators):
- ax.annotate(r'$\mathrm{%d-%d}$' % (i, i + 1), xy=(i, 0),
- xycoords=('data', 'axes fraction'), xytext=(0, -2),
- size=10, textcoords='offset points', va='top',
- ha='center')
- ax.set_xlim(x[0]-width, x[-1]+len(lines)*width + (mnb - len(x)))
+ ax.annotate(
+ r"$\mathrm{%d-%d}$" % (i, i + 1),
+ xy=(i, 0),
+ xycoords=("data", "axes fraction"),
+ xytext=(0, -2),
+ size=10,
+ textcoords="offset points",
+ va="top",
+ ha="center",
+ )
+ ax.set_xlim(x[0] - width, x[-1] + len(lines) * width + (mnb - len(x)))
ndx += 1
x = np.arange(max_length)
@@ -202,18 +248,21 @@ def plot_dF_state(estimators, labels=None, colors=None, units=None,
for tick in ax.get_xticklines():
tick.set_visible(False)
- if orientation == 'landscape':
- leg = plt.legend(lines, labels, loc=3, ncol=2, prop=FP(size=10),
- fancybox=True)
- plt.title('The free energy change breakdown', fontsize=12)
- plt.xlabel('States', fontsize=12, color='#151B54')
- plt.ylabel(r'$\Delta G$ ({})'.format(units), fontsize=12, color='#151B54')
- elif orientation == 'portrait':
- leg = ax.legend(lines, labels, loc=0, ncol=2,
- prop=FP(size=8),
- title=r'$\Delta G$ ({})'.format(units) +
- r'$\mathit{vs.}$ lambda pair',
- fancybox=True)
+ if orientation == "landscape":
+ leg = plt.legend(lines, labels, loc=3, ncol=2, prop=FP(size=10), fancybox=True)
+ plt.title("The free energy change breakdown", fontsize=12)
+ plt.xlabel("States", fontsize=12, color="#151B54")
+ plt.ylabel(r"$\Delta G$ ({})".format(units), fontsize=12, color="#151B54")
+ elif orientation == "portrait":
+ leg = ax.legend(
+ lines,
+ labels,
+ loc=0,
+ ncol=2,
+ prop=FP(size=8),
+ title=r"$\Delta G$ ({})".format(units) + r"$\mathit{vs.}$ lambda pair",
+ fancybox=True,
+ )
leg.get_frame().set_alpha(0.5)
return fig
diff --git a/src/alchemlyb/visualisation/mbar_matrix.py b/src/alchemlyb/visualisation/mbar_matrix.py
index 4b2bd952..6bdc068e 100644
--- a/src/alchemlyb/visualisation/mbar_matrix.py
+++ b/src/alchemlyb/visualisation/mbar_matrix.py
@@ -13,8 +13,9 @@
import matplotlib.pyplot as plt
import numpy as np
+
def plot_mbar_overlap_matrix(matrix, skip_lambda_index=[], ax=None):
- '''Plot the MBAR overlap matrix.
+ """Plot the MBAR overlap matrix.
Parameters
----------
@@ -41,7 +42,7 @@ def plot_mbar_overlap_matrix(matrix, skip_lambda_index=[], ax=None):
.. versionadded:: 0.4.0
- '''
+ """
# Compute the size of the figure, if ax is not given.
max_prob = matrix.max()
size = len(matrix)
@@ -49,25 +50,36 @@ def plot_mbar_overlap_matrix(matrix, skip_lambda_index=[], ax=None):
fig, ax = plt.subplots(figsize=(size / 2, size / 2))
ax.set_xticks([])
ax.set_yticks([])
- ax.axis('off')
+ ax.axis("off")
for i in range(size):
if i != 0:
- ax.axvline(x=i, ls='-', lw=0.5, color='k', alpha=0.25)
- ax.axhline(y=i, ls='-', lw=0.5, color='k', alpha=0.25)
+ ax.axvline(x=i, ls="-", lw=0.5, color="k", alpha=0.25)
+ ax.axhline(y=i, ls="-", lw=0.5, color="k", alpha=0.25)
for j in range(size):
if matrix[j, i] < 0.005:
- ii = ''
+ ii = ""
elif matrix[j, i] > 0.995:
- ii = '1.00'
+ ii = "1.00"
else:
- ii = ("{:.2f}".format(matrix[j, i])[1:])
+ ii = "{:.2f}".format(matrix[j, i])[1:]
alf = matrix[j, i] / max_prob
- ax.fill_between([i, i + 1], [size - j, size - j],
- [size - (j + 1), size - (j + 1)], color='k',
- alpha=alf)
- ax.annotate(ii, xy=(i, j), xytext=(i + 0.5, size - (j + 0.5)),
- size=8, textcoords='data', va='center',
- ha='center', color=('k' if alf < 0.5 else 'w'))
+ ax.fill_between(
+ [i, i + 1],
+ [size - j, size - j],
+ [size - (j + 1), size - (j + 1)],
+ color="k",
+ alpha=alf,
+ )
+ ax.annotate(
+ ii,
+ xy=(i, j),
+ xytext=(i + 0.5, size - (j + 0.5)),
+ size=8,
+ textcoords="data",
+ va="center",
+ ha="center",
+ color=("k" if alf < 0.5 else "w"),
+ )
if skip_lambda_index:
ks = [int(l) for l in skip_lambda_index]
@@ -75,31 +87,48 @@ def plot_mbar_overlap_matrix(matrix, skip_lambda_index=[], ax=None):
else:
ks = range(size)
for i in range(size):
- ax.annotate(ks[i], xy=(i + 0.5, 1), xytext=(i + 0.5, size + 0.5),
- size=10, textcoords=('data', 'data'),
- va='center', ha='center', color='k')
- ax.annotate(ks[i], xy=(-0.5, size - (size - 0.5)),
- xytext=(-0.5, size - (i + 0.5)),
- size=10, textcoords=('data', 'data'),
- va='center', ha='center', color='k')
- ax.annotate(r'$\lambda$', xy=(-0.5, size - (size - 0.5)),
- xytext=(-0.5, size + 0.5),
- size=10, textcoords=('data', 'data'),
- va='center', ha='center', color='k')
- ax.plot([0, size], [0, 0], 'k-', lw=4.0, solid_capstyle='butt')
- ax.plot([size, size], [0, size], 'k-', lw=4.0, solid_capstyle='butt')
- ax.plot([0, 0], [0, size], 'k-', lw=2.0, solid_capstyle='butt')
- ax.plot([0, size], [size, size], 'k-', lw=2.0, solid_capstyle='butt')
+ ax.annotate(
+ ks[i],
+ xy=(i + 0.5, 1),
+ xytext=(i + 0.5, size + 0.5),
+ size=10,
+ textcoords=("data", "data"),
+ va="center",
+ ha="center",
+ color="k",
+ )
+ ax.annotate(
+ ks[i],
+ xy=(-0.5, size - (size - 0.5)),
+ xytext=(-0.5, size - (i + 0.5)),
+ size=10,
+ textcoords=("data", "data"),
+ va="center",
+ ha="center",
+ color="k",
+ )
+ ax.annotate(
+ r"$\lambda$",
+ xy=(-0.5, size - (size - 0.5)),
+ xytext=(-0.5, size + 0.5),
+ size=10,
+ textcoords=("data", "data"),
+ va="center",
+ ha="center",
+ color="k",
+ )
+ ax.plot([0, size], [0, 0], "k-", lw=4.0, solid_capstyle="butt")
+ ax.plot([size, size], [0, size], "k-", lw=4.0, solid_capstyle="butt")
+ ax.plot([0, 0], [0, size], "k-", lw=2.0, solid_capstyle="butt")
+ ax.plot([0, size], [size, size], "k-", lw=2.0, solid_capstyle="butt")
cx = np.repeat(range(size + 1), 2)
cy = sorted(np.repeat(range(size + 1), 2), reverse=True)
- ax.plot(cx[2:-1], cy[1:-2], 'k-', lw=2.0)
- ax.plot(np.array(cx[2:-3]) + 1, cy[1:-4], 'k-', lw=2.0)
- ax.plot(cx[1:-2], np.array(cy[:-3]) - 1, 'k-', lw=2.0)
- ax.plot(cx[1:-4], np.array(cy[:-5]) - 2, 'k-', lw=2.0)
+ ax.plot(cx[2:-1], cy[1:-2], "k-", lw=2.0)
+ ax.plot(np.array(cx[2:-3]) + 1, cy[1:-4], "k-", lw=2.0)
+ ax.plot(cx[1:-2], np.array(cy[:-3]) - 1, "k-", lw=2.0)
+ ax.plot(cx[1:-4], np.array(cy[:-5]) - 2, "k-", lw=2.0)
ax.set_xlim(-1, size)
ax.set_ylim(0, size + 1)
return ax
-
-
diff --git a/src/alchemlyb/visualisation/ti_dhdl.py b/src/alchemlyb/visualisation/ti_dhdl.py
index c071a97d..6dacb6dc 100644
--- a/src/alchemlyb/visualisation/ti_dhdl.py
+++ b/src/alchemlyb/visualisation/ti_dhdl.py
@@ -10,14 +10,14 @@
"""
import matplotlib.pyplot as plt
-from matplotlib.font_manager import FontProperties as FP
import numpy as np
+from matplotlib.font_manager import FontProperties as FP
from ..postprocessors.units import get_unit_converter
-def plot_ti_dhdl(dhdl_data, labels=None, colors=None, units=None,
- ax=None):
- '''Plot the dhdl of TI.
+
+def plot_ti_dhdl(dhdl_data, labels=None, colors=None, units=None, ax=None):
+ """Plot the dhdl of TI.
Parameters
----------
@@ -55,7 +55,7 @@ def plot_ti_dhdl(dhdl_data, labels=None, colors=None, units=None,
changing the figure legend.
.. versionadded:: 0.4.0
- '''
+ """
# Make it into a list
# separate_dhdl method is used so that the input for the actual plotting
# Function are a uniformed list of series object which only contains one
@@ -69,7 +69,7 @@ def plot_ti_dhdl(dhdl_data, labels=None, colors=None, units=None,
# Convert unit
if units is None:
- units = dhdl_list[0].attrs['energy_unit']
+ units = dhdl_list[0].attrs["energy_unit"]
new_unit = []
convert = get_unit_converter(units)
@@ -80,11 +80,11 @@ def plot_ti_dhdl(dhdl_data, labels=None, colors=None, units=None,
if ax is None:
fig, ax = plt.subplots(figsize=(8, 6))
- ax.spines['bottom'].set_position('zero')
- ax.spines['top'].set_color('none')
- ax.spines['right'].set_color('none')
- ax.xaxis.set_ticks_position('bottom')
- ax.yaxis.set_ticks_position('left')
+ ax.spines["bottom"].set_position("zero")
+ ax.spines["top"].set_color("none")
+ ax.spines["right"].set_color("none")
+ ax.xaxis.set_ticks_position("bottom")
+ ax.yaxis.set_ticks_position("left")
for k, spine in ax.spines.items():
spine.set_zorder(12.2)
@@ -98,20 +98,24 @@ def plot_ti_dhdl(dhdl_data, labels=None, colors=None, units=None,
else:
if len(labels) == len(dhdl_list):
lv_names2 = labels
- else: # pragma: no cover
+ else: # pragma: no cover
raise ValueError(
- 'Length of labels ({}) should be the same as the number of data ({})'.format(
- len(labels), len(dhdl_list)))
+ "Length of labels ({}) should be the same as the number of data ({})".format(
+ len(labels), len(dhdl_list)
+ )
+ )
if colors is None:
- colors = ['r', 'g', '#7F38EC', '#9F000F', 'b', 'y']
+ colors = ["r", "g", "#7F38EC", "#9F000F", "b", "y"]
else:
if len(colors) >= len(dhdl_list):
pass
- else: # pragma: no cover
+ else: # pragma: no cover
raise ValueError(
- 'Number of colors ({}) should be larger than the number of data ({})'.format(
- len(labels), len(dhdl_list)))
+ "Number of colors ({}) should be larger than the number of data ({})".format(
+ len(labels), len(dhdl_list)
+ )
+ )
# Get the real data out
xs, ndx, dx = [0], 0, 0.001
@@ -125,16 +129,22 @@ def plot_ti_dhdl(dhdl_data, labels=None, colors=None, units=None,
for i in range(len(x) - 1):
if i % 2 == 0:
- ax.fill_between(x[i:i + 2] + ndx, 0, y[i:i + 2],
- color=colors[ndx], alpha=1.0)
+ ax.fill_between(
+ x[i : i + 2] + ndx, 0, y[i : i + 2], color=colors[ndx], alpha=1.0
+ )
else:
- ax.fill_between(x[i:i + 2] + ndx, 0, y[i:i + 2],
- color=colors[ndx], alpha=0.5)
+ ax.fill_between(
+ x[i : i + 2] + ndx, 0, y[i : i + 2], color=colors[ndx], alpha=0.5
+ )
xlegend = [-100 * wnum for wnum in range(len(lv_names2))]
- ax.plot(xlegend, [0 * wnum for wnum in xlegend], ls='-',
- color=colors[ndx],
- label=lv_names2[ndx])
+ ax.plot(
+ xlegend,
+ [0 * wnum for wnum in xlegend],
+ ls="-",
+ color=colors[ndx],
+ label=lv_names2[ndx],
+ )
xs += (x + ndx).tolist()[1:]
ndx += 1
@@ -159,7 +169,7 @@ def getInd(r=ri, z=[0]):
if i in getInd():
xt.append(i)
else:
- xt.append('')
+ xt.append("")
plt.xticks(xs[1:], xt[1:], fontsize=10)
ax.yaxis.label.set_size(10)
@@ -172,31 +182,46 @@ def getInd(r=ri, z=[0]):
max_y *= 1.01
# Modified so that the x label won't conflict with the lambda label
- min_y -= (max_y-min_y)*0.1
+ min_y -= (max_y - min_y) * 0.1
ax.set_ylim(min_y, max_y)
for i, j in zip(xs[1:], xt[1:]):
ax.annotate(
- ('%.2f' % (i - 1.0 if i > 1.0 else i) if not j == '' else ''),
- xy=(i, 0), size=10, rotation=90, va='bottom', ha='center',
- color='#151B54')
+ ("%.2f" % (i - 1.0 if i > 1.0 else i) if not j == "" else ""),
+ xy=(i, 0),
+ size=10,
+ rotation=90,
+ va="bottom",
+ ha="center",
+ color="#151B54",
+ )
if ndx > 1:
lenticks = len(ax.get_ymajorticklabels()) - 1
- if min_y < 0: lenticks -= 1
+ if min_y < 0:
+ lenticks -= 1
if lenticks < 5: # pragma: no cover
from matplotlib.ticker import AutoMinorLocator as AML
+
ax.yaxis.set_minor_locator(AML())
- ax.grid(which='both', color='w', lw=0.25, axis='y', zorder=12)
+ ax.grid(which="both", color="w", lw=0.25, axis="y", zorder=12)
ax.set_ylabel(
- r'$\langle{\frac{\partial U}{\partial\lambda}}\rangle_{\lambda}$' +
- '({})'.format(units),
- fontsize=20, color='#151B54')
- ax.annotate(r'$\mathit{\lambda}$', xy=(0, 0), xytext=(0.5, -0.05), size=18,
- textcoords='axes fraction', va='top', ha='center',
- color='#151B54')
+ r"$\langle{\frac{\partial U}{\partial\lambda}}\rangle_{\lambda}$"
+ + "({})".format(units),
+ fontsize=20,
+ color="#151B54",
+ )
+ ax.annotate(
+ r"$\mathit{\lambda}$",
+ xy=(0, 0),
+ xytext=(0.5, -0.05),
+ size=18,
+ textcoords="axes fraction",
+ va="top",
+ ha="center",
+ color="#151B54",
+ )
lege = ax.legend(prop=FP(size=14), frameon=False, loc=1)
for l in lege.legendHandles:
l.set_linewidth(10)
return ax
-
diff --git a/src/alchemlyb/workflows/__init__.py b/src/alchemlyb/workflows/__init__.py
index 6b35d460..a6a156cf 100644
--- a/src/alchemlyb/workflows/__init__.py
+++ b/src/alchemlyb/workflows/__init__.py
@@ -1,4 +1,5 @@
__all__ = [
- 'base',
+ "base",
]
+
from .abfe import ABFE
diff --git a/src/alchemlyb/workflows/abfe.py b/src/alchemlyb/workflows/abfe.py
index 32e51a7c..9fef25c1 100644
--- a/src/alchemlyb/workflows/abfe.py
+++ b/src/alchemlyb/workflows/abfe.py
@@ -1,26 +1,31 @@
+import logging
import os
-from os.path import join
from glob import glob
-import pandas as pd
-import numpy as np
-import logging
+from os.path import join
+
import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
from .base import WorkflowBase
-from ..parsing import gmx, amber
-from ..preprocessing.subsampling import decorrelate_dhdl, decorrelate_u_nk
-from ..estimators import BAR, TI, FEP_ESTIMATORS, TI_ESTIMATORS
+from .. import __version__
+from .. import concat
+from ..convergence import forward_backward_convergence
from ..estimators import AutoMBAR as MBAR
-from ..visualisation import (plot_mbar_overlap_matrix, plot_ti_dhdl,
- plot_dF_state, plot_convergence)
+from ..estimators import BAR, TI, FEP_ESTIMATORS, TI_ESTIMATORS
+from ..parsing import gmx, amber
from ..postprocessors.units import get_unit_converter
-from ..convergence import forward_backward_convergence
-from .. import concat
-from .. import __version__
+from ..preprocessing.subsampling import decorrelate_dhdl, decorrelate_u_nk
+from ..visualisation import (
+ plot_mbar_overlap_matrix,
+ plot_ti_dhdl,
+ plot_dF_state,
+ plot_convergence,
+)
class ABFE(WorkflowBase):
- '''Workflow for absolute and relative binding free energy calculations.
+ """Workflow for absolute and relative binding free energy calculations.
This workflow provides functionality similar to the ``alchemical-analysis.py`` script.
It loads multiple input files from alchemical free energy calculations and computes the
@@ -58,42 +63,50 @@ class ABFE(WorkflowBase):
.. versionadded:: 1.0.0
- '''
- def __init__(self, T, units='kT', software='GROMACS', dir=os.path.curdir,
- prefix='dhdl', suffix='xvg',
- outdirectory=os.path.curdir):
+ """
+
+ def __init__(
+ self,
+ T,
+ units="kT",
+ software="GROMACS",
+ dir=os.path.curdir,
+ prefix="dhdl",
+ suffix="xvg",
+ outdirectory=os.path.curdir,
+ ):
super().__init__(units, software, T, outdirectory)
- self.logger = logging.getLogger('alchemlyb.workflows.ABFE')
- self.logger.info('Initialise Alchemlyb ABFE Workflow')
- self.logger.info(f'Alchemlyb Version: f{__version__}')
- self.logger.info(f'Set Temperature to {T} K.')
- self.logger.info(f'Set Software to {software}.')
+ self.logger = logging.getLogger("alchemlyb.workflows.ABFE")
+ self.logger.info("Initialise Alchemlyb ABFE Workflow")
+ self.logger.info(f"Alchemlyb Version: f{__version__}")
+ self.logger.info(f"Set Temperature to {T} K.")
+ self.logger.info(f"Set Software to {software}.")
self.update_units(units)
- self.logger.info(f'Finding files with prefix: {prefix}, suffix: '
- f'{suffix} under directory {dir} produced by '
- f'{software}')
- self.file_list = glob(dir + '/**/' + prefix + '*' + suffix,
- recursive=True)
+ self.logger.info(
+ f"Finding files with prefix: {prefix}, suffix: "
+ f"{suffix} under directory {dir} produced by "
+ f"{software}"
+ )
+ self.file_list = glob(dir + "/**/" + prefix + "*" + suffix, recursive=True)
- self.logger.info(f'Found {len(self.file_list)} xvg files.')
- self.logger.info("Unsorted file list: \n %s", '\n'.join(
- self.file_list))
+ self.logger.info(f"Found {len(self.file_list)} xvg files.")
+ self.logger.info("Unsorted file list: \n %s", "\n".join(self.file_list))
- if software == 'GROMACS':
- self.logger.info(f'Using {software} parser to read the data.')
+ if software == "GROMACS":
+ self.logger.info(f"Using {software} parser to read the data.")
self._extract_u_nk = gmx.extract_u_nk
self._extract_dHdl = gmx.extract_dHdl
- elif software == 'AMBER':
+ elif software == "AMBER":
self._extract_u_nk = amber.extract_u_nk
self._extract_dHdl = amber.extract_dHdl
else:
- raise NotImplementedError(f'{software} parser not found.')
+ raise NotImplementedError(f"{software} parser not found.")
def read(self, read_u_nk=True, read_dHdl=True):
- '''Read the u_nk and dHdL data from the
+ """Read the u_nk and dHdL data from the
:attr:`~alchemlyb.workflows.ABFE.file_list`
Parameters
@@ -109,7 +122,7 @@ def read(self, read_u_nk=True, read_dHdl=True):
A list of :class:`pandas.DataFrame` of u_nk.
dHdl_list : list
A list of :class:`pandas.DataFrame` of dHdl.
- '''
+ """
self.u_nk_sample_list = None
self.dHdl_sample_list = None
@@ -119,46 +132,46 @@ def read(self, read_u_nk=True, read_dHdl=True):
if read_u_nk:
try:
u_nk = self._extract_u_nk(file, T=self.T)
- self.logger.info(
- f'Reading {len(u_nk)} lines of u_nk from {file}')
+ self.logger.info(f"Reading {len(u_nk)} lines of u_nk from {file}")
u_nk_list.append(u_nk)
except Exception as exc:
- msg = f'Error reading u_nk from {file}.'
+ msg = f"Error reading u_nk from {file}."
self.logger.error(msg)
raise OSError(msg) from exc
if read_dHdl:
try:
dhdl = self._extract_dHdl(file, T=self.T)
- self.logger.info(
- f'Reading {len(dhdl)} lines of dhdl from {file}')
+ self.logger.info(f"Reading {len(dhdl)} lines of dhdl from {file}")
dHdl_list.append(dhdl)
except Exception as exc:
- msg = f'Error reading dHdl from {file}.'
+ msg = f"Error reading dHdl from {file}."
self.logger.error(msg)
raise OSError(msg) from exc
# Sort the files according to the state
if read_u_nk:
- self.logger.info('Sort files according to the u_nk.')
+ self.logger.info("Sort files according to the u_nk.")
column_names = u_nk_list[0].columns.values.tolist()
- index_list = sorted(range(len(self.file_list)),
- key=lambda x: column_names.index(
- u_nk_list[x].reset_index(
- 'time').index.values[0]))
+ index_list = sorted(
+ range(len(self.file_list)),
+ key=lambda x: column_names.index(
+ u_nk_list[x].reset_index("time").index.values[0]
+ ),
+ )
elif read_dHdl:
- self.logger.info('Sort files according to the dHdl.')
- index_list = sorted(range(len(self.file_list)),
- key=lambda x:
- dHdl_list[x].reset_index(
- 'time').index.values[0])
+ self.logger.info("Sort files according to the dHdl.")
+ index_list = sorted(
+ range(len(self.file_list)),
+ key=lambda x: dHdl_list[x].reset_index("time").index.values[0],
+ )
else:
self.u_nk_list = []
self.dHdl_list = []
return
self.file_list = [self.file_list[i] for i in index_list]
- self.logger.info("Sorted file list: \n%s", '\n'.join(self.file_list))
+ self.logger.info("Sorted file list: \n%s", "\n".join(self.file_list))
if read_u_nk:
self.u_nk_list = [u_nk_list[i] for i in index_list]
else:
@@ -169,11 +182,19 @@ def read(self, read_u_nk=True, read_dHdl=True):
else:
self.dHdl_list = []
-
- def run(self, skiptime=0, uncorr='dE', threshold=50,
- estimators=('MBAR', 'BAR', 'TI'), overlap='O_MBAR.pdf',
- breakdown=True, forwrev=None, *args, **kwargs):
- ''' The method for running the automatic analysis.
+ def run(
+ self,
+ skiptime=0,
+ uncorr="dE",
+ threshold=50,
+ estimators=("MBAR", "BAR", "TI"),
+ overlap="O_MBAR.pdf",
+ breakdown=True,
+ forwrev=None,
+ *args,
+ **kwargs,
+ ):
+ """The method for running the automatic analysis.
Parameters
----------
@@ -214,29 +235,32 @@ def run(self, skiptime=0, uncorr='dE', threshold=50,
The summary of the convergence results. See
:func:`~alchemlyb.convergence.forward_backward_convergence` for
further explanation.
- '''
+ """
use_FEP = False
use_TI = False
if estimators is not None:
if isinstance(estimators, str):
- estimators = [estimators, ]
+ estimators = [
+ estimators,
+ ]
for estimator in estimators:
if estimator in FEP_ESTIMATORS:
use_FEP = True
elif estimator in TI_ESTIMATORS:
use_TI = True
else:
- msg = f"Estimator {estimator} is not supported. Choose one from " \
- f"{FEP_ESTIMATORS + TI_ESTIMATORS}."
+ msg = (
+ f"Estimator {estimator} is not supported. Choose one from "
+ f"{FEP_ESTIMATORS + TI_ESTIMATORS}."
+ )
self.logger.error(msg)
raise ValueError(msg)
self.read(use_FEP, use_TI)
if uncorr is not None:
- self.preprocess(skiptime=skiptime, uncorr=uncorr,
- threshold=threshold)
+ self.preprocess(skiptime=skiptime, uncorr=uncorr, threshold=threshold)
if estimators is not None:
self.estimate(estimators)
self.generate_result()
@@ -251,31 +275,30 @@ def run(self, skiptime=0, uncorr='dE', threshold=50,
plt.close(ax.figure)
fig = self.plot_dF_state()
plt.close(fig)
- fig = self.plot_dF_state(dF_state='dF_state_long.pdf',
- orientation='landscape')
+ fig = self.plot_dF_state(
+ dF_state="dF_state_long.pdf", orientation="landscape"
+ )
plt.close(fig)
if forwrev is not None:
- ax = self.check_convergence(forwrev, estimator='MBAR',
- dF_t='dF_t.pdf')
+ ax = self.check_convergence(forwrev, estimator="MBAR", dF_t="dF_t.pdf")
plt.close(ax.figure)
-
def update_units(self, units=None):
- '''Update the unit.
+ """Update the unit.
Parameters
----------
units : {'kcal/mol', 'kJ/mol', 'kT'}
The unit used for printing and plotting results.
- '''
+ """
if units is not None:
- self.logger.info(f'Set unit to {units}.')
+ self.logger.info(f"Set unit to {units}.")
self.units = units or None
- def preprocess(self, skiptime=0, uncorr='dE', threshold=50):
- '''Preprocess the data by removing the equilibration time and
+ def preprocess(self, skiptime=0, uncorr="dE", threshold=50):
+ """Preprocess the data by removing the equilibration time and
decorrelate the date.
Parameters
@@ -296,54 +319,65 @@ def preprocess(self, skiptime=0, uncorr='dE', threshold=50):
The list of u_nk after decorrelation.
dHdl_sample_list : list
The list of dHdl after decorrelation.
- '''
- self.logger.info(f'Start preprocessing with skiptime of {skiptime} '
- f'uncorrelation method of {uncorr} and threshold of '
- f'{threshold}')
+ """
+ self.logger.info(
+ f"Start preprocessing with skiptime of {skiptime} "
+ f"uncorrelation method of {uncorr} and threshold of "
+ f"{threshold}"
+ )
if len(self.u_nk_list) > 0:
self.logger.info(
- f'Processing the u_nk data set with skiptime of {skiptime}.')
+ f"Processing the u_nk data set with skiptime of {skiptime}."
+ )
self.u_nk_sample_list = []
for index, u_nk in enumerate(self.u_nk_list):
# Find the starting frame
- u_nk = u_nk[u_nk.index.get_level_values('time') >= skiptime]
+ u_nk = u_nk[u_nk.index.get_level_values("time") >= skiptime]
subsample = decorrelate_u_nk(u_nk, uncorr, remove_burnin=True)
if len(subsample) < threshold:
- self.logger.warning(f'Number of u_nk {len(subsample)} '
- f'for state {index} is less than the '
- f'threshold {threshold}.')
- self.logger.info(f'Take all the u_nk for state {index}.')
+ self.logger.warning(
+ f"Number of u_nk {len(subsample)} "
+ f"for state {index} is less than the "
+ f"threshold {threshold}."
+ )
+ self.logger.info(f"Take all the u_nk for state {index}.")
self.u_nk_sample_list.append(u_nk)
else:
- self.logger.info(f'Take {len(subsample)} uncorrelated '
- f'u_nk for state {index}.')
+ self.logger.info(
+ f"Take {len(subsample)} uncorrelated "
+ f"u_nk for state {index}."
+ )
self.u_nk_sample_list.append(subsample)
else:
- self.logger.info('No u_nk data being subsampled')
+ self.logger.info("No u_nk data being subsampled")
if len(self.dHdl_list) > 0:
self.dHdl_sample_list = []
for index, dHdl in enumerate(self.dHdl_list):
- dHdl = dHdl[dHdl.index.get_level_values('time') >= skiptime]
+ dHdl = dHdl[dHdl.index.get_level_values("time") >= skiptime]
subsample = decorrelate_dhdl(dHdl, remove_burnin=True)
if len(subsample) < threshold:
- self.logger.warning(f'Number of dHdl {len(subsample)} for '
- f'state {index} is less than the '
- f'threshold {threshold}.')
- self.logger.info(f'Take all the dHdl for state {index}.')
+ self.logger.warning(
+ f"Number of dHdl {len(subsample)} for "
+ f"state {index} is less than the "
+ f"threshold {threshold}."
+ )
+ self.logger.info(f"Take all the dHdl for state {index}.")
self.dHdl_sample_list.append(dHdl)
else:
- self.logger.info(f'Take {len(subsample)} uncorrelated '
- f'dHdl for state {index}.')
+ self.logger.info(
+ f"Take {len(subsample)} uncorrelated "
+ f"dHdl for state {index}."
+ )
self.dHdl_sample_list.append(subsample)
else:
- self.logger.info('No dHdl data being subsampled')
+ self.logger.info("No dHdl data being subsampled")
- def estimate(self, estimators=('MBAR', 'BAR', 'TI'), **kwargs):
- '''Estimate the free energy using the selected estimator.
+ def estimate(self, estimators=("MBAR", "BAR", "TI"), **kwargs):
+ """Estimate the free energy using the selected estimator.
Parameters
----------
@@ -368,10 +402,10 @@ def estimate(self, estimators=('MBAR', 'BAR', 'TI'), **kwargs):
behavior of :class:`~alchemlyb.estimators.MBAR`.
(:code:`estimate(estimators='MBAR', method='adaptive')`)
- '''
+ """
# Make estimators into a tuple
if isinstance(estimators, str):
- estimators = (estimators, )
+ estimators = (estimators,)
for estimator in estimators:
if estimator not in (FEP_ESTIMATORS + TI_ESTIMATORS):
@@ -379,42 +413,38 @@ def estimate(self, estimators=('MBAR', 'BAR', 'TI'), **kwargs):
self.logger.error(msg)
raise ValueError(msg)
- self.logger.info(
- f"Start running estimator: {','.join(estimators)}.")
+ self.logger.info(f"Start running estimator: {','.join(estimators)}.")
self.estimator = {}
# Use unprocessed data if preprocess is not performed.
- if 'TI' in estimators:
+ if "TI" in estimators:
if self.dHdl_sample_list is not None:
dHdl = concat(self.dHdl_sample_list)
else:
dHdl = concat(self.dHdl_list)
- self.logger.warning('dHdl has not been preprocessed.')
- self.logger.info(
- f'A total {len(dHdl)} lines of dHdl is used.')
+ self.logger.warning("dHdl has not been preprocessed.")
+ self.logger.info(f"A total {len(dHdl)} lines of dHdl is used.")
- if 'BAR' in estimators or 'MBAR' in estimators:
+ if "BAR" in estimators or "MBAR" in estimators:
if self.u_nk_sample_list is not None:
u_nk = concat(self.u_nk_sample_list)
else:
u_nk = concat(self.u_nk_list)
- self.logger.warning('u_nk has not been preprocessed.')
- self.logger.info(
- f'A total {len(u_nk)} lines of u_nk is used.')
+ self.logger.warning("u_nk has not been preprocessed.")
+ self.logger.info(f"A total {len(u_nk)} lines of u_nk is used.")
for estimator in estimators:
- if estimator == 'MBAR':
- self.logger.info('Run MBAR estimator.')
+ if estimator == "MBAR":
+ self.logger.info("Run MBAR estimator.")
self.estimator[estimator] = MBAR(**kwargs).fit(u_nk)
- elif estimator == 'BAR':
- self.logger.info('Run BAR estimator.')
+ elif estimator == "BAR":
+ self.logger.info("Run BAR estimator.")
self.estimator[estimator] = BAR(**kwargs).fit(u_nk)
- elif estimator == 'TI':
- self.logger.info('Run TI estimator.')
+ elif estimator == "TI":
+ self.logger.info("Run TI estimator.")
self.estimator[estimator] = TI(**kwargs).fit(dHdl)
-
def generate_result(self):
- '''Summarise the result into a dataframe.
+ """Summarise the result into a dataframe.
Returns
-------
@@ -460,38 +490,37 @@ def generate_result(self):
----------
summary : Dataframe
The summary of the free energy estimate.
- '''
+ """
# Write estimate
- self.logger.info('Summarise the estimate into a dataframe.')
+ self.logger.info("Summarise the estimate into a dataframe.")
# Make the header name
- self.logger.info('Generate the row names.')
+ self.logger.info("Generate the row names.")
estimator_names = list(self.estimator.keys())
num_states = len(self.estimator[estimator_names[0]].states_)
- data_dict = {'name': [],
- 'state': []}
+ data_dict = {"name": [], "state": []}
for i in range(num_states - 1):
- data_dict['name'].append(str(i) + ' -- ' + str(i+1))
- data_dict['state'].append('States')
+ data_dict["name"].append(str(i) + " -- " + str(i + 1))
+ data_dict["state"].append("States")
try:
u_nk = self.u_nk_list[0]
- stages = u_nk.reset_index('time').index.names
- self.logger.info('use the stage name from u_nk')
+ stages = u_nk.reset_index("time").index.names
+ self.logger.info("use the stage name from u_nk")
except:
dHdl = self.dHdl_list[0]
- stages = dHdl.reset_index('time').index.names
- self.logger.info('use the stage name from dHdl')
+ stages = dHdl.reset_index("time").index.names
+ self.logger.info("use the stage name from dHdl")
for stage in stages:
- data_dict['name'].append(stage.split('-')[0])
- data_dict['state'].append('Stages')
- data_dict['name'].append('TOTAL')
- data_dict['state'].append('Stages')
+ data_dict["name"].append(stage.split("-")[0])
+ data_dict["state"].append("Stages")
+ data_dict["name"].append("TOTAL")
+ data_dict["state"].append("Stages")
col_names = []
for estimator_name, estimator in self.estimator.items():
- self.logger.info(f'Read the results from estimator {estimator_name}')
+ self.logger.info(f"Read the results from estimator {estimator_name}")
# Do the unit conversion
delta_f_ = estimator.delta_f_
@@ -499,26 +528,26 @@ def generate_result(self):
# Write the estimator header
col_names.append(estimator_name)
- col_names.append(estimator_name + '_Error')
+ col_names.append(estimator_name + "_Error")
data_dict[estimator_name] = []
- data_dict[estimator_name + '_Error'] = []
+ data_dict[estimator_name + "_Error"] = []
for index in range(1, num_states):
- data_dict[estimator_name].append(
- delta_f_.iloc[index-1, index])
- data_dict[estimator_name + '_Error'].append(
- d_delta_f_.iloc[index - 1, index])
+ data_dict[estimator_name].append(delta_f_.iloc[index - 1, index])
+ data_dict[estimator_name + "_Error"].append(
+ d_delta_f_.iloc[index - 1, index]
+ )
- self.logger.info(f'Generate the staged result from estimator {estimator_name}')
+ self.logger.info(
+ f"Generate the staged result from estimator {estimator_name}"
+ )
for index, stage in enumerate(stages):
if len(stages) == 1:
start = 0
end = len(estimator.states_) - 1
else:
# Get the start and the end of the state
- lambda_min = min(
- [state[index] for state in estimator.states_])
- lambda_max = max(
- [state[index] for state in estimator.states_])
+ lambda_min = min([state[index] for state in estimator.states_])
+ lambda_max = max([state[index] for state in estimator.states_])
if lambda_min == lambda_max:
# Deal with the case where a certain lambda is used but
# not perturbed
@@ -529,35 +558,39 @@ def generate_result(self):
start = list(reversed(states)).index(lambda_min)
start = num_states - start - 1
end = states.index(lambda_max)
- self.logger.info(
- f'Stage {stage} is from state {start} to state {end}.')
+ self.logger.info(f"Stage {stage} is from state {start} to state {end}.")
# This assumes that the indexes are sorted as the
# preprocessing should sort the index of the df.
result = delta_f_.iloc[start, end]
- if estimator_name != 'BAR':
+ if estimator_name != "BAR":
error = d_delta_f_.iloc[start, end]
else:
- error = np.sqrt(sum(
- [d_delta_f_.iloc[start, start+1]**2
- for i in range(start, end + 1)]))
+ error = np.sqrt(
+ sum(
+ [
+ d_delta_f_.iloc[start, start + 1] ** 2
+ for i in range(start, end + 1)
+ ]
+ )
+ )
data_dict[estimator_name].append(result)
- data_dict[estimator_name + '_Error'].append(error)
+ data_dict[estimator_name + "_Error"].append(error)
# Total result
# This assumes that the indexes are sorted as the
# preprocessing should sort the index of the df.
result = delta_f_.iloc[0, -1]
- if estimator_name != 'BAR':
+ if estimator_name != "BAR":
error = d_delta_f_.iloc[0, -1]
else:
- error = np.sqrt(sum(
- [d_delta_f_.iloc[i, i + 1] ** 2
- for i in range(num_states - 1)]))
+ error = np.sqrt(
+ sum([d_delta_f_.iloc[i, i + 1] ** 2 for i in range(num_states - 1)])
+ )
data_dict[estimator_name].append(result)
- data_dict[estimator_name + '_Error'].append(error)
+ data_dict[estimator_name + "_Error"].append(error)
summary = pd.DataFrame.from_dict(data_dict)
- summary = summary.set_index(['state', 'name'])
+ summary = summary.set_index(["state", "name"])
# Make sure that the columns are in the right order
summary = summary[col_names]
# Remove the name of the index column to make it prettier
@@ -567,11 +600,11 @@ def generate_result(self):
converter = get_unit_converter(self.units)
summary = converter(summary)
self.summary = summary
- self.logger.info(f'Write results:\n{summary.to_string()}')
+ self.logger.info(f"Write results:\n{summary.to_string()}")
return summary
- def plot_overlap_matrix(self, overlap='O_MBAR.pdf', ax=None):
- '''Plot the overlap matrix for MBAR estimator using
+ def plot_overlap_matrix(self, overlap="O_MBAR.pdf", ax=None):
+ """Plot the overlap matrix for MBAR estimator using
:func:`~alchemlyb.visualisation.plot_mbar_overlap_matrix`.
Parameters
@@ -586,21 +619,20 @@ def plot_overlap_matrix(self, overlap='O_MBAR.pdf', ax=None):
-------
matplotlib.axes.Axes
An axes with the overlap matrix drawn.
- '''
- self.logger.info('Plot overlap matrix.')
- if 'MBAR' in self.estimator:
- ax = plot_mbar_overlap_matrix(self.estimator['MBAR'].overlap_matrix,
- ax=ax)
+ """
+ self.logger.info("Plot overlap matrix.")
+ if "MBAR" in self.estimator:
+ ax = plot_mbar_overlap_matrix(self.estimator["MBAR"].overlap_matrix, ax=ax)
ax.figure.savefig(join(self.out, overlap))
- self.logger.info(f'Plot overlap matrix to {self.out} under {overlap}.')
+ self.logger.info(f"Plot overlap matrix to {self.out} under {overlap}.")
return ax
else:
- self.logger.warning('MBAR estimator not found. '
- 'Overlap matrix not plotted.')
+ self.logger.warning(
+ "MBAR estimator not found. " "Overlap matrix not plotted."
+ )
- def plot_ti_dhdl(self, dhdl_TI='dhdl_TI.pdf', labels=None, colors=None,
- ax=None):
- '''Plot the dHdl for TI estimator using
+ def plot_ti_dhdl(self, dhdl_TI="dhdl_TI.pdf", labels=None, colors=None, ax=None):
+ """Plot the dHdl for TI estimator using
:func:`~alchemlyb.visualisation.plot_ti_dhdl`.
Parameters
@@ -620,20 +652,31 @@ def plot_ti_dhdl(self, dhdl_TI='dhdl_TI.pdf', labels=None, colors=None,
-------
matplotlib.axes.Axes
An axes with the TI dhdl drawn.
- '''
- self.logger.info('Plot TI dHdl.')
- if 'TI' in self.estimator:
- ax = plot_ti_dhdl(self.estimator['TI'], units=self.units,
- labels=labels, colors=colors, ax=ax)
+ """
+ self.logger.info("Plot TI dHdl.")
+ if "TI" in self.estimator:
+ ax = plot_ti_dhdl(
+ self.estimator["TI"],
+ units=self.units,
+ labels=labels,
+ colors=colors,
+ ax=ax,
+ )
ax.figure.savefig(join(self.out, dhdl_TI))
- self.logger.info(f'Plot TI dHdl to {dhdl_TI} under {self.out}.')
+ self.logger.info(f"Plot TI dHdl to {dhdl_TI} under {self.out}.")
return ax
else:
- raise ValueError('No TI data available in estimators.')
-
- def plot_dF_state(self, dF_state='dF_state.pdf', labels=None, colors=None,
- orientation='portrait', nb=10):
- '''Plot the dF states using
+ raise ValueError("No TI data available in estimators.")
+
+ def plot_dF_state(
+ self,
+ dF_state="dF_state.pdf",
+ labels=None,
+ colors=None,
+ orientation="portrait",
+ nb=10,
+ ):
+ """Plot the dF states using
:func:`~alchemlyb.visualisation.plot_dF_state`.
Parameters
@@ -653,18 +696,24 @@ def plot_dF_state(self, dF_state='dF_state.pdf', labels=None, colors=None,
-------
matplotlib.figure.Figure
An Figure with the dF states drawn.
- '''
- self.logger.info('Plot dF states.')
- fig = plot_dF_state(self.estimator.values(), labels=labels, colors=colors,
- units=self.units,
- orientation=orientation, nb=nb)
+ """
+ self.logger.info("Plot dF states.")
+ fig = plot_dF_state(
+ self.estimator.values(),
+ labels=labels,
+ colors=colors,
+ units=self.units,
+ orientation=orientation,
+ nb=nb,
+ )
fig.savefig(join(self.out, dF_state))
- self.logger.info(f'Plot dF state to {dF_state} under {self.out}.')
+ self.logger.info(f"Plot dF state to {dF_state} under {self.out}.")
return fig
- def check_convergence(self, forwrev, estimator='MBAR', dF_t='dF_t.pdf',
- ax=None, **kwargs):
- '''Compute the forward and backward convergence using
+ def check_convergence(
+ self, forwrev, estimator="MBAR", dF_t="dF_t.pdf", ax=None, **kwargs
+ ):
+ """Compute the forward and backward convergence using
:func:`~alchemlyb.convergence.forward_backward_convergence`and
plot with
:func:`~alchemlyb.visualisation.plot_convergence`.
@@ -701,59 +750,63 @@ def check_convergence(self, forwrev, estimator='MBAR', dF_t='dF_t.pdf',
of :class:`~alchemlyb.estimators.MBAR`.
(:code:`check_convergence(10, estimator='MBAR', method='adaptive')`)
- '''
- self.logger.info('Start convergence analysis.')
- self.logger.info('Checking data availability.')
+ """
+ self.logger.info("Start convergence analysis.")
+ self.logger.info("Checking data availability.")
if estimator in FEP_ESTIMATORS:
if self.u_nk_sample_list is not None:
u_nk_list = self.u_nk_sample_list
- self.logger.info('Subsampled u_nk is available.')
+ self.logger.info("Subsampled u_nk is available.")
else:
if self.u_nk_list is not None:
u_nk_list = self.u_nk_list
- self.logger.info('Subsampled u_nk not available, '
- 'use original data instead.')
+ self.logger.info(
+ "Subsampled u_nk not available, " "use original data instead."
+ )
else:
- msg = f"u_nk is needed for the f{estimator} estimator. " \
- f"If the dataset only has dHdl, " \
- f"run ABFE.check_convergence(estimator='TI') to " \
- f"use a TI estimator."
+ msg = (
+ f"u_nk is needed for the f{estimator} estimator. "
+ f"If the dataset only has dHdl, "
+ f"run ABFE.check_convergence(estimator='TI') to "
+ f"use a TI estimator."
+ )
self.logger.error(msg)
raise ValueError(msg)
- convergence = forward_backward_convergence(u_nk_list,
- estimator=estimator,
- num=forwrev, **kwargs)
+ convergence = forward_backward_convergence(
+ u_nk_list, estimator=estimator, num=forwrev, **kwargs
+ )
elif estimator in TI_ESTIMATORS:
- self.logger.warning('No valid FEP estimator or dataset found. '
- 'Fallback to TI.')
+ self.logger.warning(
+ "No valid FEP estimator or dataset found. " "Fallback to TI."
+ )
if self.dHdl_sample_list is not None:
dHdl_list = self.dHdl_sample_list
- self.logger.info('Subsampled dHdl is available.')
+ self.logger.info("Subsampled dHdl is available.")
else:
if self.dHdl_list is not None:
dHdl_list = self.dHdl_list
- self.logger.info('Subsampled dHdl not available, '
- 'use original data instead.')
+ self.logger.info(
+ "Subsampled dHdl not available, " "use original data instead."
+ )
else:
- self.logger.error(
- f'dHdl is needed for the f{estimator} estimator.')
- raise ValueError(
- f'dHdl is needed for the f{estimator} estimator.')
- convergence = forward_backward_convergence(dHdl_list,
- estimator=estimator,
- num=forwrev, **kwargs)
+ self.logger.error(f"dHdl is needed for the f{estimator} estimator.")
+ raise ValueError(f"dHdl is needed for the f{estimator} estimator.")
+ convergence = forward_backward_convergence(
+ dHdl_list, estimator=estimator, num=forwrev, **kwargs
+ )
else:
- msg = f"Estimator {estimator} is not supported. Choose one from " \
- f"{FEP_ESTIMATORS+TI_ESTIMATORS}."
+ msg = (
+ f"Estimator {estimator} is not supported. Choose one from "
+ f"{FEP_ESTIMATORS + TI_ESTIMATORS}."
+ )
self.logger.error(msg)
raise ValueError(msg)
self.convergence = get_unit_converter(self.units)(convergence)
- self.logger.info(f'Plot convergence analysis to {dF_t} under {self.out}.')
+ self.logger.info(f"Plot convergence analysis to {dF_t} under {self.out}.")
- ax = plot_convergence(self.convergence,
- units=self.units, ax=ax)
+ ax = plot_convergence(self.convergence, units=self.units, ax=ax)
ax.figure.savefig(join(self.out, dF_t))
return ax
diff --git a/src/alchemlyb/workflows/base.py b/src/alchemlyb/workflows/base.py
index 1728f7f0..1b4cbe41 100644
--- a/src/alchemlyb/workflows/base.py
+++ b/src/alchemlyb/workflows/base.py
@@ -2,7 +2,8 @@
import pandas as pd
-class WorkflowBase():
+
+class WorkflowBase:
"""The base class for the Workflow.
This is the base class for the creation of new Workflow. The
@@ -37,9 +38,10 @@ class WorkflowBase():
.. versionadded:: 0.7.0
"""
- def __init__(self, units='kT', software='Gromacs', T=298, out='./', *args,
- **kwargs):
+ def __init__(
+ self, units="kT", software="Gromacs", T=298, out="./", *args, **kwargs
+ ):
self.T = T
self.software = software
self.unit = units
@@ -47,7 +49,7 @@ def __init__(self, units='kT', software='Gromacs', T=298, out='./', *args,
self.out = out
def run(self, *args, **kwargs):
- """ Run the workflow in an automatic fashion.
+ """Run the workflow in an automatic fashion.
This method would execute the
:func:`~alchemlyb.workflows.WorkflowBase.read`,
@@ -88,7 +90,7 @@ def run(self, *args, **kwargs):
self.plot(*args, **kwargs)
def read(self, *args, **kwargs):
- """ The function that reads the files in `file_list` and parse them
+ """The function that reads the files in `file_list` and parse them
into u_nk and dHdl files.
Attributes
@@ -104,7 +106,7 @@ def read(self, *args, **kwargs):
self.dHdl_list = []
def preprocess(self, *args, **kwargs):
- """ The function that subsample the u_nk and dHdl in `u_nk_list` and
+ """The function that subsample the u_nk and dHdl in `u_nk_list` and
`dHdl_list`.
Attributes
@@ -120,7 +122,7 @@ def preprocess(self, *args, **kwargs):
self.u_nk_sample_list = []
def estimate(self, *args, **kwargs):
- """ The function that runs the estimator based on `u_nk_sample_list`
+ """The function that runs the estimator based on `u_nk_sample_list`
and `dHdl_sample_list`.
Attributes
@@ -133,7 +135,7 @@ def estimate(self, *args, **kwargs):
self.result = pd.DataFrame()
def check_convergence(self, *args, **kwargs):
- """ The function for doing convergence analysis.
+ """The function for doing convergence analysis.
Attributes
----------
@@ -145,7 +147,5 @@ def check_convergence(self, *args, **kwargs):
self.convergence = pd.DataFrame()
def plot(self, *args, **kwargs):
- """ The function for producing any plots.
-
- """
+ """The function for producing any plots."""
pass