From f1a8ae214b0064b53952f8e6128900ead29cbaa5 Mon Sep 17 00:00:00 2001 From: "William (Zhiyi) Wu" Date: Sun, 14 Aug 2022 20:56:26 +0100 Subject: [PATCH 01/21] remove iloc --- src/alchemlyb/tests/test_convergence.py | 38 +++++++------------------ 1 file changed, 11 insertions(+), 27 deletions(-) diff --git a/src/alchemlyb/tests/test_convergence.py b/src/alchemlyb/tests/test_convergence.py index 7013d119..55ff832c 100644 --- a/src/alchemlyb/tests/test_convergence.py +++ b/src/alchemlyb/tests/test_convergence.py @@ -14,37 +14,21 @@ def test_convergence_ti(gmx_benzene): dHdl, u_nk = gmx_benzene convergence = forward_backward_convergence(dHdl, 'TI') assert convergence.shape == (10, 5) - assert convergence.iloc[0, 0] == pytest.approx(3.07, 0.01) - assert convergence.iloc[0, 2] == pytest.approx(3.11, 0.01) - assert convergence.iloc[-1, 0] == pytest.approx(3.09, 0.01) - assert convergence.iloc[-1, 2] == pytest.approx(3.09, 0.01) -def test_convergence_mbar(gmx_benzene): - dHdl, u_nk = gmx_benzene - convergence = forward_backward_convergence(u_nk, 'MBAR') - assert convergence.shape == (10, 5) - assert convergence.iloc[0, 0] == pytest.approx(3.02, 0.01) - assert convergence.iloc[0, 2] == pytest.approx(3.06, 0.01) - assert convergence.iloc[-1, 0] == pytest.approx(3.05, 0.01) - assert convergence.iloc[-1, 2] == pytest.approx(3.04, 0.01) - -def test_convergence_autombar(gmx_benzene): - dHdl, u_nk = gmx_benzene - convergence = forward_backward_convergence(u_nk, 'AutoMBAR') - assert convergence.shape == (10, 5) - assert convergence.iloc[0, 0] == pytest.approx(3.02, 0.01) - assert convergence.iloc[0, 2] == pytest.approx(3.06, 0.01) - assert convergence.iloc[-1, 0] == pytest.approx(3.05, 0.01) - assert convergence.iloc[-1, 2] == pytest.approx(3.04, 0.01) + assert convergence.loc[0, 'Forward'] == pytest.approx(3.07, 0.01) + assert convergence.loc[0, 'Backward'] == pytest.approx(3.11, 0.01) + assert convergence.loc[9, 'Forward'] == pytest.approx(3.09, 0.01) + assert convergence.loc[9, 'Backward'] == pytest.approx(3.09, 0.01) -def test_convergence_bar(gmx_benzene): +@pytest.mark.parametrize('estimator', ('MBAR', 'AutoMBAR', 'BAR')) +def test_convergence_fep(gmx_benzene, estimator): dHdl, u_nk = gmx_benzene - convergence = forward_backward_convergence(u_nk, 'BAR') + convergence = forward_backward_convergence(u_nk, estimator) assert convergence.shape == (10, 5) - assert convergence.iloc[0, 0] == pytest.approx(3.02, 0.01) - assert convergence.iloc[0, 2] == pytest.approx(3.06, 0.01) - assert convergence.iloc[-1, 0] == pytest.approx(3.05, 0.01) - assert convergence.iloc[-1, 2] == pytest.approx(3.04, 0.01) + assert convergence.loc[0, 'Forward'] == pytest.approx(3.02, 0.01) + assert convergence.loc[0, 'Backward'] == pytest.approx(3.06, 0.01) + assert convergence.loc[9, 'Forward'] == pytest.approx(3.05, 0.01) + assert convergence.loc[9, 'Backward'] == pytest.approx(3.04, 0.01) def test_convergence_wrong_estimator(gmx_benzene): dHdl, u_nk = gmx_benzene From afc54b879a84f5c862eb4cdaf6bcbf213ed7426b Mon Sep 17 00:00:00 2001 From: "William (Zhiyi) Wu" Date: Wed, 17 Aug 2022 19:44:06 +0100 Subject: [PATCH 02/21] update loc --- src/alchemlyb/tests/test_visualisation.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/alchemlyb/tests/test_visualisation.py b/src/alchemlyb/tests/test_visualisation.py index bcaf4a32..72f1a730 100644 --- a/src/alchemlyb/tests/test_visualisation.py +++ b/src/alchemlyb/tests/test_visualisation.py @@ -148,13 +148,13 @@ def test_plot_convergence(): slice = int(len(data_list[0])/num_points*i) u_nk_coul = alchemlyb.concat([data[:slice] for data in data_list]) estimate = MBAR().fit(u_nk_coul) - forward.append(estimate.delta_f_.iloc[0,-1]) - forward_error.append(estimate.d_delta_f_.iloc[0,-1]) + forward.append(estimate.delta_f_.loc[0.0,1.0]) + forward_error.append(estimate.d_delta_f_.loc[0.0,1.0]) # Do the backward u_nk_coul = alchemlyb.concat([data[-slice:] for data in data_list]) estimate = MBAR().fit(u_nk_coul) - backward.append(estimate.delta_f_.iloc[0,-1]) - backward_error.append(estimate.d_delta_f_.iloc[0,-1]) + backward.append(estimate.delta_f_.loc[0.0,1.0]) + backward_error.append(estimate.d_delta_f_.loc[0.0,1.0]) ax = plot_convergence(forward, forward_error, backward, backward_error) assert isinstance(ax, matplotlib.axes.Axes) From 3fd100581aa2ac673e81d63242d67c9dfcb340df Mon Sep 17 00:00:00 2001 From: "William (Zhiyi) Wu" Date: Sun, 16 Oct 2022 20:57:04 +0100 Subject: [PATCH 03/21] update --- src/alchemlyb/tests/parsing/test_gmx.py | 11 ++++++----- src/alchemlyb/tests/test_convergence.py | 2 +- src/alchemlyb/tests/test_ti_estimators.py | 4 ++-- src/alchemlyb/tests/test_units.py | 6 +++--- 4 files changed, 12 insertions(+), 11 deletions(-) diff --git a/src/alchemlyb/tests/parsing/test_gmx.py b/src/alchemlyb/tests/parsing/test_gmx.py index a5c6e0bc..d85ad1bf 100644 --- a/src/alchemlyb/tests/parsing/test_gmx.py +++ b/src/alchemlyb/tests/parsing/test_gmx.py @@ -124,7 +124,7 @@ def test_u_nk_with_total_energy(): # Check one specific value in the dataframe assert_almost_equal( - extract_u_nk(dataset['data']['AllStates'][0], T=300).iloc[0][0], + extract_u_nk(dataset['data']['AllStates'][0], T=300).loc[0][(0.0,0.0)].values[0], -11211.577658852531, decimal=6 ) @@ -142,7 +142,7 @@ def test_u_nk_with_potential_energy(): # Check one specific value in the dataframe assert_almost_equal( - extract_u_nk(dataset['data']['AllStates'][0], T=300).iloc[0][0], + extract_u_nk(dataset['data']['AllStates'][0], T=300).loc[0][(0.0,0.0)].values[0], -15656.557252200757, decimal=6 ) @@ -161,7 +161,7 @@ def test_u_nk_without_energy(): # Check one specific value in the dataframe assert_almost_equal( - extract_u_nk(dataset['data']['AllStates'][0], T=300).iloc[0][0], + extract_u_nk(dataset['data']['AllStates'][0], T=300).loc[0][(0.0,0.0)].values[0], 0.0, decimal=6 ) @@ -180,8 +180,9 @@ def _diag_sum(dataset): u_nk = extract_u_nk(filename, T=300) # Calculate the sum of diagonal elements: - for i in range(len(dataset['data'][leg])): - ds += u_nk.iloc[i][i] + for i, lambda_ in enumerate(u_nk.columns): + #18.6 is the time step + ds += u_nk.loc[i*186/10][lambda_].values[0] return ds diff --git a/src/alchemlyb/tests/test_convergence.py b/src/alchemlyb/tests/test_convergence.py index 7a363426..582cf9d5 100644 --- a/src/alchemlyb/tests/test_convergence.py +++ b/src/alchemlyb/tests/test_convergence.py @@ -22,7 +22,7 @@ def test_convergence_ti(gmx_benzene): def test_convergence_fep(gmx_benzene): dHdl, u_nk = gmx_benzene - convergence = forward_backward_convergence(u_nk, estimator) + convergence = forward_backward_convergence(u_nk, 'MBAR') assert convergence.shape == (10, 5) assert convergence.loc[0, 'Forward'] == pytest.approx(3.02, 0.01) assert convergence.loc[0, 'Backward'] == pytest.approx(3.06, 0.01) diff --git a/src/alchemlyb/tests/test_ti_estimators.py b/src/alchemlyb/tests/test_ti_estimators.py index 01b52de4..aed93d64 100644 --- a/src/alchemlyb/tests/test_ti_estimators.py +++ b/src/alchemlyb/tests/test_ti_estimators.py @@ -109,8 +109,8 @@ class TIestimatorMixin: def test_get_delta_f(self, X_delta_f): dHdl, E, dE = X_delta_f est = self.cls().fit(dHdl) - delta_f = est.delta_f_.iloc[0, -1] - d_delta_f = est.d_delta_f_.iloc[0, -1] + delta_f = est.delta_f_.loc[(0.0,1.0)] + d_delta_f = est.d_delta_f_.loc[(0.0,1.0)] assert E == pytest.approx(delta_f, rel=1e-3) assert dE == pytest.approx(d_delta_f, rel=1e-3) diff --git a/src/alchemlyb/tests/test_units.py b/src/alchemlyb/tests/test_units.py index db9c613e..2d2fd76b 100644 --- a/src/alchemlyb/tests/test_units.py +++ b/src/alchemlyb/tests/test_units.py @@ -56,7 +56,7 @@ def dhdl(): def test_kt2kt_number(self, dhdl): new_dhdl = to_kT(dhdl) - assert 12.9 == pytest.approx(new_dhdl.iloc[0, 0], 0.1) + assert 12.9 == pytest.approx(new_dhdl.loc[(0.0,0.0)], 0.1) def test_kt2kt_unit(self, dhdl): new_dhdl = to_kT(dhdl) @@ -70,7 +70,7 @@ def test_kj2kt_unit(self, dhdl): def test_kj2kt_number(self, dhdl): dhdl.attrs['energy_unit'] = 'kJ/mol' new_dhdl = to_kT(dhdl) - assert 5.0 == pytest.approx(new_dhdl.iloc[0, 0], 0.1) + assert 5.0 == pytest.approx(new_dhdl.loc[(0.0,0.0)], 0.1) def test_kcal2kt_unit(self, dhdl): dhdl.attrs['energy_unit'] = 'kcal/mol' @@ -80,7 +80,7 @@ def test_kcal2kt_unit(self, dhdl): def test_kcal2kt_number(self, dhdl): dhdl.attrs['energy_unit'] = 'kcal/mol' new_dhdl = to_kT(dhdl) - assert 21.0 == pytest.approx(new_dhdl.iloc[0, 0], 0.1) + assert 21.0 == pytest.approx(new_dhdl.loc[(0.0,0.0)], 0.1) def test_unknown2kt(self, dhdl): dhdl.attrs['energy_unit'] = 'ddd' From 7609dc26a8e6970319ab76f21ceaec28ab8b8ed4 Mon Sep 17 00:00:00 2001 From: "William (Zhiyi) Wu" Date: Sun, 16 Oct 2022 21:32:18 +0100 Subject: [PATCH 04/21] update --- src/alchemlyb/tests/test_fep_estimators.py | 2 +- src/alchemlyb/tests/test_preprocessing.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/alchemlyb/tests/test_fep_estimators.py b/src/alchemlyb/tests/test_fep_estimators.py index e6c340c7..7135c255 100644 --- a/src/alchemlyb/tests/test_fep_estimators.py +++ b/src/alchemlyb/tests/test_fep_estimators.py @@ -181,7 +181,7 @@ def test_failback_adaptive(self, n_uk_list): # The hybr will fail on this while adaptive will work mbar = AutoMBAR().fit(alchemlyb.concat([n_uk[:2] for n_uk in n_uk_list])) - assert np.isclose(mbar.d_delta_f_.iloc[0, -1], 1.76832, 0.1) + assert np.isclose(mbar.d_delta_f_[(0.0, 0.0, 0.0)][(1.0, 1.0, 1.0)], 1.76832, 0.1) def test_AutoMBAR_BGFS(): # A case where only BFGS would work diff --git a/src/alchemlyb/tests/test_preprocessing.py b/src/alchemlyb/tests/test_preprocessing.py index 595eda90..c189df28 100644 --- a/src/alchemlyb/tests/test_preprocessing.py +++ b/src/alchemlyb/tests/test_preprocessing.py @@ -135,10 +135,10 @@ def test_disordered_exception(self, data): """Test that a shuffled DataFrame yields a KeyError. """ - indices = np.arange(len(data)) + indices = data.index.values np.random.shuffle(indices) - df = data.iloc[indices] + df = data.loc[indices] with pytest.raises(KeyError): self.slicer(df, lower=200) From 50bf8aa790040d4a098d1d3eb63b67ee471e5d97 Mon Sep 17 00:00:00 2001 From: "William (Zhiyi) Wu" Date: Sun, 16 Oct 2022 22:00:19 +0100 Subject: [PATCH 05/21] ifx test --- src/alchemlyb/tests/test_ti_estimators.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/alchemlyb/tests/test_ti_estimators.py b/src/alchemlyb/tests/test_ti_estimators.py index aed93d64..01b52de4 100644 --- a/src/alchemlyb/tests/test_ti_estimators.py +++ b/src/alchemlyb/tests/test_ti_estimators.py @@ -109,8 +109,8 @@ class TIestimatorMixin: def test_get_delta_f(self, X_delta_f): dHdl, E, dE = X_delta_f est = self.cls().fit(dHdl) - delta_f = est.delta_f_.loc[(0.0,1.0)] - d_delta_f = est.d_delta_f_.loc[(0.0,1.0)] + delta_f = est.delta_f_.iloc[0, -1] + d_delta_f = est.d_delta_f_.iloc[0, -1] assert E == pytest.approx(delta_f, rel=1e-3) assert dE == pytest.approx(d_delta_f, rel=1e-3) From e4c0a4d6532427730f9fac34a8b9b4a7623843c7 Mon Sep 17 00:00:00 2001 From: "William (Zhiyi) Wu" Date: Sun, 4 Dec 2022 10:48:49 +0000 Subject: [PATCH 06/21] update --- CHANGES | 1 + src/alchemlyb/tests/test_convergence.py | 5 +++-- src/alchemlyb/tests/test_preprocessing.py | 4 ++-- src/alchemlyb/tests/test_ti_estimators.py | 2 ++ 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/CHANGES b/CHANGES index 4bd7e8cc..45d1de88 100644 --- a/CHANGES +++ b/CHANGES @@ -18,6 +18,7 @@ The rules for this file: * 1.0.1 Fixes + - Remove most of the iloc in the tests (issue #202, PR #254). - AMBER parser now raises ValueError when the initial simulation time is not found (issue #272, PR #273). - The regex in the AMBER parser now reads also 'field=value' pairs where diff --git a/src/alchemlyb/tests/test_convergence.py b/src/alchemlyb/tests/test_convergence.py index f385a7b0..bae2b350 100644 --- a/src/alchemlyb/tests/test_convergence.py +++ b/src/alchemlyb/tests/test_convergence.py @@ -24,9 +24,10 @@ def test_convergence_ti(gmx_benzene): assert convergence.loc[9, 'Forward'] == pytest.approx(3.09, 0.01) assert convergence.loc[9, 'Backward'] == pytest.approx(3.09, 0.01) -def test_convergence_fep(gmx_benzene): +@pytest.mark.parametrize('estimator', ['MBAR', 'BAR']) +def test_convergence_fep(gmx_benzene, estimator): dHdl, u_nk = gmx_benzene - convergence = forward_backward_convergence(u_nk, 'MBAR') + convergence = forward_backward_convergence(u_nk, estimator) assert convergence.shape == (10, 5) assert convergence.loc[0, 'Forward'] == pytest.approx(3.02, 0.01) assert convergence.loc[0, 'Backward'] == pytest.approx(3.06, 0.01) diff --git a/src/alchemlyb/tests/test_preprocessing.py b/src/alchemlyb/tests/test_preprocessing.py index 05048ad8..a124c1aa 100644 --- a/src/alchemlyb/tests/test_preprocessing.py +++ b/src/alchemlyb/tests/test_preprocessing.py @@ -233,7 +233,7 @@ def test_subsampling(self, data, size): """Basic test for execution; resulting size of dataset sensitive to machine and depends on algorithm. """ - assert len(self.slicer(data, series=data.iloc[:, 0])) <= size + assert len(self.slicer(data, series=data.loc[:, data.columns[0]])) <= size @pytest.mark.parametrize('data', [gmx_benzene_dHdl(), gmx_benzene_u_nk()]) @@ -260,7 +260,7 @@ def slicer(self, *args, **kwargs): (False, gmx_benzene_u_nk(), 3571), ]) def test_conservative(self, data, size, conservative): - sliced = self.slicer(data, series=data.iloc[:, 0], conservative=conservative) + sliced = self.slicer(data, series=data.loc[:, data.columns[0]], conservative=conservative) # results can vary slightly with different machines # so possibly do # delta = 10 diff --git a/src/alchemlyb/tests/test_ti_estimators.py b/src/alchemlyb/tests/test_ti_estimators.py index 01b52de4..38b623a8 100644 --- a/src/alchemlyb/tests/test_ti_estimators.py +++ b/src/alchemlyb/tests/test_ti_estimators.py @@ -109,6 +109,8 @@ class TIestimatorMixin: def test_get_delta_f(self, X_delta_f): dHdl, E, dE = X_delta_f est = self.cls().fit(dHdl) + # Use .iloc[0, -1] as we want to cater for both + # delta_f_.loc[0.0, 1.0] and delta_f_.loc[(0.0, 0.0), (0.0, 1.0)] delta_f = est.delta_f_.iloc[0, -1] d_delta_f = est.d_delta_f_.iloc[0, -1] From 6839b8bd17f174342e6b396f7ce3d32957c2a280 Mon Sep 17 00:00:00 2001 From: "William (Zhiyi) Wu" Date: Sun, 4 Dec 2022 10:49:48 +0000 Subject: [PATCH 07/21] update --- src/alchemlyb/tests/test_fep_estimators.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/alchemlyb/tests/test_fep_estimators.py b/src/alchemlyb/tests/test_fep_estimators.py index 52c93e88..f61a1a43 100644 --- a/src/alchemlyb/tests/test_fep_estimators.py +++ b/src/alchemlyb/tests/test_fep_estimators.py @@ -139,6 +139,8 @@ def compare_delta_f(self, X_delta_f): assert X_delta_f[2] == pytest.approx(d_delta_f, rel=1e-3) def get_delta_f(self, est): + # Use .iloc[0, -1] as we want to cater for both + # delta_f_.loc[0.0, 1.0] and delta_f_.loc[(0.0, 0.0), (0.0, 1.0)] return est.delta_f_.iloc[0, -1], est.d_delta_f_.iloc[0, -1] @@ -227,6 +229,8 @@ def get_delta_f(self, est): for i in range(len(est.d_delta_f_) - 1): ee += est.d_delta_f_.values[i][i+1]**2 + # Use .iloc[0, -1] as we want to cater for both + # delta_f_.loc[0.0, 1.0] and delta_f_.loc[(0.0, 0.0), (0.0, 1.0)] return est.delta_f_.iloc[0, -1], ee**0.5 class Test_Units(): From 6d8ec6b6759260539963959a4a5871ab4e8ec857 Mon Sep 17 00:00:00 2001 From: "William (Zhiyi) Wu" Date: Sun, 4 Dec 2022 12:58:32 +0000 Subject: [PATCH 08/21] update --- src/alchemlyb/__init__.py | 37 +- src/alchemlyb/tests/conftest.py | 169 ++++++ src/alchemlyb/tests/test_convergence.py | 97 ++-- src/alchemlyb/tests/test_fep_estimators.py | 266 +++------ src/alchemlyb/tests/test_preprocessing.py | 643 ++++++++++++--------- 5 files changed, 676 insertions(+), 536 deletions(-) create mode 100644 src/alchemlyb/tests/conftest.py diff --git a/src/alchemlyb/__init__.py b/src/alchemlyb/__init__.py index 7cfae4b9..c6d6c9ed 100644 --- a/src/alchemlyb/__init__.py +++ b/src/alchemlyb/__init__.py @@ -1,27 +1,32 @@ -import pandas as pd from functools import wraps +import pandas as pd + from ._version import get_versions -__version__ = get_versions()['version'] + +__version__ = get_versions()["version"] del get_versions + def pass_attrs(func): - '''Pass the attrs from the first positional argument to the output + """Pass the attrs from the first positional argument to the output dataframe. - - + + .. versionadded:: 0.5.0 - ''' + """ @wraps(func) - def wrapper(input_dataframe, *args,**kwargs): - dataframe = func(input_dataframe, *args,**kwargs) + def wrapper(input_dataframe, *args, **kwargs): + dataframe = func(input_dataframe, *args, **kwargs) dataframe.attrs = input_dataframe.attrs return dataframe + return wrapper + def concat(objs, *args, **kwargs): - '''Concatenate pandas objects while persevering the attrs. + """Concatenate pandas objects while persevering the attrs. Concatenate pandas objects along a particular axis with optional set logic along the other axes. If all pandas objects have the same attrs @@ -46,16 +51,18 @@ def concat(objs, *args, **kwargs): See Also -------- pandas.concat - - - .. versionadded:: 0.5.0''' + + + .. versionadded:: 0.5.0""" + if isinstance(objs, pd.DataFrame): + return objs # Sanity check try: attrs = objs[0].attrs - except IndexError: # except empty list as input - raise ValueError('No objects to concatenate') + except IndexError: # except empty list as input + raise ValueError("No objects to concatenate") for obj in objs: if attrs != obj.attrs: - raise ValueError('All pandas objects should have the same attrs.') + raise ValueError("All pandas objects should have the same attrs.") return pd.concat(objs, *args, **kwargs) diff --git a/src/alchemlyb/tests/conftest.py b/src/alchemlyb/tests/conftest.py new file mode 100644 index 00000000..ab3ec2f6 --- /dev/null +++ b/src/alchemlyb/tests/conftest.py @@ -0,0 +1,169 @@ +import pytest +from alchemtest.amber import load_bace_example +from alchemtest.gmx import ( + load_benzene, + load_expanded_ensemble_case_1, + load_expanded_ensemble_case_2, + load_expanded_ensemble_case_3, + load_water_particle_with_total_energy, + load_water_particle_with_potential_energy, + load_water_particle_without_energy, + load_ABFE, +) +from alchemtest.gomc import load_benzene as gomc_load_benzene +from alchemtest.namd import ( + load_tyr2ala, + load_idws, + load_restarted, + load_restarted_reversed, +) + +from alchemlyb.parsing import gmx, amber, gomc, namd + + +@pytest.fixture +def gmx_benzene(): + dataset = load_benzene() + return dataset["data"] + + +@pytest.fixture +def gmx_benzene_Coulomb_dHdl(gmx_benzene): + return [gmx.extract_dHdl(file, T=300) for file in gmx_benzene["Coulomb"]] + + +@pytest.fixture +def gmx_benzene_Coulomb_u_nk(gmx_benzene): + return [gmx.extract_u_nk(file, T=300) for file in gmx_benzene["Coulomb"]] + + +@pytest.fixture +def gmx_benzene_VDW_u_nk(gmx_benzene): + return [gmx.extract_u_nk(file, T=300) for file in gmx_benzene["VDW"]] + + +@pytest.fixture +def gmx_ABFE(): + dataset = load_ABFE() + return dataset["data"] + + +@pytest.fixture +def gmx_ABFE_complex_n_uk(gmx_ABFE): + return [gmx.extract_u_nk(file, T=300) for file in gmx_ABFE["complex"]] + + +@pytest.fixture +def gmx_ABFE_complex_dHdl(gmx_ABFE): + return [gmx.extract_dHdl(file, T=300) for file in gmx_ABFE["complex"]] + + +@pytest.fixture +def gmx_expanded_ensemble_case_1(): + dataset = load_expanded_ensemble_case_1() + + return [ + gmx.extract_u_nk(filename, T=300, filter=False) + for filename in dataset["data"]["AllStates"] + ] + + +@pytest.fixture +def gmx_expanded_ensemble_case_2(): + dataset = load_expanded_ensemble_case_2() + + return [ + gmx.extract_u_nk(filename, T=300, filter=False) + for filename in dataset["data"]["AllStates"] + ] + + +@pytest.fixture +def gmx_expanded_ensemble_case_3(): + dataset = load_expanded_ensemble_case_3() + + return [ + gmx.extract_u_nk(filename, T=300, filter=False) + for filename in dataset["data"]["AllStates"] + ] + + +@pytest.fixture +def gmx_water_particle_with_total_energy(): + dataset = load_water_particle_with_total_energy() + + return [ + gmx.extract_u_nk(filename, T=300) for filename in dataset["data"]["AllStates"] + ] + + +@pytest.fixture +def gmx_water_particle_with_potential_energy(): + dataset = load_water_particle_with_potential_energy() + + return [ + gmx.extract_u_nk(filename, T=300) for filename in dataset["data"]["AllStates"] + ] + + +@pytest.fixture +def gmx_water_particle_without_energy(): + dataset = load_water_particle_without_energy() + + return [ + gmx.extract_u_nk(filename, T=300) for filename in dataset["data"]["AllStates"] + ] + + +@pytest.fixture +def amber_bace_example_complex_vdw(): + dataset = load_bace_example() + + return [ + amber.extract_u_nk(filename, T=298.0) + for filename in dataset["data"]["complex"]["vdw"] + ] + + +@pytest.fixture +def gomc_benzene_u_nk(): + dataset = gomc_load_benzene() + + return [gomc.extract_u_nk(filename, T=298) for filename in dataset["data"]] + + +@pytest.fixture +def namd_tyr2ala(): + dataset = load_tyr2ala() + u_nk1 = namd.extract_u_nk(dataset["data"]["forward"][0], T=300) + u_nk2 = namd.extract_u_nk(dataset["data"]["backward"][0], T=300) + + # combine dataframes of fwd and rev directions + u_nk1[u_nk1.isna()] = u_nk2 + u_nk = u_nk1.sort_index(level=u_nk1.index.names[1:]) + + return u_nk + + +@pytest.fixture +def namd_idws(): + dataset = load_idws() + u_nk = namd.extract_u_nk(dataset["data"]["forward"], T=300) + + return u_nk + + +@pytest.fixture +def namd_idws_restarted(): + dataset = load_restarted() + u_nk = namd.extract_u_nk(dataset["data"]["both"], T=300) + + return u_nk + + +@pytest.fixture +def namd_idws_restarted_reversed(): + dataset = load_restarted_reversed() + u_nk = namd.extract_u_nk(dataset["data"]["both"], T=300) + + return u_nk diff --git a/src/alchemlyb/tests/test_convergence.py b/src/alchemlyb/tests/test_convergence.py index bae2b350..d0ffb2c2 100644 --- a/src/alchemlyb/tests/test_convergence.py +++ b/src/alchemlyb/tests/test_convergence.py @@ -2,92 +2,93 @@ import pandas as pd import pytest -from alchemtest.gmx import load_benzene -from alchemlyb.parsing import gmx -from alchemlyb.convergence import forward_backward_convergence, fwdrev_cumavg_Rc, A_c +from alchemlyb.convergence import forward_backward_convergence, \ + fwdrev_cumavg_Rc, A_c from alchemlyb.convergence.convergence import _cummean -@pytest.fixture() -def gmx_benzene(): - dataset = load_benzene() - return [gmx.extract_dHdl(dhdl, T=300) for dhdl in dataset['data']['Coulomb']], \ - [gmx.extract_u_nk(dhdl, T=300) for dhdl in dataset['data']['Coulomb']] - -def test_convergence_ti(gmx_benzene): - dHdl, u_nk = gmx_benzene - convergence = forward_backward_convergence(dHdl, 'TI') +def test_convergence_ti(gmx_benzene_Coulomb_dHdl): + convergence = forward_backward_convergence(gmx_benzene_Coulomb_dHdl, "TI") assert convergence.shape == (10, 5) - assert convergence.loc[0, 'Forward'] == pytest.approx(3.07, 0.01) - assert convergence.loc[0, 'Backward'] == pytest.approx(3.11, 0.01) - assert convergence.loc[9, 'Forward'] == pytest.approx(3.09, 0.01) - assert convergence.loc[9, 'Backward'] == pytest.approx(3.09, 0.01) + assert convergence.loc[0, "Forward"] == pytest.approx(3.07, 0.01) + assert convergence.loc[0, "Backward"] == pytest.approx(3.11, 0.01) + assert convergence.loc[9, "Forward"] == pytest.approx(3.09, 0.01) + assert convergence.loc[9, "Backward"] == pytest.approx(3.09, 0.01) + -@pytest.mark.parametrize('estimator', ['MBAR', 'BAR']) -def test_convergence_fep(gmx_benzene, estimator): - dHdl, u_nk = gmx_benzene - convergence = forward_backward_convergence(u_nk, estimator) +@pytest.mark.parametrize("estimator", ["MBAR", "BAR"]) +def test_convergence_fep(gmx_benzene_Coulomb_u_nk, estimator): + convergence = forward_backward_convergence(gmx_benzene_Coulomb_u_nk, estimator) assert convergence.shape == (10, 5) - assert convergence.loc[0, 'Forward'] == pytest.approx(3.02, 0.01) - assert convergence.loc[0, 'Backward'] == pytest.approx(3.06, 0.01) - assert convergence.loc[9, 'Forward'] == pytest.approx(3.05, 0.01) - assert convergence.loc[9, 'Backward'] == pytest.approx(3.04, 0.01) + assert convergence.loc[0, "Forward"] == pytest.approx(3.02, 0.01) + assert convergence.loc[0, "Backward"] == pytest.approx(3.06, 0.01) + assert convergence.loc[9, "Forward"] == pytest.approx(3.05, 0.01) + assert convergence.loc[9, "Backward"] == pytest.approx(3.04, 0.01) + -def test_convergence_wrong_estimator(gmx_benzene): - dHdl, u_nk = gmx_benzene +def test_convergence_wrong_estimator(gmx_benzene_Coulomb_dHdl): with pytest.raises(ValueError, match="is not available in"): - forward_backward_convergence(u_nk, 'WWW') + forward_backward_convergence(gmx_benzene_Coulomb_dHdl, "WWW") -def test_convergence_wrong_cases(gmx_benzene): - dHdl, u_nk = gmx_benzene + +def test_convergence_wrong_cases(gmx_benzene_Coulomb_u_nk): with pytest.warns(DeprecationWarning, match="Using lower-case strings for"): - forward_backward_convergence(u_nk, 'mbar') + forward_backward_convergence(gmx_benzene_Coulomb_u_nk, "mbar") + -def test_convergence_method(gmx_benzene): - dHdl, u_nk = gmx_benzene - convergence = forward_backward_convergence(u_nk, 'MBAR', num=2, method='adaptive') +def test_convergence_method(gmx_benzene_Coulomb_u_nk): + convergence = forward_backward_convergence( + gmx_benzene_Coulomb_u_nk, "MBAR", num=2, method="adaptive" + ) assert len(convergence) == 2 + def test_cummean_short(): - '''Test the case where the input is shorter than the expected output''' + """Test the case where the input is shorter than the expected output""" value = _cummean(np.empty(10), 100) assert len(value) == 10 + def test_cummean_long(): - '''Test the case where the input is longer than the expected output''' + """Test the case where the input is longer than the expected output""" value = _cummean(np.empty(20), 10) assert len(value) == 10 + def test_cummean_long_none_integter(): - '''Test the case where the input is not a integer multiple of the expected output''' + """Test the case where the input is not a integer multiple of the expected output""" value = _cummean(np.empty(25), 10) assert len(value) == 10 + def test_R_c_converged(): - data = pd.Series(data=[0,]*100) - data.attrs['temperature'] = 310 - data.attrs['energy_unit'] = 'kcal/mol' + data = pd.Series(data=[0] * 100) + data.attrs["temperature"] = 310 + data.attrs["energy_unit"] = "kcal/mol" value, running_average = fwdrev_cumavg_Rc(data) np.testing.assert_allclose(value, 0.0) + def test_R_c_notconverged(): data = pd.Series(data=range(21)) - data.attrs['temperature'] = 310 - data.attrs['energy_unit'] = 'kcal/mol' + data.attrs["temperature"] = 310 + data.attrs["energy_unit"] = "kcal/mol" value, running_average = fwdrev_cumavg_Rc(data, tol=0.1, precision=0.05) np.testing.assert_allclose(value, 1.0) + def test_R_c_real(): - data = pd.Series(data=np.hstack((range(10), [4.5,]*10))) - data.attrs['temperature'] = 310 - data.attrs['energy_unit'] = 'kcal/mol' + data = pd.Series(data=np.hstack((range(10), [4.5] * 10))) + data.attrs["temperature"] = 310 + data.attrs["energy_unit"] = "kcal/mol" value, running_average = fwdrev_cumavg_Rc(data, tol=2.0) np.testing.assert_allclose(value, 0.35) + def test_A_c_real(): - data = pd.Series(data=np.hstack((range(10), [4.5,]*10))) - data.attrs['temperature'] = 310 - data.attrs['energy_unit'] = 'kcal/mol' - value = A_c([data, ] * 2, tol=2.0) + data = pd.Series(data=np.hstack((range(10), [4.5] * 10))) + data.attrs["temperature"] = 310 + data.attrs["energy_unit"] = "kcal/mol" + value = A_c([data] * 2, tol=2.0) np.testing.assert_allclose(value, 0.65) diff --git a/src/alchemlyb/tests/test_fep_estimators.py b/src/alchemlyb/tests/test_fep_estimators.py index f61a1a43..05a2dd24 100644 --- a/src/alchemlyb/tests/test_fep_estimators.py +++ b/src/alchemlyb/tests/test_fep_estimators.py @@ -1,137 +1,19 @@ """Tests for all FEP-based estimators in ``alchemlyb``. """ -import pytest - import numpy as np -import pandas as pd +import pytest +from alchemtest.generic import load_MBAR_BGFS import alchemlyb -from alchemlyb.parsing import gmx, amber, namd, gomc from alchemlyb.estimators import MBAR, BAR, AutoMBAR -import alchemtest.gmx -import alchemtest.amber -import alchemtest.gomc -import alchemtest.namd -from alchemtest.gmx import load_benzene, load_ABFE -from alchemlyb.parsing.gmx import extract_u_nk -from alchemtest.generic import load_MBAR_BGFS - -def gmx_benzene_coul_u_nk(): - dataset = alchemtest.gmx.load_benzene() - - u_nk = alchemlyb.concat([gmx.extract_u_nk(filename, T=300) - for filename in dataset['data']['Coulomb']]) - - return u_nk - -def gmx_benzene_vdw_u_nk(): - dataset = alchemtest.gmx.load_benzene() - - u_nk = alchemlyb.concat([gmx.extract_u_nk(filename, T=300) - for filename in dataset['data']['VDW']]) - - return u_nk - -def gmx_expanded_ensemble_case_1(): - dataset = alchemtest.gmx.load_expanded_ensemble_case_1() - - u_nk = alchemlyb.concat([gmx.extract_u_nk(filename, T=300, filter=False) - for filename in dataset['data']['AllStates']]) - - return u_nk - -def gmx_expanded_ensemble_case_2(): - dataset = alchemtest.gmx.load_expanded_ensemble_case_2() - - u_nk = alchemlyb.concat([gmx.extract_u_nk(filename, T=300, filter=False) - for filename in dataset['data']['AllStates']]) - - return u_nk - -def gmx_expanded_ensemble_case_3(): - dataset = alchemtest.gmx.load_expanded_ensemble_case_3() - - u_nk = alchemlyb.concat([gmx.extract_u_nk(filename, T=300, filter=False) - for filename in dataset['data']['AllStates']]) - - return u_nk - -def gmx_water_particle_with_total_energy(): - dataset = alchemtest.gmx.load_water_particle_with_total_energy() - - u_nk = alchemlyb.concat([gmx.extract_u_nk(filename, T=300) - for filename in dataset['data']['AllStates']]) - - return u_nk - -def gmx_water_particle_with_potential_energy(): - dataset = alchemtest.gmx.load_water_particle_with_potential_energy() - - u_nk = alchemlyb.concat([gmx.extract_u_nk(filename, T=300) - for filename in dataset['data']['AllStates']]) - - return u_nk - -def gmx_water_particle_without_energy(): - dataset = alchemtest.gmx.load_water_particle_without_energy() - - u_nk = alchemlyb.concat([gmx.extract_u_nk(filename, T=300) - for filename in dataset['data']['AllStates']]) - - return u_nk - -def amber_bace_example_complex_vdw(): - dataset = alchemtest.amber.load_bace_example() - - u_nk = alchemlyb.concat([amber.extract_u_nk(filename, T=298.0) - for filename in dataset['data']['complex']['vdw']]) - return u_nk - -def gomc_benzene_u_nk(): - dataset = alchemtest.gomc.load_benzene() - - u_nk = alchemlyb.concat([gomc.extract_u_nk(filename, T=298) - for filename in dataset['data']]) - - return u_nk - -def namd_tyr2ala(): - dataset = alchemtest.namd.load_tyr2ala() - u_nk1 = namd.extract_u_nk(dataset['data']['forward'][0], T=300) - u_nk2 = namd.extract_u_nk(dataset['data']['backward'][0], T=300) - - # combine dataframes of fwd and rev directions - u_nk1[u_nk1.isna()] = u_nk2 - u_nk = u_nk1.sort_index(level=u_nk1.index.names[1:]) - - return u_nk - -def namd_idws(): - dataset = alchemtest.namd.load_idws() - u_nk = namd.extract_u_nk(dataset['data']['forward'], T=300) - - return u_nk - -def namd_idws_restarted(): - dataset = alchemtest.namd.load_restarted() - u_nk = namd.extract_u_nk(dataset['data']['both'], T=300) - - return u_nk - -def namd_idws_restarted_reversed(): - dataset = alchemtest.namd.load_restarted_reversed() - u_nk = namd.extract_u_nk(dataset['data']['both'], T=300) - - return u_nk class FEPestimatorMixin: - """Mixin for all FEP Estimator test classes. - - """ + """Mixin for all FEP Estimator test classes.""" def compare_delta_f(self, X_delta_f): + est = self.cls().fit(X_delta_f[0]) delta_f, d_delta_f = self.get_delta_f(est) @@ -145,81 +27,84 @@ def get_delta_f(self, est): class TestMBAR(FEPestimatorMixin): - """Tests for MBAR. + """Tests for MBAR.""" - """ cls = MBAR - @pytest.fixture(scope="class", - params=[(gmx_benzene_coul_u_nk, 3.041, 0.02088), - (gmx_benzene_vdw_u_nk, -3.007, 0.04519), - (gmx_expanded_ensemble_case_1, 75.923, 0.14124), - (gmx_expanded_ensemble_case_2, 75.915, 0.14372), - (gmx_expanded_ensemble_case_3, 76.173, 0.11345), - (gmx_water_particle_with_total_energy, -11.680, 0.083655), - (gmx_water_particle_with_potential_energy, -11.675, 0.083589), - (gmx_water_particle_without_energy, -11.654, 0.083415), - (amber_bace_example_complex_vdw, 2.41149, 0.0620658), - (gomc_benzene_u_nk, -0.79994, 0.091579), - ]) + @pytest.fixture( + params=[ + ("gmx_benzene_Coulomb_u_nk", 3.041, 0.02088), + ("gmx_benzene_VDW_u_nk", -3.007, 0.04519), + ("gmx_expanded_ensemble_case_1", 75.923, 0.14124), + ("gmx_expanded_ensemble_case_2", 75.915, 0.14372), + ("gmx_expanded_ensemble_case_3", 76.173, 0.11345), + ("gmx_water_particle_with_total_energy", -11.680, 0.083655), + ("gmx_water_particle_with_potential_energy", -11.675, 0.083589), + ("gmx_water_particle_without_energy", -11.654, 0.083415), + ("amber_bace_example_complex_vdw", 2.41149, 0.0620658), + ("gomc_benzene_u_nk", -0.79994, 0.091579), + ], + ) def X_delta_f(self, request): get_unk, E, dE = request.param - return get_unk(), E, dE + return alchemlyb.concat(request.getfixturevalue(get_unk)), E, dE def test_mbar(self, X_delta_f): self.compare_delta_f(X_delta_f) + class TestAutoMBAR(TestMBAR): cls = AutoMBAR -class TestMBAR_fail(): - @pytest.fixture(scope="class") - def n_uk_list(self): - n_uk_list = [gmx.extract_u_nk(dhdl, T=300) for dhdl in - load_ABFE()['data']['complex']] - return n_uk_list - def test_failback_adaptive(self, n_uk_list): +class TestMBAR_fail: + def test_failback_adaptive(self, gmx_ABFE_complex_n_uk): # The hybr will fail on this while adaptive will work - mbar = AutoMBAR().fit(alchemlyb.concat([n_uk[:2] for n_uk in - n_uk_list])) - assert np.isclose(mbar.d_delta_f_[(0.0, 0.0, 0.0)][(1.0, 1.0, 1.0)], 1.76832, 0.1) + mbar = AutoMBAR().fit( + alchemlyb.concat([n_uk[:2] for n_uk in gmx_ABFE_complex_n_uk]) + ) + assert np.isclose( + mbar.d_delta_f_[(0.0, 0.0, 0.0)][(1.0, 1.0, 1.0)], 1.76832, 0.1 + ) + def test_AutoMBAR_BGFS(): # A case where only BFGS would work mbar = AutoMBAR() - u_nk = np.load(load_MBAR_BGFS()['data']['u_nk']) - N_k = np.load(load_MBAR_BGFS()['data']['N_k']) - solver_options = {"maximum_iterations": 10000,"verbose": False} + u_nk = np.load(load_MBAR_BGFS()["data"]["u_nk"]) + N_k = np.load(load_MBAR_BGFS()["data"]["N_k"]) + solver_options = {"maximum_iterations": 10000, "verbose": False} solver_protocol = {"method": None, "options": solver_options} mbar, out = mbar._do_MBAR(u_nk.T, N_k, solver_protocol) assert np.isclose(out[0][1][0], 12.552409, 0.1) + class TestBAR(FEPestimatorMixin): - """Tests for BAR. + """Tests for BAR.""" - """ cls = BAR - @pytest.fixture(scope="class", - params = [(gmx_benzene_coul_u_nk, 3.044, 0.01640), - (gmx_benzene_vdw_u_nk, -3.033, 0.03438), - (gmx_expanded_ensemble_case_1, 75.993, 0.11056), - (gmx_expanded_ensemble_case_2, 76.009, 0.11220), - (gmx_expanded_ensemble_case_3, 76.219, 0.08886), - (gmx_water_particle_with_total_energy, -11.675, 0.065055), - (gmx_water_particle_with_potential_energy, -11.724, 0.064964), - (gmx_water_particle_without_energy, -11.660, 0.064914), - (amber_bace_example_complex_vdw, 2.39294, 0.051192), - (namd_tyr2ala, 11.0044, 0.10235), - (namd_idws, 0.221147, 0.041003), - (namd_idws_restarted, 7.081127, 0.0344211), - (namd_idws_restarted_reversed, -4.18405, 0.03457), - (gomc_benzene_u_nk, -0.87095, 0.071263), - ]) + @pytest.fixture( + params=[ + ("gmx_benzene_Coulomb_u_nk", 3.044, 0.01640), + ("gmx_benzene_VDW_u_nk", -3.033, 0.03438), + ("gmx_expanded_ensemble_case_1", 75.993, 0.11056), + ("gmx_expanded_ensemble_case_2", 76.009, 0.11220), + ("gmx_expanded_ensemble_case_3", 76.219, 0.08886), + ("gmx_water_particle_with_total_energy", -11.675, 0.065055), + ("gmx_water_particle_with_potential_energy", -11.724, 0.064964), + ("gmx_water_particle_without_energy", -11.660, 0.064914), + ("amber_bace_example_complex_vdw", 2.39294, 0.051192), + ("namd_tyr2ala", 11.0044, 0.10235), + ("namd_idws", 0.221147, 0.041003), + ("namd_idws_restarted", 7.081127, 0.0344211), + ("namd_idws_restarted_reversed", -4.18405, 0.03457), + ("gomc_benzene_u_nk", -0.87095, 0.071263), + ], + ) def X_delta_f(self, request): get_unk, E, dE = request.param - return get_unk(), E, dE + return alchemlyb.concat(request.getfixturevalue(get_unk)), E, dE def test_bar(self, X_delta_f): self.compare_delta_f(X_delta_f) @@ -228,40 +113,39 @@ def get_delta_f(self, est): ee = 0.0 for i in range(len(est.d_delta_f_) - 1): - ee += est.d_delta_f_.values[i][i+1]**2 + ee += est.d_delta_f_.values[i][i + 1] ** 2 # Use .iloc[0, -1] as we want to cater for both # delta_f_.loc[0.0, 1.0] and delta_f_.loc[(0.0, 0.0), (0.0, 1.0)] return est.delta_f_.iloc[0, -1], ee**0.5 -class Test_Units(): - '''Test the units.''' + +class Test_Units: + """Test the units.""" @staticmethod - @pytest.fixture(scope='class') - def u_nk(): - bz = load_benzene().data - u_nk_coul = alchemlyb.concat( - [extract_u_nk(xvg, T=300) for xvg in bz['Coulomb']]) - u_nk_coul.attrs = extract_u_nk(load_benzene().data['Coulomb'][0], T=300).attrs - return u_nk_coul + @pytest.fixture() + def u_nk(gmx_benzene_Coulomb_u_nk): + return alchemlyb.concat(gmx_benzene_Coulomb_u_nk) def test_bar(self, u_nk): bar = BAR().fit(u_nk) - assert bar.delta_f_.attrs['temperature'] == 300 - assert bar.delta_f_.attrs['energy_unit'] == 'kT' - assert bar.d_delta_f_.attrs['temperature'] == 300 - assert bar.d_delta_f_.attrs['energy_unit'] == 'kT' + assert bar.delta_f_.attrs["temperature"] == 300 + assert bar.delta_f_.attrs["energy_unit"] == "kT" + assert bar.d_delta_f_.attrs["temperature"] == 300 + assert bar.d_delta_f_.attrs["energy_unit"] == "kT" def test_mbar(self, u_nk): mbar = MBAR().fit(u_nk) - assert mbar.delta_f_.attrs['temperature'] == 300 - assert mbar.delta_f_.attrs['energy_unit'] == 'kT' - assert mbar.d_delta_f_.attrs['temperature'] == 300 - assert mbar.d_delta_f_.attrs['energy_unit'] == 'kT' - -class TestEstimatorMixOut(): - '''Ensure that the attribute d_delta_f_, delta_f_, states_ cannot be - modified. ''' + assert mbar.delta_f_.attrs["temperature"] == 300 + assert mbar.delta_f_.attrs["energy_unit"] == "kT" + assert mbar.d_delta_f_.attrs["temperature"] == 300 + assert mbar.d_delta_f_.attrs["energy_unit"] == "kT" + + +class TestEstimatorMixOut: + """Ensure that the attribute d_delta_f_, delta_f_, states_ cannot be + modified.""" + @pytest.mark.parametrize("estimator", [MBAR, BAR]) def test_d_delta_f_(self, estimator): _estimator = estimator() diff --git a/src/alchemlyb/tests/test_preprocessing.py b/src/alchemlyb/tests/test_preprocessing.py index a124c1aa..05036fe4 100644 --- a/src/alchemlyb/tests/test_preprocessing.py +++ b/src/alchemlyb/tests/test_preprocessing.py @@ -1,96 +1,120 @@ """Tests for preprocessing functions. """ -import pytest - import numpy as np +import pytest +from alchemtest.gmx import load_benzene from numpy.testing import assert_allclose import alchemlyb -from alchemlyb.parsing import gmx -from alchemlyb.preprocessing import (slicing, statistical_inefficiency, - equilibrium_detection, - decorrelate_u_nk, decorrelate_dhdl, - u_nk2series, dhdl2series) from alchemlyb.parsing.gmx import extract_u_nk, extract_dHdl -from alchemtest.gmx import load_benzene, load_ABFE - -import alchemtest.gmx +from alchemlyb.preprocessing import ( + slicing, + statistical_inefficiency, + equilibrium_detection, + decorrelate_u_nk, + decorrelate_dhdl, + u_nk2series, + dhdl2series, +) + + +# def gmx_benzene_dHdl(): +# dataset = alchemtest.gmx.load_benzene() +# return gmx.extract_dHdl(dataset['data']['Coulomb'][0], T=300) +# +# # When issue #206 is addressed make the gmx_benzene_dHdl() function the +# # fixture, remove the wrapper below, and replace +# # gmx_benzene_dHdl_fixture --> gmx_benzene_dHdl +# @pytest.fixture() +# def gmx_benzene_dHdl_fixture(): +# return gmx_benzene_dHdl() +# +# @pytest.fixture() +# def gmx_ABFE(): +# dataset = alchemtest.gmx.load_ABFE() +# return gmx.extract_u_nk(dataset['data']['complex'][0], T=300) +# +# @pytest.fixture() +# def gmx_ABFE_dhdl(): +# dataset = alchemtest.gmx.load_ABFE() +# return gmx.extract_dHdl(dataset['data']['complex'][0], T=300) +# +# @pytest.fixture() +# def gmx_ABFE_u_nk(): +# dataset = alchemtest.gmx.load_ABFE() +# return gmx.extract_u_nk(dataset['data']['complex'][-1], T=300) +# +# @pytest.fixture() +# def gmx_benzene_u_nk_fixture(): +# dataset = alchemtest.gmx.load_benzene() +# return gmx.extract_u_nk(dataset['data']['Coulomb'][0], T=300) +# +# def gmx_benzene_u_nk(): +# dataset = alchemtest.gmx.load_benzene() +# return gmx.extract_u_nk(dataset['data']['Coulomb'][0], T=300) +# +# +# def gmx_benzene_dHdl_full(): +# dataset = alchemtest.gmx.load_benzene() +# return alchemlyb.concat([gmx.extract_dHdl(i, T=300) for i in dataset['data']['Coulomb']]) +# +# +# def gmx_benzene_u_nk_full(): +# dataset = alchemtest.gmx.load_benzene() +# return alchemlyb.concat([gmx.extract_u_nk(i, T=300) for i in dataset['data']['Coulomb']]) -def gmx_benzene_dHdl(): - dataset = alchemtest.gmx.load_benzene() - return gmx.extract_dHdl(dataset['data']['Coulomb'][0], T=300) -# When issue #206 is addressed make the gmx_benzene_dHdl() function the -# fixture, remove the wrapper below, and replace -# gmx_benzene_dHdl_fixture --> gmx_benzene_dHdl -@pytest.fixture() -def gmx_benzene_dHdl_fixture(): - return gmx_benzene_dHdl() +def _check_data_is_outside_bounds(data, lower, upper): + """ + Helper function to make sure that `data` has entries that are + below the `lower` bound, and above the `upper` bound. + This is used by slicing tests to make sure that the data + provided is appropriate for the tests. + """ + assert any(data.reset_index()["time"] < lower) + assert any(data.reset_index()["time"] > upper) -@pytest.fixture() -def gmx_ABFE(): - dataset = alchemtest.gmx.load_ABFE() - return gmx.extract_u_nk(dataset['data']['complex'][0], T=300) @pytest.fixture() -def gmx_ABFE_dhdl(): - dataset = alchemtest.gmx.load_ABFE() - return gmx.extract_dHdl(dataset['data']['complex'][0], T=300) +def dHdl(gmx_benzene_Coulomb_dHdl): + return gmx_benzene_Coulomb_dHdl[0] -@pytest.fixture() -def gmx_ABFE_u_nk(): - dataset = alchemtest.gmx.load_ABFE() - return gmx.extract_u_nk(dataset['data']['complex'][-1], T=300) @pytest.fixture() -def gmx_benzene_u_nk_fixture(): - dataset = alchemtest.gmx.load_benzene() - return gmx.extract_u_nk(dataset['data']['Coulomb'][0], T=300) - -def gmx_benzene_u_nk(): - dataset = alchemtest.gmx.load_benzene() - return gmx.extract_u_nk(dataset['data']['Coulomb'][0], T=300) +def u_nk(gmx_benzene_Coulomb_u_nk): + return gmx_benzene_Coulomb_u_nk[0] -def gmx_benzene_dHdl_full(): - dataset = alchemtest.gmx.load_benzene() - return alchemlyb.concat([gmx.extract_dHdl(i, T=300) for i in dataset['data']['Coulomb']]) - - -def gmx_benzene_u_nk_full(): - dataset = alchemtest.gmx.load_benzene() - return alchemlyb.concat([gmx.extract_u_nk(i, T=300) for i in dataset['data']['Coulomb']]) - - -def _check_data_is_outside_bounds(data, lower, upper): - """ - Helper function to make sure that `data` has entries that are - below the `lower` bound, and above the `upper` bound. - This is used by slicing tests to make sure that the data - provided is appropriate for the tests. - """ - assert any(data.reset_index()['time'] < lower) - assert any(data.reset_index()['time'] > upper) +@pytest.fixture() +def multi_index_u_nk(gmx_ABFE_complex_n_uk): + return gmx_ABFE_complex_n_uk[0] class TestSlicing: - """Test slicing functionality. + """Test slicing functionality.""" - """ def slicer(self, *args, **kwargs): return slicing(*args, **kwargs) - @pytest.mark.parametrize(('data', 'size'), [(gmx_benzene_dHdl(), 661), - (gmx_benzene_u_nk(), 661)]) - def test_basic_slicing(self, data, size): - assert len(self.slicer(data, lower=1000, upper=34000, step=5)) == size - - @pytest.mark.parametrize(('dataloader', 'lower', 'upper'), - [ - ('gmx_benzene_dHdl_fixture', 1000, 34000), - ('gmx_benzene_u_nk_fixture', 1000, 34000), - ]) + @pytest.mark.parametrize(("data", "size"), [("dHdl", 661), ("u_nk", 661)]) + def test_basic_slicing(self, data, size, request): + assert ( + len( + self.slicer( + request.getfixturevalue(data)[0], lower=1000, upper=34000, step=5 + ) + ) + == size + ) + + @pytest.mark.parametrize( + ("dataloader", "lower", "upper"), + [ + ("dHdl", 1000, 34000), + ("u_nk", 1000, 34000), + ], + ) def test_data_is_unchanged(self, dataloader, lower, upper, request): """ Test that slicing does not change the underlying data @@ -102,17 +126,16 @@ def test_data_is_unchanged(self, dataloader, lower, upper, request): # Slice data, and test that we didn't change the input data original_length = len(data) - sliced = self.slicer(data, - lower=lower, - upper=upper, - step=5) + sliced = self.slicer(data, lower=lower, upper=upper, step=5) assert len(data) == original_length - @pytest.mark.parametrize(('dataloader', 'lower', 'upper'), - [ - ('gmx_benzene_dHdl_fixture', 1000, 34000), - ('gmx_benzene_u_nk_fixture', 1000, 34000), - ]) + @pytest.mark.parametrize( + ("dataloader", "lower", "upper"), + [ + ("dHdl", 1000, 34000), + ("u_nk", 1000, 34000), + ], + ) def test_lower_and_upper_bound(self, dataloader, lower, upper, request): """ Test that the lower and upper time is respected @@ -124,19 +147,14 @@ def test_lower_and_upper_bound(self, dataloader, lower, upper, request): # Slice data, and test that we don't observe times outside # the prescribed range - sliced = self.slicer(data, - lower=lower, - upper=upper, - step=5) - assert all(sliced.reset_index()['time'] >= lower) - assert all(sliced.reset_index()['time'] <= upper) - - @pytest.mark.parametrize('data', [gmx_benzene_dHdl(), - gmx_benzene_u_nk()]) - def test_disordered_exception(self, data): - """Test that a shuffled DataFrame yields a KeyError. + sliced = self.slicer(data, lower=lower, upper=upper, step=5) + assert all(sliced.reset_index()["time"] >= lower) + assert all(sliced.reset_index()["time"] <= upper) - """ + @pytest.mark.parametrize("dataloader", ["dHdl", "u_nk"]) + def test_disordered_exception(self, dataloader, request): + """Test that a shuffled DataFrame yields a KeyError.""" + data = request.getfixturevalue(dataloader) indices = data.index.values np.random.shuffle(indices) @@ -145,102 +163,88 @@ def test_disordered_exception(self, data): with pytest.raises(KeyError): self.slicer(df, lower=200) - @pytest.mark.parametrize('data', [gmx_benzene_dHdl_full(), - gmx_benzene_u_nk_full()]) - def test_duplicated_exception(self, data): - """Test that a DataFrame with duplicate times yields a KeyError. - - """ + @pytest.mark.parametrize("dataloader", ["dHdl", "u_nk"]) + def test_duplicated_exception(self, dataloader, request): + """Test that a DataFrame with duplicate times yields a KeyError.""" + data = request.getfixturevalue(dataloader) with pytest.raises(KeyError): self.slicer(data.sort_index(axis=0), lower=200) - def test_subsample_bounds_and_step(self, gmx_ABFE): - """Make sure that slicing the series also works - """ - subsample = statistical_inefficiency(gmx_ABFE, - gmx_ABFE.sum(axis=1), - lower=100, - upper=400, - step=2) + def test_subsample_bounds_and_step(self, multi_index_u_nk): + """Make sure that slicing the series also works""" + subsample = statistical_inefficiency( + multi_index_u_nk, multi_index_u_nk.sum(axis=1), lower=100, upper=400, step=2 + ) assert len(subsample) == 76 - def test_multiindex_duplicated(self, gmx_ABFE): - subsample = statistical_inefficiency(gmx_ABFE, - gmx_ABFE.sum(axis=1)) + def test_multiindex_duplicated(self, multi_index_u_nk): + subsample = statistical_inefficiency( + multi_index_u_nk, multi_index_u_nk.sum(axis=1) + ) assert len(subsample) == 501 - def test_sort_off(self, gmx_ABFE): - unsorted = alchemlyb.concat([gmx_ABFE[-500:], gmx_ABFE[:500]]) + def test_sort_off(self, multi_index_u_nk): + unsorted = alchemlyb.concat([multi_index_u_nk[-500:], multi_index_u_nk[:500]]) with pytest.raises(KeyError): - statistical_inefficiency(unsorted, - unsorted.sum(axis=1), - sort=False) - - def test_sort_on(self, gmx_ABFE): - unsorted = alchemlyb.concat([gmx_ABFE[-500:], gmx_ABFE[:500]]) - subsample = statistical_inefficiency(unsorted, - unsorted.sum(axis=1), - sort=True) - assert subsample.reset_index(0)['time'].is_monotonic_increasing - - def test_sort_on_noseries(self, gmx_ABFE): - unsorted = alchemlyb.concat([gmx_ABFE[-500:], gmx_ABFE[:500]]) - subsample = statistical_inefficiency(unsorted, - None, - sort=True) - assert subsample.reset_index(0)['time'].is_monotonic_increasing - - def test_duplication_off(self, gmx_ABFE): - duplicated = alchemlyb.concat([gmx_ABFE, gmx_ABFE]) + statistical_inefficiency(unsorted, unsorted.sum(axis=1), sort=False) + + def test_sort_on(self, multi_index_u_nk): + unsorted = alchemlyb.concat([multi_index_u_nk[-500:], multi_index_u_nk[:500]]) + subsample = statistical_inefficiency(unsorted, unsorted.sum(axis=1), sort=True) + assert subsample.reset_index(0)["time"].is_monotonic_increasing + + def test_sort_on_noseries(self, multi_index_u_nk): + unsorted = alchemlyb.concat([multi_index_u_nk[-500:], multi_index_u_nk[:500]]) + subsample = statistical_inefficiency(unsorted, None, sort=True) + assert subsample.reset_index(0)["time"].is_monotonic_increasing + + def test_duplication_off(self, multi_index_u_nk): + duplicated = alchemlyb.concat([multi_index_u_nk, multi_index_u_nk]) with pytest.raises(KeyError): - statistical_inefficiency(duplicated, - duplicated.sum(axis=1), - drop_duplicates=False) - - def test_duplication_on_dataframe(self, gmx_ABFE): - duplicated = alchemlyb.concat([gmx_ABFE, gmx_ABFE]) - subsample = statistical_inefficiency(duplicated, - duplicated.sum(axis=1), - drop_duplicates=True) + statistical_inefficiency( + duplicated, duplicated.sum(axis=1), drop_duplicates=False + ) + + def test_duplication_on_dataframe(self, multi_index_u_nk): + duplicated = alchemlyb.concat([multi_index_u_nk, multi_index_u_nk]) + subsample = statistical_inefficiency( + duplicated, duplicated.sum(axis=1), drop_duplicates=True + ) assert len(subsample) < 1000 - def test_duplication_on_dataframe_noseries(self, gmx_ABFE): - duplicated = alchemlyb.concat([gmx_ABFE, gmx_ABFE]) - subsample = statistical_inefficiency(duplicated, - None, - drop_duplicates=True) + def test_duplication_on_dataframe_noseries(self, multi_index_u_nk): + duplicated = alchemlyb.concat([multi_index_u_nk, multi_index_u_nk]) + subsample = statistical_inefficiency(duplicated, None, drop_duplicates=True) assert len(subsample) == 1001 - def test_duplication_on_series(self, gmx_ABFE): - duplicated = alchemlyb.concat([gmx_ABFE, gmx_ABFE]) - subsample = statistical_inefficiency(duplicated.sum(axis=1), - duplicated.sum(axis=1), - drop_duplicates=True) + def test_duplication_on_series(self, multi_index_u_nk): + duplicated = alchemlyb.concat([multi_index_u_nk, multi_index_u_nk]) + subsample = statistical_inefficiency( + duplicated.sum(axis=1), duplicated.sum(axis=1), drop_duplicates=True + ) assert len(subsample) < 1000 - def test_duplication_on_series_noseries(self, gmx_ABFE): - duplicated = alchemlyb.concat([gmx_ABFE, gmx_ABFE]) - subsample = statistical_inefficiency(duplicated.sum(axis=1), - None, - drop_duplicates=True) + def test_duplication_on_series_noseries(self, multi_index_u_nk): + duplicated = alchemlyb.concat([multi_index_u_nk, multi_index_u_nk]) + subsample = statistical_inefficiency( + duplicated.sum(axis=1), None, drop_duplicates=True + ) assert len(subsample) == 1001 -class CorrelatedPreprocessors: - @pytest.mark.parametrize(('data', 'size'), [(gmx_benzene_dHdl(), 4001), - (gmx_benzene_u_nk(), 4001)]) - def test_subsampling(self, data, size): +class CorrelatedPreprocessors: + @pytest.mark.parametrize(("dataloader", "size"), [("dHdl", 4001), ("u_nk", 4001)]) + def test_subsampling(self, dataloader, size, request): """Basic test for execution; resulting size of dataset sensitive to machine and depends on algorithm. """ + data = request.getfixturevalue(dataloader) assert len(self.slicer(data, series=data.loc[:, data.columns[0]])) <= size - @pytest.mark.parametrize('data', [gmx_benzene_dHdl(), - gmx_benzene_u_nk()]) - def test_no_series(self, data): - """Check that we get the same result as simple slicing with no Series. - - """ + @pytest.mark.parametrize("dataloader", ["dHdl", "u_nk"]) + def test_no_series(self, dataloader, request): + """Check that we get the same result as simple slicing with no Series.""" + data = request.getfixturevalue(dataloader) df_sub = self.slicer(data, lower=200, upper=5000, step=2) df_sliced = slicing(data, lower=200, upper=5000, step=2) @@ -248,43 +252,52 @@ def test_no_series(self, data): class TestStatisticalInefficiency(TestSlicing, CorrelatedPreprocessors): - def slicer(self, *args, **kwargs): return statistical_inefficiency(*args, **kwargs) - @pytest.mark.parametrize(('conservative', 'data', 'size'), - [ - (True, gmx_benzene_dHdl(), 2001), # 0.00: g = 1.0559445620585415 - (True, gmx_benzene_u_nk(), 2001), # 'fep': g = 1.0560203916559594 - (False, gmx_benzene_dHdl(), 3789), - (False, gmx_benzene_u_nk(), 3571), - ]) - def test_conservative(self, data, size, conservative): - sliced = self.slicer(data, series=data.loc[:, data.columns[0]], conservative=conservative) + @pytest.mark.parametrize( + ("conservative", "dataloader", "size"), + [ + (True, "dHdl", 2001), # 0.00: g = 1.0559445620585415 + (True, "u_nk", 2001), # 'fep': g = 1.0560203916559594 + (False, "dHdl", 3789), + (False, "u_nk", 3571), + ], + ) + def test_conservative(self, dataloader, size, conservative, request): + data = request.getfixturevalue(dataloader) + sliced = self.slicer( + data, series=data.loc[:, data.columns[0]], conservative=conservative + ) # results can vary slightly with different machines # so possibly do # delta = 10 # assert size - delta < len(sliced) < size + delta assert len(sliced) == size - @pytest.mark.parametrize('series', [ - gmx_benzene_dHdl()['fep'][:20], # wrong length - gmx_benzene_dHdl()['fep'][::-1], # wrong time stamps (reversed) - ]) + @pytest.mark.parametrize( + "dataloader,end,step", + [ + gmx_benzene_dHdl()["fep"][:20], # wrong length + gmx_benzene_dHdl()["fep"][::-1], # wrong time stamps (reversed) + ], + ) def test_raise_ValueError_for_mismatched_data(self, series): data = gmx_benzene_dHdl() with pytest.raises(ValueError): self.slicer(data, series=series) - @pytest.mark.parametrize(('dataloader', 'lower', 'upper'), - [ - ('gmx_benzene_dHdl_fixture', 1000, 34000), - ('gmx_benzene_u_nk_fixture', 1000, 34000), - ]) - @pytest.mark.parametrize('use_series', [True, False]) - @pytest.mark.parametrize('conservative', [True, False]) + @pytest.mark.parametrize( + ("dataloader", "lower", "upper"), + [ + ("gmx_benzene_dHdl_fixture", 1000, 34000), + ("gmx_benzene_u_nk_fixture", 1000, 34000), + ], + ) + @pytest.mark.parametrize("use_series", [True, False]) + @pytest.mark.parametrize("conservative", [True, False]) def test_data_is_unchanged( - self, dataloader, use_series, lower, upper, conservative, request + self, dataloader, use_series, lower, upper, conservative, request ): """ Test that using statistical_inefficiency does not change the underlying data @@ -304,23 +317,27 @@ def test_data_is_unchanged( # Slice data, and test that we didn't change the input data original_length = len(data) - self.slicer(data, - series=series, - lower=lower, - upper=upper, - step=5, - conservative=conservative) + self.slicer( + data, + series=series, + lower=lower, + upper=upper, + step=5, + conservative=conservative, + ) assert len(data) == original_length - @pytest.mark.parametrize(('dataloader', 'lower', 'upper'), - [ - ('gmx_benzene_dHdl_fixture', 1000, 34000), - ('gmx_benzene_u_nk_fixture', 1000, 34000), - ]) - @pytest.mark.parametrize('use_series', [True, False]) - @pytest.mark.parametrize('conservative', [True, False]) + @pytest.mark.parametrize( + ("dataloader", "lower", "upper"), + [ + ("gmx_benzene_dHdl_fixture", 1000, 34000), + ("gmx_benzene_u_nk_fixture", 1000, 34000), + ], + ) + @pytest.mark.parametrize("use_series", [True, False]) + @pytest.mark.parametrize("conservative", [True, False]) def test_lower_and_upper_bound_slicer( - self, dataloader, use_series, lower, upper, conservative, request + self, dataloader, use_series, lower, upper, conservative, request ): """ Test that the lower and upper time is respected when using statistical_inefficiency @@ -340,23 +357,27 @@ def test_lower_and_upper_bound_slicer( # Slice data, and test that we don't observe times outside # the prescribed range - sliced = self.slicer(data, - series=series, - lower=lower, - upper=upper, - step=5, - conservative=conservative) - assert all(sliced.reset_index()['time'] >= lower) - assert all(sliced.reset_index()['time'] <= upper) - - @pytest.mark.parametrize(('dataloader', 'lower', 'upper'), - [ - ('gmx_benzene_dHdl_fixture', 1000, 34000), - ('gmx_benzene_u_nk_fixture', 1000, 34000), - ]) - @pytest.mark.parametrize('conservative', [True, False]) + sliced = self.slicer( + data, + series=series, + lower=lower, + upper=upper, + step=5, + conservative=conservative, + ) + assert all(sliced.reset_index()["time"] >= lower) + assert all(sliced.reset_index()["time"] <= upper) + + @pytest.mark.parametrize( + ("dataloader", "lower", "upper"), + [ + ("gmx_benzene_dHdl_fixture", 1000, 34000), + ("gmx_benzene_u_nk_fixture", 1000, 34000), + ], + ) + @pytest.mark.parametrize("conservative", [True, False]) def test_slicing_inefficiency_equivalence( - self, dataloader, lower, upper, conservative, request + self, dataloader, lower, upper, conservative, request ): """ Test that first slicing the data frame, then subsampling is equivalent to @@ -369,143 +390,201 @@ def test_slicing_inefficiency_equivalence( # Slice dataframe, then subsample it based on the sum of its components sliced_data = slicing(data, lower=lower, upper=upper) - subsampled_sliced_data = self.slicer(sliced_data, - series=sliced_data.sum(axis=1), - conservative=conservative) + subsampled_sliced_data = self.slicer( + sliced_data, series=sliced_data.sum(axis=1), conservative=conservative + ) # Subsample the dataframe based on the sum of its components while # also specifying the slicing range - subsampled_data = self.slicer(data, - series=data.sum(axis=1), - lower=lower, - upper=upper, - conservative=conservative) + subsampled_data = self.slicer( + data, + series=data.sum(axis=1), + lower=lower, + upper=upper, + conservative=conservative, + ) assert (subsampled_sliced_data == subsampled_data).all(axis=None) class TestEquilibriumDetection(TestSlicing, CorrelatedPreprocessors): - def slicer(self, *args, **kwargs): return equilibrium_detection(*args, **kwargs) -class Test_Units(): - '''Test the preprocessing module.''' + +class Test_Units: + """Test the preprocessing module.""" + @staticmethod - @pytest.fixture(scope='class') + @pytest.fixture(scope="class") def dhdl(): dataset = load_benzene() - dhdl = extract_dHdl(dataset['data']['Coulomb'][0], 310) + dhdl = extract_dHdl(dataset["data"]["Coulomb"][0], 310) return dhdl def test_slicing(self, dhdl): - '''Test if extract_u_nk assign the attr correctly''' + """Test if extract_u_nk assign the attr correctly""" dataset = load_benzene() - u_nk = extract_u_nk(dataset['data']['Coulomb'][0], 310) + u_nk = extract_u_nk(dataset["data"]["Coulomb"][0], 310) new_u_nk = slicing(u_nk) - assert new_u_nk.attrs['temperature'] == 310 - assert new_u_nk.attrs['energy_unit'] == 'kT' + assert new_u_nk.attrs["temperature"] == 310 + assert new_u_nk.attrs["energy_unit"] == "kT" def test_statistical_inefficiency(self, dhdl): - '''Test if extract_u_nk assign the attr correctly''' + """Test if extract_u_nk assign the attr correctly""" dataset = load_benzene() - dhdl = extract_dHdl(dataset['data']['Coulomb'][0], 310) + dhdl = extract_dHdl(dataset["data"]["Coulomb"][0], 310) new_dhdl = statistical_inefficiency(dhdl) - assert new_dhdl.attrs['temperature'] == 310 - assert new_dhdl.attrs['energy_unit'] == 'kT' + assert new_dhdl.attrs["temperature"] == 310 + assert new_dhdl.attrs["energy_unit"] == "kT" def test_equilibrium_detection(self, dhdl): - '''Test if extract_u_nk assign the attr correctly''' + """Test if extract_u_nk assign the attr correctly""" dataset = load_benzene() - dhdl = extract_dHdl(dataset['data']['Coulomb'][0], 310) + dhdl = extract_dHdl(dataset["data"]["Coulomb"][0], 310) new_dhdl = equilibrium_detection(dhdl) - assert new_dhdl.attrs['temperature'] == 310 - assert new_dhdl.attrs['energy_unit'] == 'kT' + assert new_dhdl.attrs["temperature"] == 310 + assert new_dhdl.attrs["energy_unit"] == "kT" -@pytest.mark.parametrize(('method', 'size'), [('all', 2001), - ('dE', 2001)]) + +@pytest.mark.parametrize(("method", "size"), [("all", 2001), ("dE", 2001)]) def test_decorrelate_u_nk_single_l(gmx_benzene_u_nk_fixture, method, size): - assert len(decorrelate_u_nk(gmx_benzene_u_nk_fixture, method=method, - drop_duplicates=True, - sort=True)) == size + assert ( + len( + decorrelate_u_nk( + gmx_benzene_u_nk_fixture, method=method, drop_duplicates=True, sort=True + ) + ) + == size + ) + def test_decorrelate_u_nk_burnin(gmx_benzene_u_nk_fixture): - assert len(decorrelate_u_nk(gmx_benzene_u_nk_fixture, method='dE', - drop_duplicates=True, - sort=True, remove_burnin=True)) == 2849 + assert ( + len( + decorrelate_u_nk( + gmx_benzene_u_nk_fixture, + method="dE", + drop_duplicates=True, + sort=True, + remove_burnin=True, + ) + ) + == 2849 + ) -def test_decorrelate_dhdl_burnin(gmx_benzene_dHdl_fixture): - assert len(decorrelate_dhdl(gmx_benzene_dHdl_fixture, - drop_duplicates=True, - sort=True, remove_burnin=True)) == 2848 -@pytest.mark.parametrize(('method', 'size'), [('all', 1001), - ('dE', 334)]) +def test_decorrelate_dhdl_burnin(gmx_benzene_dHdl_fixture): + assert ( + len( + decorrelate_dhdl( + gmx_benzene_dHdl_fixture, + drop_duplicates=True, + sort=True, + remove_burnin=True, + ) + ) + == 2848 + ) + + +@pytest.mark.parametrize(("method", "size"), [("all", 1001), ("dE", 334)]) def test_decorrelate_u_nk_multiple_l(gmx_ABFE_u_nk, method, size): - assert len(decorrelate_u_nk(gmx_ABFE_u_nk, method=method,)) == size + assert ( + len( + decorrelate_u_nk( + gmx_ABFE_u_nk, + method=method, + ) + ) + == size + ) + def test_decorrelate_dhdl_single_l(gmx_benzene_u_nk_fixture): - assert len(decorrelate_dhdl(gmx_benzene_u_nk_fixture, drop_duplicates=True, - sort=True)) == 2001 + assert ( + len(decorrelate_dhdl(gmx_benzene_u_nk_fixture, drop_duplicates=True, sort=True)) + == 2001 + ) + def test_decorrelate_dhdl_multiple_l(gmx_ABFE_dhdl): - assert len(decorrelate_dhdl(gmx_ABFE_dhdl,)) == 501 + assert ( + len( + decorrelate_dhdl( + gmx_ABFE_dhdl, + ) + ) + == 501 + ) + def test_raise_non_uk(gmx_ABFE_dhdl): with pytest.raises(ValueError): - decorrelate_u_nk(gmx_ABFE_dhdl, ) + decorrelate_u_nk( + gmx_ABFE_dhdl, + ) -class TestDhdl2series(): + +class TestDhdl2series: @staticmethod - @pytest.fixture(scope='class') + @pytest.fixture(scope="class") def dhdl(): dataset = load_benzene() - dhdl = extract_dHdl(dataset['data']['Coulomb'][0], 300) + dhdl = extract_dHdl(dataset["data"]["Coulomb"][0], 300) return dhdl - @pytest.mark.parametrize("methodargs", [{}, {'method': 'all'}]) + @pytest.mark.parametrize("methodargs", [{}, {"method": "all"}]) def test_dhdl2series(self, dhdl, methodargs): series = dhdl2series(dhdl, **methodargs) assert len(series) == len(dhdl) assert_allclose(series, dhdl.sum(axis=1)) def test_other_method_ValueError(self, dhdl): - with pytest.raises(ValueError, - match="Only method='all' is supported for dhdl2series()."): + with pytest.raises( + ValueError, match="Only method='all' is supported for dhdl2series()." + ): dhdl2series(dhdl, method="dE") -class TestU_nk2series(): + +class TestU_nk2series: @staticmethod - @pytest.fixture(scope='class') + @pytest.fixture(scope="class") def u_nk(): dataset = load_benzene() - u_nk = extract_u_nk(dataset['data']['Coulomb'][0], 300) + u_nk = extract_u_nk(dataset["data"]["Coulomb"][0], 300) return u_nk - @pytest.mark.parametrize("methodargs,reference", # reference = sum - [({}, 9207.80229000283), - ({'method': 'all'}, 85982.34668751864), - ({'method': 'dE'}, 9207.80229000283), - ]) + @pytest.mark.parametrize( + "methodargs,reference", # reference = sum + [ + ({}, 9207.80229000283), + ({"method": "all"}, 85982.34668751864), + ({"method": "dE"}, 9207.80229000283), + ], + ) def test_u_nk2series(self, u_nk, methodargs, reference): series = u_nk2series(u_nk, **methodargs) assert len(series) == len(u_nk) assert_allclose(series.sum(), reference) - @pytest.mark.parametrize("methodargs,reference", # reference = sum - [({'method': 'dhdl_all'}, 85982.34668751864), - ({'method': 'dhdl'}, 9207.80229000283), - ]) + @pytest.mark.parametrize( + "methodargs,reference", # reference = sum + [ + ({"method": "dhdl_all"}, 85982.34668751864), + ({"method": "dhdl"}, 9207.80229000283), + ], + ) def test_u_nk2series_deprecated(self, u_nk, methodargs, reference): - with pytest.warns(DeprecationWarning, - match=r"Method 'dhdl.*' has been deprecated, using '.*' instead\. " - r"'dhdl.*' will be removed in alchemlyb 3\.0\.0\."): + with pytest.warns( + DeprecationWarning, + match=r"Method 'dhdl.*' has been deprecated, using '.*' instead\. " + r"'dhdl.*' will be removed in alchemlyb 3\.0\.0\.", + ): series = u_nk2series(u_nk, **methodargs) assert len(series) == len(u_nk) assert_allclose(series.sum(), reference) - def test_other_method_ValueError(self, u_nk): - with pytest.raises(ValueError, - match='Decorrelation method bogus not found.'): + with pytest.raises(ValueError, match="Decorrelation method bogus not found."): u_nk2series(u_nk, method="bogus") From 2a2d11f32a6473257a393e25f025367b8e575a66 Mon Sep 17 00:00:00 2001 From: "William (Zhiyi) Wu" Date: Sun, 4 Dec 2022 18:27:49 +0000 Subject: [PATCH 09/21] update --- src/alchemlyb/tests/conftest.py | 100 +++++++++- src/alchemlyb/tests/test_preprocessing.py | 176 +++++------------ src/alchemlyb/tests/test_ti_estimators.py | 222 ++++++++++------------ src/alchemlyb/tests/test_units.py | 156 ++++++++------- 4 files changed, 331 insertions(+), 323 deletions(-) diff --git a/src/alchemlyb/tests/conftest.py b/src/alchemlyb/tests/conftest.py index ab3ec2f6..5e4c8a1c 100644 --- a/src/alchemlyb/tests/conftest.py +++ b/src/alchemlyb/tests/conftest.py @@ -1,5 +1,5 @@ import pytest -from alchemtest.amber import load_bace_example +from alchemtest.amber import load_bace_example, load_simplesolvated from alchemtest.gmx import ( load_benzene, load_expanded_ensemble_case_1, @@ -32,6 +32,11 @@ def gmx_benzene_Coulomb_dHdl(gmx_benzene): return [gmx.extract_dHdl(file, T=300) for file in gmx_benzene["Coulomb"]] +@pytest.fixture +def gmx_benzene_VDW_dHdl(gmx_benzene): + return [gmx.extract_dHdl(file, T=300) for file in gmx_benzene["VDW"]] + + @pytest.fixture def gmx_benzene_Coulomb_u_nk(gmx_benzene): return [gmx.extract_u_nk(file, T=300) for file in gmx_benzene["Coulomb"]] @@ -68,6 +73,16 @@ def gmx_expanded_ensemble_case_1(): ] +@pytest.fixture +def gmx_expanded_ensemble_case_1_dHdl(): + dataset = load_expanded_ensemble_case_1() + + return [ + gmx.extract_dHdl(filename, T=300, filter=False) + for filename in dataset["data"]["AllStates"] + ] + + @pytest.fixture def gmx_expanded_ensemble_case_2(): dataset = load_expanded_ensemble_case_2() @@ -78,6 +93,16 @@ def gmx_expanded_ensemble_case_2(): ] +@pytest.fixture +def gmx_expanded_ensemble_case_2_dHdl(): + dataset = load_expanded_ensemble_case_2() + + return [ + gmx.extract_dHdl(filename, T=300, filter=False) + for filename in dataset["data"]["AllStates"] + ] + + @pytest.fixture def gmx_expanded_ensemble_case_3(): dataset = load_expanded_ensemble_case_3() @@ -88,6 +113,16 @@ def gmx_expanded_ensemble_case_3(): ] +@pytest.fixture +def gmx_expanded_ensemble_case_3_dHdl(): + dataset = load_expanded_ensemble_case_3() + + return [ + gmx.extract_dHdl(filename, T=300, filter=False) + for filename in dataset["data"]["AllStates"] + ] + + @pytest.fixture def gmx_water_particle_with_total_energy(): dataset = load_water_particle_with_total_energy() @@ -97,6 +132,15 @@ def gmx_water_particle_with_total_energy(): ] +@pytest.fixture +def gmx_water_particle_with_total_energy_dHdl(): + dataset = load_water_particle_with_total_energy() + + return [ + gmx.extract_dHdl(filename, T=300) for filename in dataset["data"]["AllStates"] + ] + + @pytest.fixture def gmx_water_particle_with_potential_energy(): dataset = load_water_particle_with_potential_energy() @@ -106,6 +150,15 @@ def gmx_water_particle_with_potential_energy(): ] +@pytest.fixture +def gmx_water_particle_with_potential_energy_dHdl(): + dataset = load_water_particle_with_potential_energy() + + return [ + gmx.extract_dHdl(filename, T=300) for filename in dataset["data"]["AllStates"] + ] + + @pytest.fixture def gmx_water_particle_without_energy(): dataset = load_water_particle_without_energy() @@ -115,6 +168,38 @@ def gmx_water_particle_without_energy(): ] +@pytest.fixture +def gmx_water_particle_without_energy_dHdl(): + dataset = load_water_particle_without_energy() + + return [ + gmx.extract_dHdl(filename, T=300) for filename in dataset["data"]["AllStates"] + ] + + +@pytest.fixture +def amber_simplesolvated(): + dataset = load_simplesolvated() + return dataset["data"] + + +@pytest.fixture +def amber_simplesolvated_charge_dHdl(amber_simplesolvated): + return [ + amber.extract_dHdl(filename, T=298.0) + for filename in amber_simplesolvated["charge"] + ] + + +@pytest.fixture +def amber_simplesolvated_vdw_dHdl(amber_simplesolvated): + + return [ + amber.extract_dHdl(filename, T=298.0) + for filename in amber_simplesolvated["vdw"] + ] + + @pytest.fixture def amber_bace_example_complex_vdw(): dataset = load_bace_example() @@ -126,10 +211,19 @@ def amber_bace_example_complex_vdw(): @pytest.fixture -def gomc_benzene_u_nk(): +def gomc_benzene(): dataset = gomc_load_benzene() + return dataset["data"] + - return [gomc.extract_u_nk(filename, T=298) for filename in dataset["data"]] +@pytest.fixture +def gomc_benzene_u_nk(gomc_benzene): + return [gomc.extract_u_nk(filename, T=298) for filename in gomc_benzene] + + +@pytest.fixture +def gomc_benzene_dHdl(gomc_benzene): + return [gomc.extract_dHdl(filename, T=298) for filename in gomc_benzene] @pytest.fixture diff --git a/src/alchemlyb/tests/test_preprocessing.py b/src/alchemlyb/tests/test_preprocessing.py index 05036fe4..66ca634a 100644 --- a/src/alchemlyb/tests/test_preprocessing.py +++ b/src/alchemlyb/tests/test_preprocessing.py @@ -3,11 +3,9 @@ """ import numpy as np import pytest -from alchemtest.gmx import load_benzene from numpy.testing import assert_allclose import alchemlyb -from alchemlyb.parsing.gmx import extract_u_nk, extract_dHdl from alchemlyb.preprocessing import ( slicing, statistical_inefficiency, @@ -19,52 +17,6 @@ ) -# def gmx_benzene_dHdl(): -# dataset = alchemtest.gmx.load_benzene() -# return gmx.extract_dHdl(dataset['data']['Coulomb'][0], T=300) -# -# # When issue #206 is addressed make the gmx_benzene_dHdl() function the -# # fixture, remove the wrapper below, and replace -# # gmx_benzene_dHdl_fixture --> gmx_benzene_dHdl -# @pytest.fixture() -# def gmx_benzene_dHdl_fixture(): -# return gmx_benzene_dHdl() -# -# @pytest.fixture() -# def gmx_ABFE(): -# dataset = alchemtest.gmx.load_ABFE() -# return gmx.extract_u_nk(dataset['data']['complex'][0], T=300) -# -# @pytest.fixture() -# def gmx_ABFE_dhdl(): -# dataset = alchemtest.gmx.load_ABFE() -# return gmx.extract_dHdl(dataset['data']['complex'][0], T=300) -# -# @pytest.fixture() -# def gmx_ABFE_u_nk(): -# dataset = alchemtest.gmx.load_ABFE() -# return gmx.extract_u_nk(dataset['data']['complex'][-1], T=300) -# -# @pytest.fixture() -# def gmx_benzene_u_nk_fixture(): -# dataset = alchemtest.gmx.load_benzene() -# return gmx.extract_u_nk(dataset['data']['Coulomb'][0], T=300) -# -# def gmx_benzene_u_nk(): -# dataset = alchemtest.gmx.load_benzene() -# return gmx.extract_u_nk(dataset['data']['Coulomb'][0], T=300) -# -# -# def gmx_benzene_dHdl_full(): -# dataset = alchemtest.gmx.load_benzene() -# return alchemlyb.concat([gmx.extract_dHdl(i, T=300) for i in dataset['data']['Coulomb']]) -# -# -# def gmx_benzene_u_nk_full(): -# dataset = alchemtest.gmx.load_benzene() -# return alchemlyb.concat([gmx.extract_u_nk(i, T=300) for i in dataset['data']['Coulomb']]) - - def _check_data_is_outside_bounds(data, lower, upper): """ Helper function to make sure that `data` has entries that are @@ -91,6 +43,11 @@ def multi_index_u_nk(gmx_ABFE_complex_n_uk): return gmx_ABFE_complex_n_uk[0] +@pytest.fixture() +def multi_index_dHdl(gmx_ABFE_complex_dHdl): + return gmx_ABFE_complex_dHdl[0] + + class TestSlicing: """Test slicing functionality.""" @@ -102,7 +59,7 @@ def test_basic_slicing(self, data, size, request): assert ( len( self.slicer( - request.getfixturevalue(data)[0], lower=1000, upper=34000, step=5 + request.getfixturevalue(data), lower=1000, upper=34000, step=5 ) ) == size @@ -163,10 +120,12 @@ def test_disordered_exception(self, dataloader, request): with pytest.raises(KeyError): self.slicer(df, lower=200) - @pytest.mark.parametrize("dataloader", ["dHdl", "u_nk"]) + @pytest.mark.parametrize( + "dataloader", ["gmx_benzene_Coulomb_dHdl", "gmx_benzene_Coulomb_u_nk"] + ) def test_duplicated_exception(self, dataloader, request): """Test that a DataFrame with duplicate times yields a KeyError.""" - data = request.getfixturevalue(dataloader) + data = alchemlyb.concat(request.getfixturevalue(dataloader)) with pytest.raises(KeyError): self.slicer(data.sort_index(axis=0), lower=200) @@ -278,20 +237,21 @@ def test_conservative(self, dataloader, size, conservative, request): @pytest.mark.parametrize( "dataloader,end,step", [ - gmx_benzene_dHdl()["fep"][:20], # wrong length - gmx_benzene_dHdl()["fep"][::-1], # wrong time stamps (reversed) + ("dHdl", 20, None), # wrong length + ("dHdl", None, -1), # wrong time stamps (reversed) ], ) - def test_raise_ValueError_for_mismatched_data(self, series): - data = gmx_benzene_dHdl() + def test_raise_ValueError_for_mismatched_data(self, dataloader, end, step, request): + + data = request.getfixturevalue(dataloader) with pytest.raises(ValueError): - self.slicer(data, series=series) + self.slicer(data, series=data[:end:step]) @pytest.mark.parametrize( ("dataloader", "lower", "upper"), [ - ("gmx_benzene_dHdl_fixture", 1000, 34000), - ("gmx_benzene_u_nk_fixture", 1000, 34000), + ("dHdl", 1000, 34000), + ("u_nk", 1000, 34000), ], ) @pytest.mark.parametrize("use_series", [True, False]) @@ -330,8 +290,8 @@ def test_data_is_unchanged( @pytest.mark.parametrize( ("dataloader", "lower", "upper"), [ - ("gmx_benzene_dHdl_fixture", 1000, 34000), - ("gmx_benzene_u_nk_fixture", 1000, 34000), + ("dHdl", 1000, 34000), + ("u_nk", 1000, 34000), ], ) @pytest.mark.parametrize("use_series", [True, False]) @@ -371,8 +331,8 @@ def test_lower_and_upper_bound_slicer( @pytest.mark.parametrize( ("dataloader", "lower", "upper"), [ - ("gmx_benzene_dHdl_fixture", 1000, 34000), - ("gmx_benzene_u_nk_fixture", 1000, 34000), + ("dHdl", 1000, 34000), + ("u_nk", 1000, 34000), ], ) @pytest.mark.parametrize("conservative", [True, False]) @@ -415,55 +375,38 @@ def slicer(self, *args, **kwargs): class Test_Units: """Test the preprocessing module.""" - @staticmethod - @pytest.fixture(scope="class") - def dhdl(): - dataset = load_benzene() - dhdl = extract_dHdl(dataset["data"]["Coulomb"][0], 310) - return dhdl - - def test_slicing(self, dhdl): + def test_slicing(self, u_nk): """Test if extract_u_nk assign the attr correctly""" - dataset = load_benzene() - u_nk = extract_u_nk(dataset["data"]["Coulomb"][0], 310) new_u_nk = slicing(u_nk) - assert new_u_nk.attrs["temperature"] == 310 + assert new_u_nk.attrs["temperature"] == 300 assert new_u_nk.attrs["energy_unit"] == "kT" - def test_statistical_inefficiency(self, dhdl): + def test_statistical_inefficiency(self, dHdl): """Test if extract_u_nk assign the attr correctly""" - dataset = load_benzene() - dhdl = extract_dHdl(dataset["data"]["Coulomb"][0], 310) - new_dhdl = statistical_inefficiency(dhdl) - assert new_dhdl.attrs["temperature"] == 310 + new_dhdl = statistical_inefficiency(dHdl) + assert new_dhdl.attrs["temperature"] == 300 assert new_dhdl.attrs["energy_unit"] == "kT" - def test_equilibrium_detection(self, dhdl): + def test_equilibrium_detection(self, dHdl): """Test if extract_u_nk assign the attr correctly""" - dataset = load_benzene() - dhdl = extract_dHdl(dataset["data"]["Coulomb"][0], 310) - new_dhdl = equilibrium_detection(dhdl) - assert new_dhdl.attrs["temperature"] == 310 + new_dhdl = equilibrium_detection(dHdl) + assert new_dhdl.attrs["temperature"] == 300 assert new_dhdl.attrs["energy_unit"] == "kT" @pytest.mark.parametrize(("method", "size"), [("all", 2001), ("dE", 2001)]) -def test_decorrelate_u_nk_single_l(gmx_benzene_u_nk_fixture, method, size): +def test_decorrelate_u_nk_single_l(u_nk, method, size): assert ( - len( - decorrelate_u_nk( - gmx_benzene_u_nk_fixture, method=method, drop_duplicates=True, sort=True - ) - ) + len(decorrelate_u_nk(u_nk, method=method, drop_duplicates=True, sort=True)) == size ) -def test_decorrelate_u_nk_burnin(gmx_benzene_u_nk_fixture): +def test_decorrelate_u_nk_burnin(u_nk): assert ( len( decorrelate_u_nk( - gmx_benzene_u_nk_fixture, + u_nk, method="dE", drop_duplicates=True, sort=True, @@ -474,11 +417,11 @@ def test_decorrelate_u_nk_burnin(gmx_benzene_u_nk_fixture): ) -def test_decorrelate_dhdl_burnin(gmx_benzene_dHdl_fixture): +def test_decorrelate_dhdl_burnin(dHdl): assert ( len( decorrelate_dhdl( - gmx_benzene_dHdl_fixture, + dHdl, drop_duplicates=True, sort=True, remove_burnin=True, @@ -488,12 +431,12 @@ def test_decorrelate_dhdl_burnin(gmx_benzene_dHdl_fixture): ) -@pytest.mark.parametrize(("method", "size"), [("all", 1001), ("dE", 334)]) -def test_decorrelate_u_nk_multiple_l(gmx_ABFE_u_nk, method, size): +@pytest.mark.parametrize(("method", "size"), [("all", 501), ("dE", 501)]) +def test_decorrelate_u_nk_multiple_l(multi_index_u_nk, method, size): assert ( len( decorrelate_u_nk( - gmx_ABFE_u_nk, + multi_index_u_nk, method=method, ) ) @@ -501,60 +444,43 @@ def test_decorrelate_u_nk_multiple_l(gmx_ABFE_u_nk, method, size): ) -def test_decorrelate_dhdl_single_l(gmx_benzene_u_nk_fixture): - assert ( - len(decorrelate_dhdl(gmx_benzene_u_nk_fixture, drop_duplicates=True, sort=True)) - == 2001 - ) +def test_decorrelate_dhdl_single_l(u_nk): + assert len(decorrelate_dhdl(u_nk, drop_duplicates=True, sort=True)) == 2001 -def test_decorrelate_dhdl_multiple_l(gmx_ABFE_dhdl): +def test_decorrelate_dhdl_multiple_l(multi_index_dHdl): assert ( len( decorrelate_dhdl( - gmx_ABFE_dhdl, + multi_index_dHdl, ) ) == 501 ) -def test_raise_non_uk(gmx_ABFE_dhdl): +def test_raise_non_uk(multi_index_dHdl): with pytest.raises(ValueError): decorrelate_u_nk( - gmx_ABFE_dhdl, + multi_index_dHdl, ) class TestDhdl2series: - @staticmethod - @pytest.fixture(scope="class") - def dhdl(): - dataset = load_benzene() - dhdl = extract_dHdl(dataset["data"]["Coulomb"][0], 300) - return dhdl - @pytest.mark.parametrize("methodargs", [{}, {"method": "all"}]) - def test_dhdl2series(self, dhdl, methodargs): - series = dhdl2series(dhdl, **methodargs) - assert len(series) == len(dhdl) - assert_allclose(series, dhdl.sum(axis=1)) + def test_dhdl2series(self, dHdl, methodargs): + series = dhdl2series(dHdl, **methodargs) + assert len(series) == len(dHdl) + assert_allclose(series, dHdl.sum(axis=1)) - def test_other_method_ValueError(self, dhdl): + def test_other_method_ValueError(self, dHdl): with pytest.raises( ValueError, match="Only method='all' is supported for dhdl2series()." ): - dhdl2series(dhdl, method="dE") + dhdl2series(dHdl, method="dE") class TestU_nk2series: - @staticmethod - @pytest.fixture(scope="class") - def u_nk(): - dataset = load_benzene() - u_nk = extract_u_nk(dataset["data"]["Coulomb"][0], 300) - return u_nk - @pytest.mark.parametrize( "methodargs,reference", # reference = sum [ diff --git a/src/alchemlyb/tests/test_ti_estimators.py b/src/alchemlyb/tests/test_ti_estimators.py index 38b623a8..5c644547 100644 --- a/src/alchemlyb/tests/test_ti_estimators.py +++ b/src/alchemlyb/tests/test_ti_estimators.py @@ -1,111 +1,83 @@ """Tests for all TI-based estimators in ``alchemlyb``. """ -import pytest - import pandas as pd +import pytest import alchemlyb -from alchemlyb.parsing import gmx, amber, gomc from alchemlyb.estimators import TI -import alchemtest.gmx -import alchemtest.amber -import alchemtest.gomc -from alchemtest.gmx import load_benzene, load_ABFE -from alchemlyb.parsing.gmx import extract_dHdl - +from alchemlyb.parsing import amber -def gmx_benzene_coul_dHdl(): - dataset = alchemtest.gmx.load_benzene() - - dHdl = alchemlyb.concat([gmx.extract_dHdl(filename, T=300) - for filename in dataset['data']['Coulomb']]) +@pytest.fixture +def Coulomb(gmx_benzene_Coulomb_dHdl): + dHdl = alchemlyb.concat(gmx_benzene_Coulomb_dHdl) return dHdl -def gmx_benzene_vdw_dHdl(): - dataset = alchemtest.gmx.load_benzene() - - dHdl = alchemlyb.concat([gmx.extract_dHdl(filename, T=300) - for filename in dataset['data']['VDW']]) +@pytest.fixture +def VDW(gmx_benzene_VDW_dHdl): + dHdl = alchemlyb.concat(gmx_benzene_VDW_dHdl) return dHdl -def gmx_expanded_ensemble_case_1_dHdl(): - dataset = alchemtest.gmx.load_expanded_ensemble_case_1() - - dHdl = alchemlyb.concat([gmx.extract_dHdl(filename, T=300, filter=False) - for filename in dataset['data']['AllStates']]) +@pytest.fixture +def expanded_ensemble_case_1(gmx_expanded_ensemble_case_1_dHdl): + dHdl = alchemlyb.concat(gmx_expanded_ensemble_case_1_dHdl) return dHdl -def gmx_expanded_ensemble_case_2_dHdl(): - dataset = alchemtest.gmx.load_expanded_ensemble_case_2() - - dHdl = alchemlyb.concat([gmx.extract_dHdl(filename, T=300, filter=False) - for filename in dataset['data']['AllStates']]) +@pytest.fixture +def expanded_ensemble_case_2(gmx_expanded_ensemble_case_2_dHdl): + dHdl = alchemlyb.concat(gmx_expanded_ensemble_case_2_dHdl) return dHdl -def gmx_expanded_ensemble_case_3_dHdl(): - dataset = alchemtest.gmx.load_expanded_ensemble_case_3() - - dHdl = alchemlyb.concat([gmx.extract_dHdl(filename, T=300, filter=False) - for filename in dataset['data']['AllStates']]) +@pytest.fixture +def expanded_ensemble_case_3(gmx_expanded_ensemble_case_3_dHdl): + dHdl = alchemlyb.concat(gmx_expanded_ensemble_case_3_dHdl) return dHdl -def gmx_water_particle_with_total_energy_dHdl(): - dataset = alchemtest.gmx.load_water_particle_with_total_energy() - - dHdl = alchemlyb.concat([gmx.extract_dHdl(filename, T=300) - for filename in dataset['data']['AllStates']]) +@pytest.fixture +def water_particle_with_total_energy(gmx_water_particle_with_total_energy_dHdl): + dHdl = alchemlyb.concat(gmx_water_particle_with_total_energy_dHdl) return dHdl -def gmx_water_particle_with_potential_energy_dHdl(): - dataset = alchemtest.gmx.load_water_particle_with_potential_energy() - - dHdl = alchemlyb.concat([gmx.extract_dHdl(filename, T=300) - for filename in dataset['data']['AllStates']]) +@pytest.fixture +def water_particle_with_potential_energy( + gmx_water_particle_with_potential_energy_dHdl, +): + dHdl = alchemlyb.concat(gmx_water_particle_with_potential_energy_dHdl) return dHdl -def gmx_water_particle_without_energy_dHdl(): - dataset = alchemtest.gmx.load_water_particle_without_energy() - - dHdl = alchemlyb.concat([gmx.extract_dHdl(filename, T=300) - for filename in dataset['data']['AllStates']]) +@pytest.fixture +def water_particle_without_energy(gmx_water_particle_without_energy_dHdl): + dHdl = alchemlyb.concat(gmx_water_particle_without_energy_dHdl) return dHdl -def amber_simplesolvated_charge_dHdl(): - dataset = alchemtest.amber.load_simplesolvated() - - dHdl = alchemlyb.concat([amber.extract_dHdl(filename, T=298.0) - for filename in dataset['data']['charge']]) +@pytest.fixture +def simplesolvated_charge(amber_simplesolvated_charge_dHdl): + dHdl = alchemlyb.concat(amber_simplesolvated_charge_dHdl) return dHdl -def amber_simplesolvated_vdw_dHdl(): - dataset = alchemtest.amber.load_simplesolvated() - - dHdl = alchemlyb.concat([amber.extract_dHdl(filename, T=298.0) - for filename in dataset['data']['vdw']]) +@pytest.fixture +def simplesolvated_vdw(amber_simplesolvated_vdw_dHdl): + dHdl = alchemlyb.concat(amber_simplesolvated_vdw_dHdl) return dHdl -def gomc_benzene_dHdl(): - dataset = alchemtest.gomc.load_benzene() - - dHdl = alchemlyb.concat([gomc.extract_dHdl(filename, T=298) - for filename in dataset['data']]) +@pytest.fixture +def benzene(gomc_benzene_dHdl): + dHdl = alchemlyb.concat(gomc_benzene_dHdl) return dHdl class TIestimatorMixin: - def test_get_delta_f(self, X_delta_f): dHdl, E, dE = X_delta_f est = self.cls().fit(dHdl) @@ -117,93 +89,93 @@ def test_get_delta_f(self, X_delta_f): assert E == pytest.approx(delta_f, rel=1e-3) assert dE == pytest.approx(d_delta_f, rel=1e-3) + class TestTI(TIestimatorMixin): - """Tests for TI. + """Tests for TI.""" - """ cls = TI T = 298.0 kT_amber = amber.k_b * T - @pytest.fixture(scope="class", - params = [(gmx_benzene_coul_dHdl, 3.089, 0.02157), - (gmx_benzene_vdw_dHdl, -3.056, 0.04863), - (gmx_expanded_ensemble_case_1_dHdl, 76.220, 0.15568), - (gmx_expanded_ensemble_case_2_dHdl, 76.247, 0.15889), - (gmx_expanded_ensemble_case_3_dHdl, 76.387, 0.12532), - (gmx_water_particle_with_total_energy_dHdl, -11.696, 0.091775), - (gmx_water_particle_with_potential_energy_dHdl, -11.751, 0.091149), - (gmx_water_particle_without_energy_dHdl, -11.687, 0.091604), - (amber_simplesolvated_charge_dHdl, -60.114/kT_amber, 0.08186/kT_amber), - (amber_simplesolvated_vdw_dHdl, 3.824/kT_amber, 0.13254/kT_amber), - ]) + @pytest.fixture( + params=[ + ("Coulomb", 3.089, 0.02157), + ("VDW", -3.056, 0.04863), + ("expanded_ensemble_case_1", 76.220, 0.15568), + ("expanded_ensemble_case_2", 76.247, 0.15889), + ("expanded_ensemble_case_3", 76.387, 0.12532), + ("water_particle_with_total_energy", -11.696, 0.091775), + ("water_particle_with_potential_energy", -11.751, 0.091149), + ("water_particle_without_energy", -11.687, 0.091604), + ("simplesolvated_charge", -60.114 / kT_amber, 0.08186 / kT_amber), + ("simplesolvated_vdw", 3.824 / kT_amber, 0.13254 / kT_amber), + ], + ) def X_delta_f(self, request): get_dHdl, E, dE = request.param - return get_dHdl(), E, dE + return request.getfixturevalue(get_dHdl), E, dE -def test_TI_separate_dhdl_multiple_column(): - dHdl = gomc_benzene_dHdl() + +def test_TI_separate_dhdl_multiple_column(benzene): + dHdl = benzene estimator = TI().fit(dHdl) assert all([isinstance(dhdl, pd.Series) for dhdl in estimator.separate_dhdl()]) assert sorted([len(dhdl) for dhdl in estimator.separate_dhdl()]) == [8, 16] -def test_TI_separate_dhdl_single_column(): - dHdl = gmx_benzene_coul_dHdl() + +def test_TI_separate_dhdl_single_column(Coulomb): + dHdl = Coulomb estimator = TI().fit(dHdl) assert all([isinstance(dhdl, pd.Series) for dhdl in estimator.separate_dhdl()]) - assert [len(dhdl) for dhdl in estimator.separate_dhdl()] == [5, ] - -def test_TI_separate_dhdl_no_pertubed(): - '''The test for the case where two lambda are there and one is not pertubed''' - dHdl = gmx_benzene_coul_dHdl() - dHdl.insert(1, 'bound-lambda', [1.0, ] * len(dHdl)) - dHdl.insert(1, 'bound', [1.0, ] * len(dHdl)) - dHdl.set_index('bound-lambda', append=True, inplace=True) + assert [len(dhdl) for dhdl in estimator.separate_dhdl()] == [ + 5, + ] + + +def test_TI_separate_dhdl_no_pertubed(Coulomb): + """The test for the case where two lambda are there and one is not pertubed""" + dHdl = Coulomb + dHdl.insert(1, "bound-lambda", [1.0] * len(dHdl)) + dHdl.insert(1, "bound", [1.0] * len(dHdl)) + dHdl.set_index("bound-lambda", append=True, inplace=True) estimator = TI().fit(dHdl) assert all([isinstance(dhdl, pd.Series) for dhdl in estimator.separate_dhdl()]) - assert [len(dhdl) for dhdl in estimator.separate_dhdl()] == [5, ] + assert [len(dhdl) for dhdl in estimator.separate_dhdl()] == [5] -class Test_Units(): - '''Test the units.''' - @staticmethod - @pytest.fixture(scope='class') - def dhdl(): - bz = load_benzene().data - dHdl_coul = alchemlyb.concat( - [extract_dHdl(xvg, T=300) for xvg in bz['Coulomb']]) - return dHdl_coul - - def test_ti(self, dhdl): - ti = TI().fit(dhdl) - assert ti.delta_f_.attrs['temperature'] == 300 - assert ti.delta_f_.attrs['energy_unit'] == 'kT' - assert ti.d_delta_f_.attrs['temperature'] == 300 - assert ti.d_delta_f_.attrs['energy_unit'] == 'kT' - assert ti.dhdl.attrs['temperature'] == 300 - assert ti.dhdl.attrs['energy_unit'] == 'kT' - def test_ti_separate_dhdl(self, dhdl): - ti = TI().fit(dhdl) +class Test_Units: + """Test the units.""" + + def test_ti(self, Coulomb): + ti = TI().fit(Coulomb) + assert ti.delta_f_.attrs["temperature"] == 300 + assert ti.delta_f_.attrs["energy_unit"] == "kT" + assert ti.d_delta_f_.attrs["temperature"] == 300 + assert ti.d_delta_f_.attrs["energy_unit"] == "kT" + assert ti.dhdl.attrs["temperature"] == 300 + assert ti.dhdl.attrs["energy_unit"] == "kT" + + def test_ti_separate_dhdl(self, Coulomb): + ti = TI().fit(Coulomb) dhdl_list = ti.separate_dhdl() for dhdl in dhdl_list: - assert dhdl.attrs['temperature'] == 300 - assert dhdl.attrs['energy_unit'] == 'kT' + assert dhdl.attrs["temperature"] == 300 + assert dhdl.attrs["energy_unit"] == "kT" + + +class Test_MultipleColumnUnits: + """Test the case where the index has multiple columns""" -class Test_MultipleColumnUnits(): - '''Test the case where the index has multiple columns''' @staticmethod - @pytest.fixture(scope='class') - def dhdl(): - data = load_ABFE()['data']['complex'] - dhdl = alchemlyb.concat( - [extract_dHdl(data[i], - 300) for i in range(30)]) + @pytest.fixture + def dhdl(gmx_ABFE_complex_dHdl): + dhdl = alchemlyb.concat(gmx_ABFE_complex_dHdl) return dhdl def test_ti_separate_dhdl(self, dhdl): ti = TI().fit(dhdl) dhdl_list = ti.separate_dhdl() for dhdl in dhdl_list: - assert dhdl.attrs['temperature'] == 300 - assert dhdl.attrs['energy_unit'] == 'kT' \ No newline at end of file + assert dhdl.attrs["temperature"] == 300 + assert dhdl.attrs["energy_unit"] == "kT" diff --git a/src/alchemlyb/tests/test_units.py b/src/alchemlyb/tests/test_units.py index 8dc059e9..62467dde 100644 --- a/src/alchemlyb/tests/test_units.py +++ b/src/alchemlyb/tests/test_units.py @@ -1,35 +1,46 @@ -import pytest import pandas as pd +import pytest +from alchemtest.gmx import load_benzene import alchemlyb from alchemlyb import pass_attrs -from alchemtest.gmx import load_benzene -from alchemlyb.parsing.gmx import extract_dHdl, extract_u_nk +from alchemlyb.parsing.gmx import extract_u_nk from alchemlyb.postprocessors.units import to_kT -from alchemlyb.preprocessing import (dhdl2series, u_nk2series, - decorrelate_u_nk, decorrelate_dhdl, - slicing, statistical_inefficiency, - equilibrium_detection) - -def test_noT(): - '''Test no temperature error''' - dataset = load_benzene() - dhdl = extract_dHdl(dataset['data']['Coulomb'][0], 310) - dhdl.attrs.pop('temperature', None) +from alchemlyb.preprocessing import ( + dhdl2series, + u_nk2series, + decorrelate_u_nk, + decorrelate_dhdl, + slicing, + statistical_inefficiency, + equilibrium_detection, +) + + +@pytest.fixture +def dHdl(gmx_benzene_Coulomb_dHdl): + return gmx_benzene_Coulomb_dHdl[0] + + +def test_noT(dHdl): + """Test no temperature error""" + dhdl = dHdl.copy() + dhdl.attrs.pop("temperature", None) with pytest.raises(TypeError): to_kT(dhdl) -def test_nounit(): - '''Test no unit error''' - dataset = load_benzene() - dhdl = extract_dHdl(dataset['data']['Coulomb'][0], 310) - dhdl.attrs.pop('energy_unit', None) + +def test_nounit(dHdl): + """Test no unit error""" + dhdl = dHdl.copy() + dhdl.attrs.pop("energy_unit", None) with pytest.raises(TypeError): to_kT(dhdl) + def test_concat(): - '''Test if different attrs could will give rise to error.''' - d = {'col1': [1, 2], 'col2': [3, 4]} + """Test if different attrs could will give rise to error.""" + d = {"col1": [1, 2], "col2": [3, 4]} df1 = pd.DataFrame(data=d) df1.attrs = {1: 1} df2 = pd.DataFrame(data=d) @@ -37,68 +48,71 @@ def test_concat(): with pytest.raises(ValueError): alchemlyb.concat([df1, df2]) + def test_concat_empty(): - '''Test if empty raise the right error.''' + """Test if empty raise the right error.""" with pytest.raises(ValueError): alchemlyb.concat([]) + def test_setT(): - '''Test setting temperature.''' - df = pd.DataFrame(data={'col1': [1, 2]}) - df.attrs = {'temperature': 300, 'energy_unit': 'kT'} + """Test setting temperature.""" + df = pd.DataFrame(data={"col1": [1, 2]}) + df.attrs = {"temperature": 300, "energy_unit": "kT"} new = to_kT(df, 310) - assert new.attrs['temperature'] == 310 + assert new.attrs["temperature"] == 310 -class Test_Conversion(): - '''Test the preprocessing module.''' - @staticmethod - @pytest.fixture(scope='class') - def dhdl(): - dataset = load_benzene() - dhdl = extract_dHdl(dataset['data']['Coulomb'][0], 310) - return dhdl + +class Test_Conversion: + """Test the preprocessing module.""" def test_kt2kt_number(self, dhdl): new_dhdl = to_kT(dhdl) - assert 12.9 == pytest.approx(new_dhdl.loc[(0.0,0.0)], 0.1) + assert 12.9 == pytest.approx(new_dhdl.loc[(0.0, 0.0)], 0.1) def test_kt2kt_unit(self, dhdl): new_dhdl = to_kT(dhdl) - assert new_dhdl.attrs['energy_unit'] == 'kT' + assert new_dhdl.attrs["energy_unit"] == "kT" def test_kj2kt_unit(self, dhdl): - dhdl.attrs['energy_unit'] = 'kJ/mol' + dhdl = dHdl.copy() + dhdl.attrs["energy_unit"] = "kJ/mol" new_dhdl = to_kT(dhdl) - assert new_dhdl.attrs['energy_unit'] == 'kT' + assert new_dhdl.attrs["energy_unit"] == "kT" def test_kj2kt_number(self, dhdl): - dhdl.attrs['energy_unit'] = 'kJ/mol' + dhdl = dHdl.copy() + dhdl.attrs["energy_unit"] = "kJ/mol" new_dhdl = to_kT(dhdl) - assert 5.0 == pytest.approx(new_dhdl.loc[(0.0,0.0)], 0.1) + assert 5.0 == pytest.approx(new_dhdl.loc[(0.0, 0.0)], 0.1) def test_kcal2kt_unit(self, dhdl): - dhdl.attrs['energy_unit'] = 'kcal/mol' + dhdl = dHdl.copy() + dhdl.attrs["energy_unit"] = "kcal/mol" new_dhdl = to_kT(dhdl) - assert new_dhdl.attrs['energy_unit'] == 'kT' + assert new_dhdl.attrs["energy_unit"] == "kT" def test_kcal2kt_number(self, dhdl): - dhdl.attrs['energy_unit'] = 'kcal/mol' + dhdl = dHdl.copy() + dhdl.attrs["energy_unit"] = "kcal/mol" new_dhdl = to_kT(dhdl) - assert 21.0 == pytest.approx(new_dhdl.loc[(0.0,0.0)], 0.1) + assert 21.0 == pytest.approx(new_dhdl.loc[(0.0, 0.0)], 0.1) def test_unknown2kt(self, dhdl): - dhdl.attrs['energy_unit'] = 'ddd' + dhdl = dHdl.copy() + dhdl.attrs["energy_unit"] = "ddd" with pytest.raises(ValueError): to_kT(dhdl) + def test_pd_concat(): - '''Test if concat will preserve the metadata. + """Test if concat will preserve the metadata. When this test is being made, the pd.concat will discard the attrs of the input dataframe. However, this should get fixed in the future. pandas-dev/pandas#28283 - ''' - d = {'col1': [1, 2], 'col2': [3, 4]} + """ + d = {"col1": [1, 2], "col2": [3, 4]} df1 = pd.DataFrame(data=d) df1.attrs = {1: 1} df2 = pd.DataFrame(data=d) @@ -106,8 +120,9 @@ def test_pd_concat(): df = pd.concat([df1, df2]) assert df.attrs == {1: 1} + def test_pass_attrs(): - d = {'col1': [1, 2], 'col2': [3, 4]} + d = {"col1": [1, 2], "col2": [3, 4]} df1 = pd.DataFrame(data=d) df1.attrs = {1: 1} df2 = pd.DataFrame(data=d) @@ -116,40 +131,41 @@ def test_pass_attrs(): @pass_attrs def concat(df1, df2): return pd.concat([df1, df2]) + assert concat(df1, df2).attrs == {1: 1} + def test_pd_slice(): - '''Test if slicing will preserve the metadata.''' - d = {'col1': [1, 2], 'col2': [3, 4]} + """Test if slicing will preserve the metadata.""" + d = {"col1": [1, 2], "col2": [3, 4]} df = pd.DataFrame(data=d) df.attrs = {1: 1} assert df[::2].attrs == {1: 1} -class TestRetainUnit(): - '''This test tests if the functions that should retain the unit would actually - retain the units.''' - @staticmethod - @pytest.fixture(scope='class') - def dhdl(): - dataset = load_benzene() - dhdl = extract_dHdl(dataset['data']['Coulomb'][0], 310) - return dhdl + +class TestRetainUnit: + """This test tests if the functions that should retain the unit would actually + retain the units.""" @staticmethod - @pytest.fixture(scope='class') + @pytest.fixture(scope="class") def u_nk(): dataset = load_benzene() - u_nk = extract_u_nk(dataset['data']['Coulomb'][0], 310) + u_nk = extract_u_nk(dataset["data"]["Coulomb"][0], 310) return u_nk - @pytest.mark.parametrize('func,fixture_in', - [(dhdl2series, 'dhdl'), - (u_nk2series, 'u_nk'), - (decorrelate_u_nk, 'u_nk'), - (decorrelate_dhdl, 'dhdl'), - (slicing, 'dhdl'), - (statistical_inefficiency, 'dhdl'), - (equilibrium_detection, 'dhdl')]) + @pytest.mark.parametrize( + "func,fixture_in", + [ + (dhdl2series, "dhdl"), + (u_nk2series, "u_nk"), + (decorrelate_u_nk, "u_nk"), + (decorrelate_dhdl, "dhdl"), + (slicing, "dhdl"), + (statistical_inefficiency, "dhdl"), + (equilibrium_detection, "dhdl"), + ], + ) def test_function(self, func, fixture_in, request): result = func(request.getfixturevalue(fixture_in)) - assert result.attrs['energy_unit'] is not None + assert result.attrs["energy_unit"] is not None From 91a39289cc6e36ea5ac3155c8171d31d854eb4cb Mon Sep 17 00:00:00 2001 From: "William (Zhiyi) Wu" Date: Sun, 4 Dec 2022 18:32:04 +0000 Subject: [PATCH 10/21] update --- environment.yml | 4 +++- readthedocs.yml | 9 ++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/environment.yml b/environment.yml index e8c92a4a..61a4b37d 100644 --- a/environment.yml +++ b/environment.yml @@ -2,10 +2,12 @@ name: alchemlyb channels: - conda-forge dependencies: -- python +- python=3.8 - numpy - pandas - pymbar >=3.0.5,<4 - scipy - scikit-learn - matplotlib +- pip + - -e . diff --git a/readthedocs.yml b/readthedocs.yml index 35ffa343..7b0ff11a 100644 --- a/readthedocs.yml +++ b/readthedocs.yml @@ -7,11 +7,10 @@ sphinx: formats: - pdf -python: - version: 3.8 - install: - - method: pip - path: . +build: + os: "ubuntu-20.04" + tools: + python: "mambaforge-4.10" conda: environment: environment.yml From 0d81a47be728f10cdb94a0c026c8143c6d706c44 Mon Sep 17 00:00:00 2001 From: "William (Zhiyi) Wu" Date: Sun, 4 Dec 2022 18:36:08 +0000 Subject: [PATCH 11/21] update --- environment.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/environment.yml b/environment.yml index 61a4b37d..6521b1f4 100644 --- a/environment.yml +++ b/environment.yml @@ -10,4 +10,4 @@ dependencies: - scikit-learn - matplotlib - pip - - -e . + - . From 1afd2b98817676cbd1e1d415b2c0960418bb5a4b Mon Sep 17 00:00:00 2001 From: "William (Zhiyi) Wu" Date: Sun, 4 Dec 2022 18:37:13 +0000 Subject: [PATCH 12/21] update --- environment.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/environment.yml b/environment.yml index 6521b1f4..cdc9d1e9 100644 --- a/environment.yml +++ b/environment.yml @@ -9,5 +9,3 @@ dependencies: - scipy - scikit-learn - matplotlib -- pip - - . From 1ab2c67ae9ee1c21168a451971a87bcc227c7680 Mon Sep 17 00:00:00 2001 From: "William (Zhiyi) Wu" Date: Sun, 4 Dec 2022 20:03:58 +0000 Subject: [PATCH 13/21] update --- readthedocs.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/readthedocs.yml b/readthedocs.yml index 7b0ff11a..99e5cc16 100644 --- a/readthedocs.yml +++ b/readthedocs.yml @@ -14,3 +14,8 @@ build: conda: environment: environment.yml + +python: + install: + - method: pip + path: . \ No newline at end of file From c8799240f427f89c3998a636924f3949c57ffc5a Mon Sep 17 00:00:00 2001 From: "William (Zhiyi) Wu" Date: Sun, 4 Dec 2022 20:19:31 +0000 Subject: [PATCH 14/21] update --- readthedocs.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/readthedocs.yml b/readthedocs.yml index 99e5cc16..71851b19 100644 --- a/readthedocs.yml +++ b/readthedocs.yml @@ -18,4 +18,4 @@ conda: python: install: - method: pip - path: . \ No newline at end of file + path: . From ded5d1bc2c7fec534d23dfcd5590ed75d49e264e Mon Sep 17 00:00:00 2001 From: "William (Zhiyi) Wu" Date: Sun, 4 Dec 2022 20:34:41 +0000 Subject: [PATCH 15/21] update --- src/alchemlyb/tests/conftest.py | 5 + src/alchemlyb/tests/test_units.py | 77 ++++---- src/alchemlyb/tests/test_visualisation.py | 211 ++++++++++++---------- 3 files changed, 157 insertions(+), 136 deletions(-) diff --git a/src/alchemlyb/tests/conftest.py b/src/alchemlyb/tests/conftest.py index 5e4c8a1c..b5b485e9 100644 --- a/src/alchemlyb/tests/conftest.py +++ b/src/alchemlyb/tests/conftest.py @@ -47,6 +47,11 @@ def gmx_benzene_VDW_u_nk(gmx_benzene): return [gmx.extract_u_nk(file, T=300) for file in gmx_benzene["VDW"]] +@pytest.fixture +def gmx_benzene_VDW_dHdl(gmx_benzene): + return [gmx.extract_dHdl(file, T=300) for file in gmx_benzene["VDW"]] + + @pytest.fixture def gmx_ABFE(): dataset = load_ABFE() diff --git a/src/alchemlyb/tests/test_units.py b/src/alchemlyb/tests/test_units.py index 62467dde..e3c11ea4 100644 --- a/src/alchemlyb/tests/test_units.py +++ b/src/alchemlyb/tests/test_units.py @@ -1,10 +1,8 @@ import pandas as pd import pytest -from alchemtest.gmx import load_benzene import alchemlyb from alchemlyb import pass_attrs -from alchemlyb.parsing.gmx import extract_u_nk from alchemlyb.postprocessors.units import to_kT from alchemlyb.preprocessing import ( dhdl2series, @@ -22,20 +20,23 @@ def dHdl(gmx_benzene_Coulomb_dHdl): return gmx_benzene_Coulomb_dHdl[0] +@pytest.fixture +def u_nk(gmx_benzene_Coulomb_u_nk): + return gmx_benzene_Coulomb_u_nk[0] + + def test_noT(dHdl): """Test no temperature error""" - dhdl = dHdl.copy() - dhdl.attrs.pop("temperature", None) + dHdl.attrs.pop("temperature", None) with pytest.raises(TypeError): - to_kT(dhdl) + to_kT(dHdl) def test_nounit(dHdl): """Test no unit error""" - dhdl = dHdl.copy() - dhdl.attrs.pop("energy_unit", None) + dHdl.attrs.pop("energy_unit", None) with pytest.raises(TypeError): - to_kT(dhdl) + to_kT(dHdl) def test_concat(): @@ -66,43 +67,38 @@ def test_setT(): class Test_Conversion: """Test the preprocessing module.""" - def test_kt2kt_number(self, dhdl): - new_dhdl = to_kT(dhdl) + def test_kt2kt_number(self, dHdl): + new_dhdl = to_kT(dHdl) assert 12.9 == pytest.approx(new_dhdl.loc[(0.0, 0.0)], 0.1) - def test_kt2kt_unit(self, dhdl): - new_dhdl = to_kT(dhdl) + def test_kt2kt_unit(self, dHdl): + new_dhdl = to_kT(dHdl) assert new_dhdl.attrs["energy_unit"] == "kT" - def test_kj2kt_unit(self, dhdl): - dhdl = dHdl.copy() - dhdl.attrs["energy_unit"] = "kJ/mol" - new_dhdl = to_kT(dhdl) + def test_kj2kt_unit(self, dHdl): + dHdl.attrs["energy_unit"] = "kJ/mol" + new_dhdl = to_kT(dHdl) assert new_dhdl.attrs["energy_unit"] == "kT" - def test_kj2kt_number(self, dhdl): - dhdl = dHdl.copy() - dhdl.attrs["energy_unit"] = "kJ/mol" - new_dhdl = to_kT(dhdl) + def test_kj2kt_number(self, dHdl): + dHdl.attrs["energy_unit"] = "kJ/mol" + new_dhdl = to_kT(dHdl) assert 5.0 == pytest.approx(new_dhdl.loc[(0.0, 0.0)], 0.1) - def test_kcal2kt_unit(self, dhdl): - dhdl = dHdl.copy() - dhdl.attrs["energy_unit"] = "kcal/mol" - new_dhdl = to_kT(dhdl) + def test_kcal2kt_unit(self, dHdl): + dHdl.attrs["energy_unit"] = "kcal/mol" + new_dhdl = to_kT(dHdl) assert new_dhdl.attrs["energy_unit"] == "kT" - def test_kcal2kt_number(self, dhdl): - dhdl = dHdl.copy() - dhdl.attrs["energy_unit"] = "kcal/mol" - new_dhdl = to_kT(dhdl) + def test_kcal2kt_number(self, dHdl): + dHdl.attrs["energy_unit"] = "kcal/mol" + new_dhdl = to_kT(dHdl) assert 21.0 == pytest.approx(new_dhdl.loc[(0.0, 0.0)], 0.1) - def test_unknown2kt(self, dhdl): - dhdl = dHdl.copy() - dhdl.attrs["energy_unit"] = "ddd" + def test_unknown2kt(self, dHdl): + dHdl.attrs["energy_unit"] = "ddd" with pytest.raises(ValueError): - to_kT(dhdl) + to_kT(dHdl) def test_pd_concat(): @@ -147,23 +143,16 @@ class TestRetainUnit: """This test tests if the functions that should retain the unit would actually retain the units.""" - @staticmethod - @pytest.fixture(scope="class") - def u_nk(): - dataset = load_benzene() - u_nk = extract_u_nk(dataset["data"]["Coulomb"][0], 310) - return u_nk - @pytest.mark.parametrize( "func,fixture_in", [ - (dhdl2series, "dhdl"), + (dhdl2series, "dHdl"), (u_nk2series, "u_nk"), (decorrelate_u_nk, "u_nk"), - (decorrelate_dhdl, "dhdl"), - (slicing, "dhdl"), - (statistical_inefficiency, "dhdl"), - (equilibrium_detection, "dhdl"), + (decorrelate_dhdl, "dHdl"), + (slicing, "dHdl"), + (statistical_inefficiency, "dHdl"), + (equilibrium_detection, "dHdl"), ], ) def test_function(self, func, fixture_in, request): diff --git a/src/alchemlyb/tests/test_visualisation.py b/src/alchemlyb/tests/test_visualisation.py index 509ecffd..c8f522f9 100644 --- a/src/alchemlyb/tests/test_visualisation.py +++ b/src/alchemlyb/tests/test_visualisation.py @@ -1,42 +1,48 @@ import matplotlib import matplotlib.pyplot as plt -import pandas as pd import numpy as np +import pandas as pd import pytest +from alchemtest.gmx import load_benzene import alchemlyb -from alchemtest.gmx import load_benzene -from alchemlyb.parsing.gmx import extract_u_nk, extract_dHdl +from alchemlyb.convergence import forward_backward_convergence from alchemlyb.estimators import MBAR, TI, BAR +from alchemlyb.visualisation import plot_convergence +from alchemlyb.visualisation.dF_state import plot_dF_state from alchemlyb.visualisation.mbar_matrix import plot_mbar_overlap_matrix from alchemlyb.visualisation.ti_dhdl import plot_ti_dhdl -from alchemlyb.visualisation.dF_state import plot_dF_state -from alchemlyb.visualisation import plot_convergence -from alchemlyb.convergence import forward_backward_convergence -def test_plot_mbar_omatrix(): - '''Just test if the plot runs''' - bz = load_benzene().data - u_nk_coul = alchemlyb.concat([extract_u_nk(xvg, T=300) for xvg in bz['Coulomb']]) + +def test_plot_mbar_omatrix(gmx_benzene_Coulomb_u_nk): + """Just test if the plot runs""" + u_nk_coul = alchemlyb.concat(gmx_benzene_Coulomb_u_nk) mbar_coul = MBAR() mbar_coul.fit(u_nk_coul) - assert isinstance(plot_mbar_overlap_matrix(mbar_coul.overlap_matrix), - matplotlib.axes.Axes) - assert isinstance(plot_mbar_overlap_matrix(mbar_coul.overlap_matrix, [1,]), - matplotlib.axes.Axes) + assert isinstance( + plot_mbar_overlap_matrix(mbar_coul.overlap_matrix), matplotlib.axes.Axes + ) + assert isinstance( + plot_mbar_overlap_matrix( + mbar_coul.overlap_matrix, + [ + 1, + ], + ), + matplotlib.axes.Axes, + ) # Bump up coverage overlap_maxtrix = mbar_coul.overlap_matrix - overlap_maxtrix[0,0] = 0.0025 + overlap_maxtrix[0, 0] = 0.0025 overlap_maxtrix[-1, -1] = 0.9975 - assert isinstance(plot_mbar_overlap_matrix(overlap_maxtrix), - matplotlib.axes.Axes) + assert isinstance(plot_mbar_overlap_matrix(overlap_maxtrix), matplotlib.axes.Axes) -def test_plot_ti_dhdl(): - '''Just test if the plot runs''' - bz = load_benzene().data - dHdl_coul = alchemlyb.concat([extract_dHdl(xvg, T=300) for xvg in bz['Coulomb']]) + +def test_plot_ti_dhdl(gmx_benzene_Coulomb_dHdl, gmx_benzene_VDW_dHdl): + """Just test if the plot runs""" + dHdl_coul = alchemlyb.concat(gmx_benzene_Coulomb_dHdl) ti_coul = TI() ti_coul.fit(dHdl_coul) @@ -45,36 +51,40 @@ def test_plot_ti_dhdl(): plt.close(ax.figure) fig, ax = plt.subplots(figsize=(8, 6)) - assert isinstance(plot_ti_dhdl(ti_coul, ax=ax), - matplotlib.axes.Axes) - assert isinstance(plot_ti_dhdl(ti_coul, labels=['Coul']), - matplotlib.axes.Axes) - assert isinstance(plot_ti_dhdl(ti_coul, labels=['Coul'], colors=['r']), - matplotlib.axes.Axes) + assert isinstance(plot_ti_dhdl(ti_coul, ax=ax), matplotlib.axes.Axes) + assert isinstance(plot_ti_dhdl(ti_coul, labels=["Coul"]), matplotlib.axes.Axes) + assert isinstance( + plot_ti_dhdl(ti_coul, labels=["Coul"], colors=["r"]), matplotlib.axes.Axes + ) plt.close(fig) - dHdl_vdw = alchemlyb.concat([extract_dHdl(xvg, T=300) for xvg in bz['VDW']]) + dHdl_vdw = alchemlyb.concat(gmx_benzene_VDW_dHdl) ti_vdw = TI().fit(dHdl_vdw) ax = plot_ti_dhdl([ti_coul, ti_vdw]) assert isinstance(ax, matplotlib.axes.Axes) plt.close(ax.figure) ti_coul.dhdl = pd.DataFrame.from_dict( - {'fep': range(100)}, - orient='index', - columns=np.arange(100)/100).T + {"fep": range(100)}, orient="index", columns=np.arange(100) / 100 + ).T ti_coul.dhdl.attrs = dHdl_vdw.attrs ax = plot_ti_dhdl(ti_coul) assert isinstance(ax, matplotlib.axes.Axes) plt.close(ax.figure) -def test_plot_dF_state(): - '''Just test if the plot runs''' + +def test_plot_dF_state( + gmx_benzene_Coulomb_dHdl, + gmx_benzene_Coulomb_u_nk, + gmx_benzene_VDW_u_nk, + gmx_benzene_VDW_dHdl, +): + """Just test if the plot runs""" bz = load_benzene().data - u_nk_coul = alchemlyb.concat([extract_u_nk(xvg, T=300) for xvg in bz['Coulomb']]) - dHdl_coul = alchemlyb.concat([extract_dHdl(xvg, T=300) for xvg in bz['Coulomb']]) - u_nk_vdw = alchemlyb.concat([extract_u_nk(xvg, T=300) for xvg in bz['VDW']]) - dHdl_vdw = alchemlyb.concat([extract_dHdl(xvg, T=300) for xvg in bz['VDW']]) + u_nk_coul = alchemlyb.concat(gmx_benzene_Coulomb_u_nk) + dHdl_coul = alchemlyb.concat(gmx_benzene_Coulomb_dHdl) + u_nk_vdw = alchemlyb.concat(gmx_benzene_VDW_u_nk) + dHdl_vdw = alchemlyb.concat(gmx_benzene_VDW_dHdl) ti_coul = TI().fit(dHdl_coul) ti_vdw = TI().fit(dHdl_vdw) @@ -83,39 +93,47 @@ def test_plot_dF_state(): mbar_coul = MBAR().fit(u_nk_coul) mbar_vdw = MBAR().fit(u_nk_vdw) - dhdl_data = [(ti_coul, ti_vdw), - (bar_coul, bar_vdw), - (mbar_coul, mbar_vdw), ] - fig = plot_dF_state(dhdl_data, orientation='portrait') + dhdl_data = [ + (ti_coul, ti_vdw), + (bar_coul, bar_vdw), + (mbar_coul, mbar_vdw), + ] + fig = plot_dF_state(dhdl_data, orientation="portrait") assert isinstance(fig, matplotlib.figure.Figure) plt.close(fig) - fig = plot_dF_state(dhdl_data, orientation='landscape') + fig = plot_dF_state(dhdl_data, orientation="landscape") assert isinstance(fig, matplotlib.figure.Figure) plt.close(fig) - fig = plot_dF_state(dhdl_data, labels=['MBAR', 'TI', 'BAR']) + fig = plot_dF_state(dhdl_data, labels=["MBAR", "TI", "BAR"]) assert isinstance(fig, matplotlib.figure.Figure) plt.close(fig) with pytest.raises(ValueError): - fig = plot_dF_state(dhdl_data, labels=['MBAR', 'TI',]) - - fig = plot_dF_state(dhdl_data, colors=['#C45AEC', '#33CC33', '#F87431']) + fig = plot_dF_state( + dhdl_data, + labels=[ + "MBAR", + "TI", + ], + ) + + fig = plot_dF_state(dhdl_data, colors=["#C45AEC", "#33CC33", "#F87431"]) assert isinstance(fig, matplotlib.figure.Figure) plt.close(fig) with pytest.raises(ValueError): - fig = plot_dF_state(dhdl_data, colors=['#C45AEC', '#33CC33']) + fig = plot_dF_state(dhdl_data, colors=["#C45AEC", "#33CC33"]) with pytest.raises(ValueError): - fig = plot_dF_state(dhdl_data, orientation='xxx') + fig = plot_dF_state(dhdl_data, orientation="xxx") - fig = plot_dF_state(ti_coul, orientation='landscape') + fig = plot_dF_state(ti_coul, orientation="landscape") assert isinstance(fig, matplotlib.figure.Figure) plt.close(fig) - fig = plot_dF_state(ti_coul, orientation='portrait') + fig = plot_dF_state(ti_coul, orientation="portrait") assert isinstance(fig, matplotlib.figure.Figure) plt.close(fig) @@ -127,80 +145,89 @@ def test_plot_dF_state(): assert isinstance(fig, matplotlib.figure.Figure) plt.close(fig) -def test_plot_convergence_dataframe(): - bz = load_benzene().data - data_list = [extract_u_nk(xvg, T=300) for xvg in bz['Coulomb']] - df = forward_backward_convergence(data_list, 'MBAR') + +def test_plot_convergence_dataframe(gmx_benzene_Coulomb_u_nk): + df = forward_backward_convergence(gmx_benzene_Coulomb_u_nk, "MBAR") ax = plot_convergence(df) assert isinstance(ax, matplotlib.axes.Axes) plt.close(ax.figure) + def test_plot_convergence_dataframe_noerr(): # Test the input from R_c - data = pd.DataFrame(data={'Forward': range(100), - 'Backward': range(100), - 'data_fraction': np.linspace(0,1,100)}) - data.attrs = {'temperature': 300, 'energy_unit': 'kT'} + data = pd.DataFrame( + data={ + "Forward": range(100), + "Backward": range(100), + "data_fraction": np.linspace(0, 1, 100), + } + ) + data.attrs = {"temperature": 300, "energy_unit": "kT"} ax = plot_convergence(data, final_error=2) assert isinstance(ax, matplotlib.axes.Axes) plt.close(ax.figure) -def test_plot_convergence(): - bz = load_benzene().data - data_list = [extract_u_nk(xvg, T=300) for xvg in bz['Coulomb']] + +def test_plot_convergence(gmx_benzene_Coulomb_u_nk): + data_list = gmx_benzene_Coulomb_u_nk forward = [] forward_error = [] backward = [] backward_error = [] num_points = 10 - for i in range(1, num_points+1): + for i in range(1, num_points + 1): # Do the forward - slice = int(len(data_list[0])/num_points*i) + slice = int(len(data_list[0]) / num_points * i) u_nk_coul = alchemlyb.concat([data[:slice] for data in data_list]) estimate = MBAR().fit(u_nk_coul) - forward.append(estimate.delta_f_.loc[0.0,1.0]) - forward_error.append(estimate.d_delta_f_.loc[0.0,1.0]) + forward.append(estimate.delta_f_.loc[0.0, 1.0]) + forward_error.append(estimate.d_delta_f_.loc[0.0, 1.0]) # Do the backward u_nk_coul = alchemlyb.concat([data[-slice:] for data in data_list]) estimate = MBAR().fit(u_nk_coul) - backward.append(estimate.delta_f_.loc[0.0,1.0]) - backward_error.append(estimate.d_delta_f_.loc[0.0,1.0]) - - df = pd.DataFrame(data={'Forward': forward, - 'Forward_Error': forward_error, - 'Backward': backward, - 'Backward_Error': backward_error}) + backward.append(estimate.delta_f_.loc[0.0, 1.0]) + backward_error.append(estimate.d_delta_f_.loc[0.0, 1.0]) + + df = pd.DataFrame( + data={ + "Forward": forward, + "Forward_Error": forward_error, + "Backward": backward, + "Backward_Error": backward_error, + } + ) df.attrs = estimate.delta_f_.attrs ax = plot_convergence(df) assert isinstance(ax, matplotlib.axes.Axes) plt.close(ax.figure) -class Test_Units(): + +class Test_Units: @staticmethod - @pytest.fixture(scope='class') - def estimaters(): - bz = load_benzene().data - dHdl_coul = alchemlyb.concat( - [extract_dHdl(xvg, T=300) for xvg in bz['Coulomb']]) + @pytest.fixture() + def estimaters(gmx_benzene_Coulomb_dHdl, gmx_benzene_Coulomb_u_nk): + dHdl_coul = alchemlyb.concat(gmx_benzene_Coulomb_dHdl) ti = TI().fit(dHdl_coul) - - u_nk_coul = alchemlyb.concat( - [extract_u_nk(xvg, T=300) for xvg in bz['Coulomb']]) + u_nk_coul = alchemlyb.concat(gmx_benzene_Coulomb_u_nk) mbar = MBAR().fit(u_nk_coul) return ti, mbar @staticmethod - @pytest.fixture(scope='class') + @pytest.fixture(scope="class") def convergence(): - df = pd.DataFrame(data={'Forward': range(10), - 'Forward_Error': range(10), - 'Backward': range(10), - 'Backward_Error': range(10)}) - df.attrs = {'temperature': 300, 'energy_unit': 'kT'} + df = pd.DataFrame( + data={ + "Forward": range(10), + "Forward_Error": range(10), + "Backward": range(10), + "Backward_Error": range(10), + } + ) + df.attrs = {"temperature": 300, "energy_unit": "kT"} return df - @pytest.mark.parametrize('units', [None, 'kT', 'kJ/mol', 'kcal/mol']) + @pytest.mark.parametrize("units", [None, "kT", "kJ/mol", "kcal/mol"]) def test_plot_dF_state(self, estimaters, units): fig = plot_dF_state(estimaters, units=units) assert isinstance(fig, matplotlib.figure.Figure) @@ -208,9 +235,9 @@ def test_plot_dF_state(self, estimaters, units): def test_plot_dF_state_unknown(self, estimaters): with pytest.raises(ValueError): - fig = plot_dF_state(estimaters, units='ddd') + fig = plot_dF_state(estimaters, units="ddd") - @pytest.mark.parametrize('units', [None, 'kT', 'kJ/mol', 'kcal/mol']) + @pytest.mark.parametrize("units", [None, "kT", "kJ/mol", "kcal/mol"]) def test_plot_ti_dhdl(self, estimaters, units): ti, mbar = estimaters ax = plot_ti_dhdl(ti, units=units) @@ -220,9 +247,9 @@ def test_plot_ti_dhdl(self, estimaters, units): def test_plot_ti_dhdl_unknown(self, estimaters): ti, mbar = estimaters with pytest.raises(ValueError): - fig = plot_ti_dhdl(ti, units='ddd') + fig = plot_ti_dhdl(ti, units="ddd") - @pytest.mark.parametrize('units', [None, 'kT', 'kJ/mol', 'kcal/mol']) + @pytest.mark.parametrize("units", [None, "kT", "kJ/mol", "kcal/mol"]) def test_plot_convergence(self, convergence, units): ax = plot_convergence(convergence) assert isinstance(ax, matplotlib.axes.Axes) From 06c1de02100962be91a8190666e1214d83d4203a Mon Sep 17 00:00:00 2001 From: "William (Zhiyi) Wu" Date: Mon, 5 Dec 2022 17:27:41 +0000 Subject: [PATCH 16/21] update --- src/alchemlyb/tests/test_preprocessing.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/alchemlyb/tests/test_preprocessing.py b/src/alchemlyb/tests/test_preprocessing.py index a3b15f83..5e2cc611 100644 --- a/src/alchemlyb/tests/test_preprocessing.py +++ b/src/alchemlyb/tests/test_preprocessing.py @@ -65,14 +65,11 @@ def test_basic_slicing(self, data, size, request): == size ) - def test_unchanged(self): + def test_unchanged(self, namd_idws): # NAMD energy files only have dE for adjacent lambdas, this ensures # that the slicer will not drop these rows as they have NaN values. - file = load_idws().data['forward'][0] - u_nk = namd.extract_u_nk(file, 298) - # Do the pre-processing as the u_nk are from all lambdas - groups = u_nk.groupby('fep-lambda') + groups = namd_idws.groupby('fep-lambda') for key, group in groups: group = group[~group.index.duplicated(keep='first')] df = self.slicer(group, None, None, None) From 58653848aed25bb5e05760fa3faece743cfdc6e4 Mon Sep 17 00:00:00 2001 From: "William (Zhiyi) Wu" Date: Tue, 6 Dec 2022 10:18:07 +0000 Subject: [PATCH 17/21] black --- src/alchemlyb/__init__.py | 35 +- src/alchemlyb/convergence/convergence.py | 130 +++--- src/alchemlyb/estimators/__init__.py | 2 +- src/alchemlyb/estimators/bar_.py | 52 ++- src/alchemlyb/estimators/base.py | 9 +- src/alchemlyb/estimators/mbar_.py | 106 +++-- src/alchemlyb/estimators/ti_.py | 47 +- src/alchemlyb/parsing/__init__.py | 21 +- src/alchemlyb/parsing/amber.py | 186 ++++---- src/alchemlyb/parsing/gmx.py | 156 ++++--- src/alchemlyb/parsing/gomc.py | 68 +-- src/alchemlyb/parsing/namd.py | 170 ++++--- src/alchemlyb/parsing/util.py | 28 +- src/alchemlyb/postprocessors/__init__.py | 2 +- src/alchemlyb/postprocessors/units.py | 53 +-- src/alchemlyb/preprocessing/__init__.py | 24 +- src/alchemlyb/preprocessing/subsampling.py | 173 ++++--- src/alchemlyb/tests/parsing/test_amber.py | 100 +++-- src/alchemlyb/tests/parsing/test_gmx.py | 252 ++++++----- src/alchemlyb/tests/parsing/test_gomc.py | 35 +- src/alchemlyb/tests/parsing/test_namd.py | 247 ++++++---- src/alchemlyb/tests/parsing/test_util.py | 85 ++-- src/alchemlyb/tests/test_convergence.py | 108 +++-- src/alchemlyb/tests/test_fep_estimators.py | 240 ++++++---- src/alchemlyb/tests/test_import.py | 3 +- src/alchemlyb/tests/test_preprocessing.py | 497 ++++++++++++--------- src/alchemlyb/tests/test_ti_estimators.py | 200 ++++++--- src/alchemlyb/tests/test_units.py | 134 +++--- src/alchemlyb/tests/test_version.py | 4 +- src/alchemlyb/tests/test_visualisation.py | 190 ++++---- src/alchemlyb/tests/test_workflow.py | 17 +- src/alchemlyb/tests/test_workflow_ABFE.py | 456 +++++++++++-------- src/alchemlyb/visualisation/__init__.py | 4 +- src/alchemlyb/visualisation/convergence.py | 201 +++++---- src/alchemlyb/visualisation/dF_state.py | 171 ++++--- src/alchemlyb/visualisation/mbar_matrix.py | 99 ++-- src/alchemlyb/visualisation/ti_dhdl.py | 103 +++-- src/alchemlyb/workflows/__init__.py | 3 +- src/alchemlyb/workflows/abfe.py | 495 +++++++++++--------- src/alchemlyb/workflows/base.py | 22 +- 40 files changed, 2895 insertions(+), 2033 deletions(-) diff --git a/src/alchemlyb/__init__.py b/src/alchemlyb/__init__.py index 7cfae4b9..d73b1822 100644 --- a/src/alchemlyb/__init__.py +++ b/src/alchemlyb/__init__.py @@ -1,27 +1,32 @@ -import pandas as pd from functools import wraps +import pandas as pd + from ._version import get_versions -__version__ = get_versions()['version'] + +__version__ = get_versions()["version"] del get_versions + def pass_attrs(func): - '''Pass the attrs from the first positional argument to the output + """Pass the attrs from the first positional argument to the output dataframe. - - + + .. versionadded:: 0.5.0 - ''' + """ @wraps(func) - def wrapper(input_dataframe, *args,**kwargs): - dataframe = func(input_dataframe, *args,**kwargs) + def wrapper(input_dataframe, *args, **kwargs): + dataframe = func(input_dataframe, *args, **kwargs) dataframe.attrs = input_dataframe.attrs return dataframe + return wrapper + def concat(objs, *args, **kwargs): - '''Concatenate pandas objects while persevering the attrs. + """Concatenate pandas objects while persevering the attrs. Concatenate pandas objects along a particular axis with optional set logic along the other axes. If all pandas objects have the same attrs @@ -46,16 +51,16 @@ def concat(objs, *args, **kwargs): See Also -------- pandas.concat - - - .. versionadded:: 0.5.0''' + + + .. versionadded:: 0.5.0""" # Sanity check try: attrs = objs[0].attrs - except IndexError: # except empty list as input - raise ValueError('No objects to concatenate') + except IndexError: # except empty list as input + raise ValueError("No objects to concatenate") for obj in objs: if attrs != obj.attrs: - raise ValueError('All pandas objects should have the same attrs.') + raise ValueError("All pandas objects should have the same attrs.") return pd.concat(objs, *args, **kwargs) diff --git a/src/alchemlyb/convergence/convergence.py b/src/alchemlyb/convergence/convergence.py index 372bd176..ea02afa3 100644 --- a/src/alchemlyb/convergence/convergence.py +++ b/src/alchemlyb/convergence/convergence.py @@ -3,17 +3,19 @@ import logging from warnings import warn -import pandas as pd import numpy as np +import pandas as pd -from ..estimators import BAR, TI, FEP_ESTIMATORS, TI_ESTIMATORS -from ..estimators import AutoMBAR as MBAR from .. import concat +from ..estimators import AutoMBAR as MBAR +from ..estimators import BAR, TI, FEP_ESTIMATORS, TI_ESTIMATORS from ..postprocessors.units import to_kT +estimators_dispatch = {"BAR": BAR, "TI": TI, "MBAR": MBAR} + -def forward_backward_convergence(df_list, estimator='MBAR', num=10, **kwargs): - '''Forward and backward convergence of the free energy estimate. +def forward_backward_convergence(df_list, estimator="MBAR", num=10, **kwargs): + """Forward and backward convergence of the free energy estimate. Generate the free energy estimate as a function of time in both directions, with the specified number of equally spaced points in the time @@ -69,16 +71,17 @@ def forward_backward_convergence(df_list, estimator='MBAR', num=10, **kwargs): The default for using ``estimator='MBAR'`` was changed from :class:`~alchemlyb.estimators.MBAR` to :class:`~alchemlyb.estimators.AutoMBAR`. - ''' - logger = logging.getLogger('alchemlyb.convergence.' - 'forward_backward_convergence') - logger.info('Start convergence analysis.') - logger.info('Check data availability.') + """ + logger = logging.getLogger("alchemlyb.convergence." "forward_backward_convergence") + logger.info("Start convergence analysis.") + logger.info("Check data availability.") if estimator.upper() != estimator: - warn("Using lower-case strings for the 'estimator' kwarg in " - "convergence.forward_backward_convergence() is deprecated in " - "1.0.0 and only upper case will be accepted in 2.0.0", - DeprecationWarning) + warn( + "Using lower-case strings for the 'estimator' kwarg in " + "convergence.forward_backward_convergence() is deprecated in " + "1.0.0 and only upper case will be accepted in 2.0.0", + DeprecationWarning, + ) estimator = estimator.upper() if estimator not in (FEP_ESTIMATORS + TI_ESTIMATORS): @@ -87,62 +90,78 @@ def forward_backward_convergence(df_list, estimator='MBAR', num=10, **kwargs): raise ValueError(msg) else: # select estimator class by name - estimator_fit = globals()[estimator](**kwargs).fit - logger.info(f'Use {estimator} estimator for convergence analysis.') + estimator_fit = estimators_dispatch[estimator](**kwargs).fit + logger.info(f"Use {estimator} estimator for convergence analysis.") - logger.info('Begin forward analysis') + logger.info("Begin forward analysis") forward_list = [] forward_error_list = [] for i in range(1, num + 1): - logger.info('Forward analysis: {:.2f}%'.format(100 * i / num)) + logger.info("Forward analysis: {:.2f}%".format(100 * i / num)) sample = [] for data in df_list: - sample.append(data[:len(data) // num * i]) + sample.append(data[: len(data) // num * i]) sample = concat(sample) result = estimator_fit(sample) forward_list.append(result.delta_f_.iloc[0, -1]) - if estimator.lower() == 'bar': - error = np.sqrt(sum( - [result.d_delta_f_.iloc[i, i + 1] ** 2 - for i in range(len(result.d_delta_f_) - 1)])) + if estimator.lower() == "bar": + error = np.sqrt( + sum( + [ + result.d_delta_f_.iloc[i, i + 1] ** 2 + for i in range(len(result.d_delta_f_) - 1) + ] + ) + ) forward_error_list.append(error) else: forward_error_list.append(result.d_delta_f_.iloc[0, -1]) - logger.info('{:.2f} +/- {:.2f} kT'.format(forward_list[-1], - forward_error_list[-1])) + logger.info( + "{:.2f} +/- {:.2f} kT".format(forward_list[-1], forward_error_list[-1]) + ) - logger.info('Begin backward analysis') + logger.info("Begin backward analysis") backward_list = [] backward_error_list = [] for i in range(1, num + 1): - logger.info('Backward analysis: {:.2f}%'.format(100 * i / num)) + logger.info("Backward analysis: {:.2f}%".format(100 * i / num)) sample = [] for data in df_list: - sample.append(data[-len(data) // num * i:]) + sample.append(data[-len(data) // num * i :]) sample = concat(sample) result = estimator_fit(sample) backward_list.append(result.delta_f_.iloc[0, -1]) - if estimator.lower() == 'bar': - error = np.sqrt(sum( - [result.d_delta_f_.iloc[i, i + 1] ** 2 - for i in range(len(result.d_delta_f_) - 1)])) + if estimator.lower() == "bar": + error = np.sqrt( + sum( + [ + result.d_delta_f_.iloc[i, i + 1] ** 2 + for i in range(len(result.d_delta_f_) - 1) + ] + ) + ) backward_error_list.append(error) else: backward_error_list.append(result.d_delta_f_.iloc[0, -1]) - logger.info('{:.2f} +/- {:.2f} kT'.format(backward_list[-1], - backward_error_list[-1])) + logger.info( + "{:.2f} +/- {:.2f} kT".format(backward_list[-1], backward_error_list[-1]) + ) convergence = pd.DataFrame( - {'Forward': forward_list, - 'Forward_Error': forward_error_list, - 'Backward': backward_list, - 'Backward_Error': backward_error_list, - 'data_fraction': [i / num for i in range(1, num + 1)]}) + { + "Forward": forward_list, + "Forward_Error": forward_error_list, + "Backward": backward_list, + "Backward_Error": backward_error_list, + "data_fraction": [i / num for i in range(1, num + 1)], + } + ) convergence.attrs = df_list[0].attrs return convergence + def _cummean(vals, out_length): - '''The cumulative mean of an array. + """The cumulative mean of an array. This function computes the cumulative mean and shapes the result to the desired length. @@ -167,18 +186,19 @@ def _cummean(vals, out_length): .. versionadded:: 1.0.0 - ''' + """ in_length = len(vals) if in_length < out_length: out_length = in_length block = in_length // out_length - reshape = vals[: block*out_length].reshape(block, out_length) + reshape = vals[: block * out_length].reshape(block, out_length) mean = np.mean(reshape, axis=0) - result = np.cumsum(mean) / np.arange(1, out_length+1) + result = np.cumsum(mean) / np.arange(1, out_length + 1) return result + def fwdrev_cumavg_Rc(series, precision=0.01, tol=2): - '''Generate the convergence criteria :math:`R_c` for a single simulation. + """Generate the convergence criteria :math:`R_c` for a single simulation. The input will be :class:`pandas.Series` generated by :func:`~alchemlyb.preprocessing.subsampling.decorrelate_u_nk` or @@ -241,7 +261,7 @@ def fwdrev_cumavg_Rc(series, precision=0.01, tol=2): .. _`equation 16`: https://www.ncbi.nlm.nih.gov/pmc/articles/PMC8397498/#FD16 - ''' + """ series = to_kT(series) array = series.to_numpy() out_length = int(1 / precision) @@ -250,9 +270,12 @@ def fwdrev_cumavg_Rc(series, precision=0.01, tol=2): length = len(g_forward) convergence = pd.DataFrame( - {'Forward': g_forward, - 'Backward': g_backward, - 'data_fraction': [i / length for i in range(1, length + 1)]}) + { + "Forward": g_forward, + "Backward": g_backward, + "data_fraction": [i / length for i in range(1, length + 1)], + } + ) convergence.attrs = series.attrs # Final value @@ -270,8 +293,9 @@ def fwdrev_cumavg_Rc(series, precision=0.01, tol=2): # the same as this branch will be triggered. return 1.0, convergence + def A_c(series_list, precision=0.01, tol=2): - '''Generate the ensemble convergence criteria :math:`A_c` for a set of simulations. + """Generate the ensemble convergence criteria :math:`A_c` for a set of simulations. The input is a :class:`list` of :class:`pandas.Series` generated by :func:`~alchemlyb.preprocessing.subsampling.decorrelate_u_nk` or @@ -317,11 +341,11 @@ def A_c(series_list, precision=0.01, tol=2): .. _`equation 18`: https://www.ncbi.nlm.nih.gov/pmc/articles/PMC8397498/#FD18 - ''' - logger = logging.getLogger('alchemlyb.convergence.A_c') + """ + logger = logging.getLogger("alchemlyb.convergence.A_c") n_R_c = len(series_list) R_c_list = [fwdrev_cumavg_Rc(series, precision, tol)[0] for series in series_list] - logger.info(f'R_c list: {R_c_list}') + logger.info(f"R_c list: {R_c_list}") # Integrate the R_c_list <= R_c over the range of 0 to 1 array_01 = np.hstack((R_c_list, [0, 1])) sorted_array = np.sort(np.unique(array_01)) @@ -330,6 +354,6 @@ def A_c(series_list, precision=0.01, tol=2): if i == 0: continue else: - d_R_c = sorted_array[-i] - sorted_array[-i-1] + d_R_c = sorted_array[-i] - sorted_array[-i - 1] result += d_R_c * sum(R_c_list <= element) / n_R_c return result diff --git a/src/alchemlyb/estimators/__init__.py b/src/alchemlyb/estimators/__init__.py index ca48015b..4b4e7771 100644 --- a/src/alchemlyb/estimators/__init__.py +++ b/src/alchemlyb/estimators/__init__.py @@ -1,5 +1,5 @@ -from .mbar_ import MBAR, AutoMBAR from .bar_ import BAR +from .mbar_ import MBAR, AutoMBAR from .ti_ import TI FEP_ESTIMATORS = [MBAR.__name__, AutoMBAR.__name__, BAR.__name__] diff --git a/src/alchemlyb/estimators/bar_.py b/src/alchemlyb/estimators/bar_.py index 3a7150b2..7bf39bc7 100644 --- a/src/alchemlyb/estimators/bar_.py +++ b/src/alchemlyb/estimators/bar_.py @@ -1,11 +1,11 @@ import numpy as np import pandas as pd - -from sklearn.base import BaseEstimator from pymbar import BAR as BAR_ +from sklearn.base import BaseEstimator from .base import _EstimatorMixOut + class BAR(BaseEstimator, _EstimatorMixOut): """Bennett acceptance ratio (BAR). @@ -57,7 +57,13 @@ class BAR(BaseEstimator, _EstimatorMixOut): """ - def __init__(self, maximum_iterations=10000, relative_tolerance=1.0e-7, method='false-position', verbose=False): + def __init__( + self, + maximum_iterations=10000, + relative_tolerance=1.0e-7, + method="false-position", + verbose=False, + ): self.maximum_iterations = maximum_iterations self.relative_tolerance = relative_tolerance @@ -87,7 +93,10 @@ def fit(self, u_nk): # group u_nk by lambda states groups = u_nk.groupby(level=u_nk.index.names[1:]) - N_k = [(len(groups.get_group(i)) if i in groups.groups else 0) for i in u_nk.columns] + N_k = [ + (len(groups.get_group(i)) if i in groups.groups else 0) + for i in u_nk.columns + ] # Now get free energy differences and their uncertainties for each step deltas = np.array([]) @@ -96,19 +105,22 @@ def fit(self, u_nk): # get us from lambda step k uk = groups.get_group(self._states_[k]) # get w_F - w_f = uk.iloc[:, k+1] - uk.iloc[:, k] + w_f = uk.iloc[:, k + 1] - uk.iloc[:, k] # get us from lambda step k+1 - uk1 = groups.get_group(self._states_[k+1]) + uk1 = groups.get_group(self._states_[k + 1]) # get w_R - w_r = uk1.iloc[:, k] - uk1.iloc[:, k+1] + w_r = uk1.iloc[:, k] - uk1.iloc[:, k + 1] # now determine df and ddf using pymbar.BAR - df, ddf = BAR_(w_f, w_r, - method=self.method, - maximum_iterations=self.maximum_iterations, - relative_tolerance=self.relative_tolerance, - verbose=self.verbose) + df, ddf = BAR_( + w_f, + w_r, + method=self.method, + maximum_iterations=self.maximum_iterations, + relative_tolerance=self.relative_tolerance, + verbose=self.verbose, + ) deltas = np.append(deltas, df) d_deltas = np.append(d_deltas, ddf**2) @@ -121,14 +133,14 @@ def fit(self, u_nk): out = [] dout = [] for i in range(len(deltas) - j): - out.append(deltas[i:i + j + 1].sum()) + out.append(deltas[i : i + j + 1].sum()) # See https://github.com/alchemistry/alchemlyb/pull/60#issuecomment-430720742 # Error estimate generated by BAR ARE correlated # Use the BAR uncertainties between two neighbour states if j == 0: - dout.append(d_deltas[i:i + j + 1].sum()) + dout.append(d_deltas[i : i + j + 1].sum()) # Other uncertainties are unknown at this point else: dout.append(np.nan) @@ -137,14 +149,14 @@ def fit(self, u_nk): ad_delta += np.diagflat(np.array(dout), k=j + 1) # yield standard delta_f_ free energies between each state - self._delta_f_ = pd.DataFrame(adelta - adelta.T, - columns=self._states_, - index=self._states_) + self._delta_f_ = pd.DataFrame( + adelta - adelta.T, columns=self._states_, index=self._states_ + ) # yield standard deviation d_delta_f_ between each state - self._d_delta_f_ = pd.DataFrame(np.sqrt(ad_delta + ad_delta.T), - columns=self._states_, - index=self._states_) + self._d_delta_f_ = pd.DataFrame( + np.sqrt(ad_delta + ad_delta.T), columns=self._states_, index=self._states_ + ) self._delta_f_.attrs = u_nk.attrs self._d_delta_f_.attrs = u_nk.attrs diff --git a/src/alchemlyb/estimators/base.py b/src/alchemlyb/estimators/base.py index e6b1b8be..93f3da8a 100644 --- a/src/alchemlyb/estimators/base.py +++ b/src/alchemlyb/estimators/base.py @@ -1,9 +1,11 @@ -class _EstimatorMixOut(): - '''This class creates view for the d_delta_f_, delta_f_, states_ for the - estimator class to consume.''' +class _EstimatorMixOut: + """This class creates view for the d_delta_f_, delta_f_, states_ for the + estimator class to consume.""" + _d_delta_f_ = None _delta_f_ = None _states_ = None + @property def d_delta_f_(self): return self._d_delta_f_ @@ -15,4 +17,3 @@ def delta_f_(self): @property def states_(self): return self._states_ - \ No newline at end of file diff --git a/src/alchemlyb/estimators/mbar_.py b/src/alchemlyb/estimators/mbar_.py index d34434e9..4759d687 100644 --- a/src/alchemlyb/estimators/mbar_.py +++ b/src/alchemlyb/estimators/mbar_.py @@ -1,12 +1,12 @@ -import numpy as np -import pandas as pd import logging -from sklearn.base import BaseEstimator +import pandas as pd import pymbar +from sklearn.base import BaseEstimator from .base import _EstimatorMixOut + class MBAR(BaseEstimator, _EstimatorMixOut): """Multi-state Bennett acceptance ratio (MBAR). @@ -62,14 +62,20 @@ class MBAR(BaseEstimator, _EstimatorMixOut): `delta_f_`, `d_delta_f_`, `states_` are view of the original object. """ - def __init__(self, maximum_iterations=10000, relative_tolerance=1.0e-7, - initial_f_k=None, method='hybr', verbose=False): + def __init__( + self, + maximum_iterations=10000, + relative_tolerance=1.0e-7, + initial_f_k=None, + method="hybr", + verbose=False, + ): self.maximum_iterations = maximum_iterations self.relative_tolerance = relative_tolerance self.initial_f_k = initial_f_k self.method = method self.verbose = verbose - self.logger = logging.getLogger('alchemlyb.estimators.MBAR') + self.logger = logging.getLogger("alchemlyb.estimators.MBAR") # handle for pymbar.MBAR object self._mbar = None @@ -90,22 +96,24 @@ def fit(self, u_nk): u_nk = u_nk.sort_index(level=u_nk.index.names[1:]) groups = u_nk.groupby(level=u_nk.index.names[1:]) - N_k = [(len(groups.get_group(i)) if i in groups.groups else 0) for i in - u_nk.columns] + N_k = [ + (len(groups.get_group(i)) if i in groups.groups else 0) + for i in u_nk.columns + ] self._states_ = u_nk.columns.values.tolist() # Prepare the solver_protocol as stated in https://github.com/choderalab/pymbar/issues/419#issuecomment-803714103 - solver_options = {"maximum_iterations": self.maximum_iterations, - "verbose": self.verbose} - solver_protocol = {"method": self.method, - "options": solver_options} + solver_options = { + "maximum_iterations": self.maximum_iterations, + "verbose": self.verbose, + } + solver_protocol = {"method": self.method, "options": solver_options} self._mbar, out = self._do_MBAR(u_nk, N_k, solver_protocol) - free_energy_differences = [pd.DataFrame(i, - columns=self._states_, - index=self._states_) for i in - out] + free_energy_differences = [ + pd.DataFrame(i, columns=self._states_, index=self._states_) for i in out + ] (self._delta_f_, self._d_delta_f_, self.theta_) = free_energy_differences @@ -118,15 +126,20 @@ def predict(self, u_ln): pass def _do_MBAR(self, u_nk, N_k, solver_protocol): - mbar = pymbar.MBAR(u_nk.T, N_k, - relative_tolerance=self.relative_tolerance, - initial_f_k=self.initial_f_k, - solver_protocol=(solver_protocol,)) - self.logger.info("Solved MBAR equations with method %r and " - "maximum_iterations=%d, relative_tolerance=%g", - solver_protocol['method'], - solver_protocol['options']['maximum_iterations'], - self.relative_tolerance) + mbar = pymbar.MBAR( + u_nk.T, + N_k, + relative_tolerance=self.relative_tolerance, + initial_f_k=self.initial_f_k, + solver_protocol=(solver_protocol,), + ) + self.logger.info( + "Solved MBAR equations with method %r and " + "maximum_iterations=%d, relative_tolerance=%g", + solver_protocol["method"], + solver_protocol["options"]["maximum_iterations"], + self.relative_tolerance, + ) # set attributes out = mbar.getFreeEnergyDifferences(return_theta=True) return mbar, out @@ -145,7 +158,7 @@ def overlap_matrix(self): --------- pymbar.mbar.MBAR.computeOverlap """ - return self._mbar.computeOverlap()['matrix'] + return self._mbar.computeOverlap()["matrix"] class AutoMBAR(MBAR): @@ -188,31 +201,42 @@ class AutoMBAR(MBAR): .. versionchanged:: 1.0.0 AutoMBAR accepts the `method` argument. """ - def __init__(self, maximum_iterations=10000, relative_tolerance=1.0e-7, - initial_f_k=None, verbose=False, method=None): - super().__init__(maximum_iterations=maximum_iterations, - relative_tolerance=relative_tolerance, - initial_f_k=initial_f_k, - verbose=verbose, method=method) - self.logger = logging.getLogger('alchemlyb.estimators.AutoMBAR') + + def __init__( + self, + maximum_iterations=10000, + relative_tolerance=1.0e-7, + initial_f_k=None, + verbose=False, + method=None, + ): + super().__init__( + maximum_iterations=maximum_iterations, + relative_tolerance=relative_tolerance, + initial_f_k=initial_f_k, + verbose=verbose, + method=method, + ) + self.logger = logging.getLogger("alchemlyb.estimators.AutoMBAR") def _do_MBAR(self, u_nk, N_k, solver_protocol): if solver_protocol["method"] is None: - self.logger.info('Initialise the automatic routine of the MBAR ' - 'estimator.') + self.logger.info( + "Initialise the automatic routine of the MBAR " "estimator." + ) # Try the fastest method first try: - self.logger.info('Trying the hybr method.') - solver_protocol["method"] = 'hybr' + self.logger.info("Trying the hybr method.") + solver_protocol["method"] = "hybr" mbar, out = super()._do_MBAR(u_nk, N_k, solver_protocol) except pymbar.utils.ParameterError: try: - self.logger.info('Trying the adaptive method.') - solver_protocol["method"] = 'adaptive' + self.logger.info("Trying the adaptive method.") + solver_protocol["method"] = "adaptive" mbar, out = super()._do_MBAR(u_nk, N_k, solver_protocol) except pymbar.utils.ParameterError: - self.logger.info('Trying the BFGS method.') - solver_protocol["method"] = 'BFGS' + self.logger.info("Trying the BFGS method.") + solver_protocol["method"] = "BFGS" mbar, out = super()._do_MBAR(u_nk, N_k, solver_protocol) return mbar, out else: diff --git a/src/alchemlyb/estimators/ti_.py b/src/alchemlyb/estimators/ti_.py index e01b8f72..bef379cc 100644 --- a/src/alchemlyb/estimators/ti_.py +++ b/src/alchemlyb/estimators/ti_.py @@ -1,10 +1,10 @@ import numpy as np import pandas as pd - from sklearn.base import BaseEstimator from .base import _EstimatorMixOut + class TI(BaseEstimator, _EstimatorMixOut): """Thermodynamic integration (TI). @@ -71,43 +71,49 @@ def fit(self, dHdl): dl = means.reset_index()[means.index.names[:]].diff().iloc[1:].values # apply trapezoid rule to obtain DF between each adjacent state - deltas = (dl * (means.iloc[:-1].values + means.iloc[1:].values)/2).sum(axis=1) + deltas = (dl * (means.iloc[:-1].values + means.iloc[1:].values) / 2).sum(axis=1) # build matrix of deltas between each state - adelta = np.zeros((len(deltas)+1, len(deltas)+1)) + adelta = np.zeros((len(deltas) + 1, len(deltas) + 1)) ad_delta = np.zeros_like(adelta) for j in range(len(deltas)): out = [] dout = [] for i in range(len(deltas) - j): - out.append(deltas[i] + deltas[i+1:i+j+1].sum()) + out.append(deltas[i] + deltas[i + 1 : i + j + 1].sum()) # Define additional zero lambda a = [0.0] * len(l_types) # Define dl series' with additional zero lambda on the left and right - dll = np.insert(dl[i:i + j + 1], 0, [a], axis=0) - dlr = np.append(dl[i:i + j + 1], [a], axis=0) + dll = np.insert(dl[i : i + j + 1], 0, [a], axis=0) + dlr = np.append(dl[i : i + j + 1], [a], axis=0) # Get a series of the form: x1, x1 + x2, ..., x(n-1) + x(n), x(n) dllr = dll + dlr # Append deviation of free energy difference between state i and i+j+1 - dout.append((dllr ** 2 * variances.iloc[i:i + j + 2].values / 4).sum(axis=1).sum()) - adelta += np.diagflat(np.array(out), k=j+1) - ad_delta += np.diagflat(np.array(dout), k=j+1) + dout.append( + (dllr**2 * variances.iloc[i : i + j + 2].values / 4) + .sum(axis=1) + .sum() + ) + adelta += np.diagflat(np.array(out), k=j + 1) + ad_delta += np.diagflat(np.array(dout), k=j + 1) # yield standard delta_f_ free energies between each state - self._delta_f_ = pd.DataFrame(adelta - adelta.T, - columns=means.index.values, - index=means.index.values) + self._delta_f_ = pd.DataFrame( + adelta - adelta.T, columns=means.index.values, index=means.index.values + ) self.dhdl = means # yield standard deviation d_delta_f_ between each state - self._d_delta_f_ = pd.DataFrame(np.sqrt(ad_delta + ad_delta.T), - columns=variances.index.values, - index=variances.index.values) + self._d_delta_f_ = pd.DataFrame( + np.sqrt(ad_delta + ad_delta.T), + columns=variances.index.values, + index=variances.index.values, + ) self._states_ = means.index.values.tolist() @@ -135,7 +141,9 @@ def separate_dhdl(self): """ if len(self.dhdl.index.names) == 1: name = self.dhdl.columns[0] - return [self.dhdl[name], ] + return [ + self.dhdl[name], + ] dhdl_list = [] # get the lambda names l_types = self.dhdl.index.names @@ -143,14 +151,14 @@ def separate_dhdl(self): # Fix issue #148, where for pandas == 1.3.0 # lambdas = self.dhdl.reset_index()[list(l_types)] lambdas = self.dhdl.reset_index()[l_types] - diff = lambdas.diff().to_numpy(dtype='bool') + diff = lambdas.diff().to_numpy(dtype="bool") # diff will give the first row as NaN so need to fix that diff[0, :] = diff[1, :] # Make sure that the start point is set to true as well diff[:-1, :] = diff[:-1, :] | diff[1:, :] for i in range(len(l_types)): - if any(diff[:,i]): - new = self.dhdl.iloc[diff[:,i], i] + if any(diff[:, i]): + new = self.dhdl.iloc[diff[:, i], i] # drop all other index for l in l_types: if l != l_types[i]: @@ -158,4 +166,3 @@ def separate_dhdl(self): new.attrs = self.dhdl.attrs dhdl_list.append(new) return dhdl_list - diff --git a/src/alchemlyb/parsing/__init__.py b/src/alchemlyb/parsing/__init__.py index 60165ac9..dc048732 100644 --- a/src/alchemlyb/parsing/__init__.py +++ b/src/alchemlyb/parsing/__init__.py @@ -1,33 +1,38 @@ from functools import wraps + def _init_attrs(func): - '''Add temperature to the parsed dataframe. + """Add temperature to the parsed dataframe. The temperature is added to the dataframe as dataframe.attrs['temperature'] and the energy unit is initiated as dataframe.attrs['energy_unit'] = 'kT'. - ''' + """ + @wraps(func) def wrapper(outfile, T, *args, **kwargs): dataframe = func(outfile, T, *args, **kwargs) if dataframe is not None: - dataframe.attrs['temperature'] = T - dataframe.attrs['energy_unit'] = 'kT' + dataframe.attrs["temperature"] = T + dataframe.attrs["energy_unit"] = "kT" return dataframe + return wrapper def _init_attrs_dict(func): - '''Add temperature and energy units to the parsed dataframes. + """Add temperature and energy units to the parsed dataframes. The temperature is added to the dataframe as dataframe.attrs['temperature'] and the energy unit is initiated as dataframe.attrs['energy_unit'] = 'kT'. - ''' + """ + @wraps(func) def wrapper(outfile, T, *args, **kwargs): dict_with_df = func(outfile, T, *args, **kwargs) for k in dict_with_df.keys(): if dict_with_df[k] is not None: - dict_with_df[k].attrs['temperature'] = T - dict_with_df[k].attrs['energy_unit'] = 'kT' + dict_with_df[k].attrs["temperature"] = T + dict_with_df[k].attrs["energy_unit"] = "kT" return dict_with_df + return wrapper diff --git a/src/alchemlyb/parsing/amber.py b/src/alchemlyb/parsing/amber.py index 8e23ada1..d129a064 100644 --- a/src/alchemlyb/parsing/amber.py +++ b/src/alchemlyb/parsing/amber.py @@ -11,21 +11,21 @@ """ -import re import logging +import re -import pandas as pd import numpy as np +import pandas as pd -from .util import anyopen from . import _init_attrs_dict +from .util import anyopen from ..postprocessors.units import R_kJmol, kJ2kcal logger = logging.getLogger("alchemlyb.parsers.Amber") k_b = R_kJmol * kJ2kcal -_FP_RE = r'[+-]?(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)?' +_FP_RE = r"[+-]?(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)?" def convert_to_pandas(file_datum): @@ -39,10 +39,13 @@ def convert_to_pandas(file_datum): data_dic["lambdas"].append(file_datum.clambda) frame_time = file_datum.t0 + (frame_index + 1) * file_datum.dt * file_datum.ntpr data_dic["time"].append(frame_time) - df = pd.DataFrame(data_dic["dHdl"], columns=["dHdl"], - index=pd.Index(data_dic["time"], name='time', dtype='Float64')) + df = pd.DataFrame( + data_dic["dHdl"], + columns=["dHdl"], + index=pd.Index(data_dic["time"], name="time", dtype="Float64"), + ) df["lambdas"] = data_dic["lambdas"][0] - df = df.reset_index().set_index(['time'] + ['lambdas']) + df = df.reset_index().set_index(["time"] + ["lambdas"]) return df @@ -59,7 +62,7 @@ def _pre_gen(it, first): return -class SectionParser(): +class SectionParser: """ A simple parser to extract data values from sections. """ @@ -68,7 +71,7 @@ def __init__(self, filename): """Opens a file according to its file type.""" self.filename = filename try: - self.fileh = anyopen(self.filename, 'r') + self.fileh = anyopen(self.filename, "r") except: logger.exception("Cannot open file %s", filename) raise @@ -93,7 +96,7 @@ def skip_after(self, pattern): break return Found_pattern - def extract_section(self, start, end, fields, limit=None, extra=''): + def extract_section(self, start, end, fields, limit=None, extra=""): """ Extract data values (int, float) in fields from a section marked with start and end regexes. Do not read further than @@ -109,15 +112,15 @@ def extract_section(self, start, end, fields, limit=None, extra=''): if inside: if re.search(end, line): break - lines.append(line.rstrip('\n')) - line = ''.join(lines) + lines.append(line.rstrip("\n")) + line = "".join(lines) result = [] for field in fields: - match = re.search(fr' {field}\s*=\s*(\*+|{_FP_RE}|\d+)', line) + match = re.search(rf" {field}\s*=\s*(\*+|{_FP_RE}|\d+)", line) if match: value = match.group(1) - if '*' in value: # catch fortran format overflow - result.append(float('Inf')) + if "*" in value: # catch fortran format overflow + result.append(float("Inf")) else: try: result.append(int(value)) @@ -146,12 +149,21 @@ def __exit__(self, typ, value, traceback): self.close() -class FEData(): +class FEData: """A simple struct container to collect data from individual files.""" - __slots__ = ['clambda', 't0', 'dt', 'T', 'ntpr', 'gradients', - 'mbar_energies', - 'have_mbar', 'mbar_lambdas', 'mbar_lambda_idx'] + __slots__ = [ + "clambda", + "t0", + "dt", + "T", + "ntpr", + "gradients", + "mbar_energies", + "have_mbar", + "mbar_lambdas", + "mbar_lambda_idx", + ] def __init__(self): self.clambda = -1.0 @@ -170,7 +182,7 @@ def file_validation(outfile): """ Function that validate and parse an AMBER output file. :exc:`ValueError` are risen if inconsinstencies in the input file are found. - + Parameters ---------- outfile : str @@ -189,76 +201,81 @@ def file_validation(outfile): if not line: logger.error("The file %s does not contain any data, it's empty.", outfile) - raise ValueError(f'file {outfile} does not contain any data.') + raise ValueError(f"file {outfile} does not contain any data.") - if not secp.skip_after('^ 2. CONTROL DATA FOR THE RUN'): + if not secp.skip_after("^ 2. CONTROL DATA FOR THE RUN"): logger.error('No "CONTROL DATA" section found in file %s.', outfile) raise ValueError(f'no "CONTROL DATA" section found in file {outfile}') - ntpr, = secp.extract_section('^Nature and format of output:', '^$', - ['ntpr']) - nstlim, dt = secp.extract_section('Molecular dynamics:', '^$', - ['nstlim', 'dt']) - T, = secp.extract_section('temperature regulation:', '^$', - ['temp0']) + (ntpr,) = secp.extract_section("^Nature and format of output:", "^$", ["ntpr"]) + nstlim, dt = secp.extract_section("Molecular dynamics:", "^$", ["nstlim", "dt"]) + (T,) = secp.extract_section("temperature regulation:", "^$", ["temp0"]) if not T: logger.error('No valid "temp0" record found in file %s.', outfile) raise ValueError(f'no valid "temp0" record found in file {outfile}') - clambda, = secp.extract_section('^Free energy options:', '^$', - ['clambda'], '^---') + (clambda,) = secp.extract_section( + "^Free energy options:", "^$", ["clambda"], "^---" + ) if clambda is None: - logger.error('No free energy section found in file %s, "clambda" was None.', outfile) - raise ValueError(f'no free energy section found in file {outfile}') + logger.error( + 'No free energy section found in file %s, "clambda" was None.', outfile + ) + raise ValueError(f"no free energy section found in file {outfile}") mbar_ndata = 0 - have_mbar, mbar_ndata, mbar_states = secp.extract_section('^FEP MBAR options:', - '^$', - ['ifmbar', - 'bar_intervall', - 'mbar_states'], - '^---') + have_mbar, mbar_ndata, mbar_states = secp.extract_section( + "^FEP MBAR options:", + "^$", + ["ifmbar", "bar_intervall", "mbar_states"], + "^---", + ) if have_mbar: mbar_ndata = int(nstlim / mbar_ndata) mbar_lambdas = _process_mbar_lambdas(secp) file_datum.mbar_lambdas = mbar_lambdas - clambda_str = f'{clambda:6.4f}' + clambda_str = f"{clambda:6.4f}" if clambda_str not in mbar_lambdas: - logger.warning('WARNING: lamba %s not contained in set of ' - 'MBAR lambas: %s\nNot using MBAR.', - clambda_str, ', '.join(mbar_lambdas)) + logger.warning( + "WARNING: lamba %s not contained in set of " + "MBAR lambas: %s\nNot using MBAR.", + clambda_str, + ", ".join(mbar_lambdas), + ) have_mbar = False else: mbar_nlambda = len(mbar_lambdas) if mbar_nlambda != mbar_states: logger.error( - 'the number of lambda windows read (%s)' - 'is different from what expected (%d)', - ','.join(mbar_lambdas), mbar_states) + "the number of lambda windows read (%s)" + "is different from what expected (%d)", + ",".join(mbar_lambdas), + mbar_states, + ) raise ValueError( - f'the number of lambda windows read ({mbar_nlambda})' - f' is different from what expected ({mbar_states})') + f"the number of lambda windows read ({mbar_nlambda})" + f" is different from what expected ({mbar_states})" + ) mbar_lambda_idx = mbar_lambdas.index(clambda_str) file_datum.mbar_lambda_idx = mbar_lambda_idx for _ in range(mbar_nlambda): file_datum.mbar_energies.append([]) - if not secp.skip_after('^ 3. ATOMIC '): + if not secp.skip_after("^ 3. ATOMIC "): logger.error('No "ATOMIC" section found in the file %s.', outfile) raise ValueError(f'no "ATOMIC" section found in file {outfile}') - t0, = secp.extract_section('^ begin time', '^$', ['coords']) + (t0,) = secp.extract_section("^ begin time", "^$", ["coords"]) if t0 is None: - logger.error('No starting simulation time in file %s.', outfile) - raise ValueError(f'No starting simulation time in file {outfile}') + logger.error("No starting simulation time in file %s.", outfile) + raise ValueError(f"No starting simulation time in file {outfile}") - if not secp.skip_after('^ 4. RESULTS'): + if not secp.skip_after("^ 4. RESULTS"): logger.error('No "RESULTS" section found in the file %s.', outfile) raise ValueError(f'no "RESULTS" section found in file {outfile}') - file_datum.clambda = clambda file_datum.t0 = t0 file_datum.dt = dt @@ -293,13 +310,13 @@ def extract(outfile, T): """ - beta = 1/(k_b * T) + beta = 1 / (k_b * T) file_datum = file_validation(outfile) if not np.isclose(T, file_datum.T, atol=0.01): - msg = f'The temperature read from the input file ({file_datum.T:.2f} K)' - msg += f' is different from the temperature passed as parameter ({T:.2f} K)' + msg = f"The temperature read from the input file ({file_datum.T:.2f} K)" + msg += f" is different from the temperature passed as parameter ({T:.2f} K)" logger.error(msg) raise ValueError(msg) @@ -311,18 +328,19 @@ def extract(outfile, T): old_nstep = -1 for line in secp: if " A V E R A G E S O V E R" in line: - _ = secp.skip_after('^|=========================================') - elif line.startswith(' NSTEP'): - nstep, dvdl = secp.extract_section('^ NSTEP', '^ ---', - ['NSTEP', 'DV/DL'], - extra=line) + _ = secp.skip_after("^|=========================================") + elif line.startswith(" NSTEP"): + nstep, dvdl = secp.extract_section( + "^ NSTEP", "^ ---", ["NSTEP", "DV/DL"], extra=line + ) if nstep != old_nstep and dvdl is not None and nstep is not None: file_datum.gradients.append(dvdl) nensec += 1 old_nstep = nstep - elif line.startswith('MBAR Energy analysis') and file_datum.have_mbar: - mbar = secp.extract_section('^MBAR', '^ ---', file_datum.mbar_lambdas, - extra=line) + elif line.startswith("MBAR Energy analysis") and file_datum.have_mbar: + mbar = secp.extract_section( + "^MBAR", "^ ---", file_datum.mbar_lambdas, extra=line + ) if None in mbar: msg = "Something strange parsing the following MBAR section." @@ -335,40 +353,48 @@ def extract(outfile, T): if energy > 0.0: high_E_cnt += 1 - file_datum.mbar_energies[lmbda].append(beta * (energy - reference_energy)) - elif line == ' 5. TIMINGS\n': + file_datum.mbar_energies[lmbda].append( + beta * (energy - reference_energy) + ) + elif line == " 5. TIMINGS\n": finished = True break if high_E_cnt: - logger.warning('%i MBAR energ%s > 0.0 kcal/mol', - high_E_cnt, 'ies are' if high_E_cnt > 1 else 'y is') + logger.warning( + "%i MBAR energ%s > 0.0 kcal/mol", + high_E_cnt, + "ies are" if high_E_cnt > 1 else "y is", + ) if not finished: - logger.warning('WARNING: file %s is a prematurely terminated run', outfile) + logger.warning("WARNING: file %s is a prematurely terminated run", outfile) if file_datum.have_mbar: mbar_time = [ file_datum.t0 + (frame_index + 1) * file_datum.dt * file_datum.ntpr - for frame_index in range(len(file_datum.mbar_energies[0]))] + for frame_index in range(len(file_datum.mbar_energies[0])) + ] mbar_df = pd.DataFrame( file_datum.mbar_energies, index=np.array(file_datum.mbar_lambdas, dtype=np.float64), columns=pd.MultiIndex.from_arrays( - [mbar_time, np.repeat(file_datum.clambda, len(mbar_time))], names=['time', 'lambdas']) - ).T + [mbar_time, np.repeat(file_datum.clambda, len(mbar_time))], + names=["time", "lambdas"], + ), + ).T else: logger.info('WARNING: No MBAR energies found! "u_nk" entry will be None') mbar_df = None if not nensec: - logger.warning('WARNING: File %s does not contain any dV/dl data', outfile) + logger.warning("WARNING: File %s does not contain any dV/dl data", outfile) dHdl_df = None else: - logger.info('Read %s dV/dl data points in file %s', nensec, outfile) + logger.info("Read %s dV/dl data points in file %s", nensec, outfile) dHdl_df = convert_to_pandas(file_datum) - dHdl_df['dHdl'] *= beta + dHdl_df["dHdl"] *= beta return {"u_nk": mbar_df, "dHdl": dHdl_df} @@ -395,7 +421,7 @@ def extract_dHdl(outfile, T): """ extracted = extract(outfile, T) - return extracted['dHdl'] + return extracted["dHdl"] def extract_u_nk(outfile, T): @@ -421,7 +447,7 @@ def extract_u_nk(outfile, T): """ extracted = extract(outfile, T) - return extracted['u_nk'] + return extracted["u_nk"] def _process_mbar_lambdas(secp): @@ -441,15 +467,15 @@ def _process_mbar_lambdas(secp): mbar_lambdas = [] for line in secp: - if line.startswith(' MBAR - lambda values considered:'): + if line.startswith(" MBAR - lambda values considered:"): in_mbar = True continue if in_mbar: - if line.startswith(' Extra'): + if line.startswith(" Extra"): break - if 'total' in line: + if "total" in line: data = line.split() mbar_lambdas.extend(data[2:]) else: diff --git a/src/alchemlyb/parsing/gmx.py b/src/alchemlyb/parsing/gmx.py index 00267c66..a9f83498 100644 --- a/src/alchemlyb/parsing/gmx.py +++ b/src/alchemlyb/parsing/gmx.py @@ -1,15 +1,16 @@ """Parsers for extracting alchemical data from `Gromacs `_ output files. """ -import pandas as pd import numpy as np +import pandas as pd -from .util import anyopen from . import _init_attrs +from .util import anyopen from ..postprocessors.units import R_kJmol k_b = R_kJmol + @_init_attrs def extract_u_nk(xvg, T, filter=True): r"""Return reduced potentials `u_nk` from a Hamiltonian differences XVG file. @@ -60,9 +61,9 @@ def extract_u_nk(xvg, T, filter=True): """ h_col_match = r"\xD\f{}H \xl\f{}" - pv_col_match = 'pV' - u_col_match = ['Total Energy', 'Potential Energy'] - beta = 1/(k_b * T) + pv_col_match = "pV" + u_col_match = ["Total Energy", "Potential Energy"] + beta = 1 / (k_b * T) state, lambdas, statevec = _extract_state(xvg) @@ -82,7 +83,11 @@ def extract_u_nk(xvg, T, filter=True): pv = df[pv_cols[0]] # gromacs also gives us total/potential energy U directly; need this for reduced potential - u_cols = [col for col in df.columns if any(single_u_col_match in col for single_u_col_match in u_col_match)] + u_cols = [ + col + for col in df.columns + if any(single_u_col_match in col for single_u_col_match in u_col_match) + ] u = None if u_cols: u = df[u_cols[0]] @@ -90,7 +95,7 @@ def extract_u_nk(xvg, T, filter=True): u_k = dict() cols = list() for col in dH: - u_col = eval(col.split('to')[1]) + u_col = eval(col.split("to")[1]) # calculate reduced potential u_k = dH + pV + U u_k[u_col] = beta * dH[col].values if pv_cols: @@ -99,8 +104,9 @@ def extract_u_nk(xvg, T, filter=True): u_k[u_col] += beta * u.values cols.append(u_col) - u_k = pd.DataFrame(u_k, columns=cols, - index=pd.Index(times.values, name='time', dtype='Float64')) + u_k = pd.DataFrame( + u_k, columns=cols, index=pd.Index(times.values, name="time", dtype="Float64") + ) # create columns for each lambda, indicating state each row sampled from # if state is None run as expanded ensemble data or REX @@ -108,8 +114,8 @@ def extract_u_nk(xvg, T, filter=True): # if thermodynamic state is specified map thermodynamic # state data to lambda values, else (for REX) # define state based on the legend - if 'Thermodynamic state' in df: - ts_index = df.columns.get_loc('Thermodynamic state') + if "Thermodynamic state" in df: + ts_index = df.columns.get_loc("Thermodynamic state") thermo_state = df[df.columns[ts_index]] for i, l in enumerate(lambdas): v = [] @@ -128,13 +134,14 @@ def extract_u_nk(xvg, T, filter=True): u_k[l] = statevec # set up new multi-index - newind = ['time'] + lambdas + newind = ["time"] + lambdas u_k = u_k.reset_index().set_index(newind) - u_k.name = 'u_nk' + u_k.name = "u_nk" return u_k + @_init_attrs def extract_dHdl(xvg, T, filter=True): r"""Return gradients `dH/dl` from a Hamiltonian differences XVG file. @@ -182,7 +189,7 @@ def extract_dHdl(xvg, T, filter=True): parsed and is turned on by default. """ - beta = 1/(k_b * T) + beta = 1 / (k_b * T) headers = _get_headers(xvg) state, lambdas, statevec = _extract_state(xvg, headers) @@ -204,10 +211,13 @@ def extract_dHdl(xvg, T, filter=True): # rename columns to not include the word 'lambda', since we use this for # index below - cols = [l.split('-')[0] for l in lambdas] + cols = [l.split("-")[0] for l in lambdas] - dHdl = pd.DataFrame(dHdl.values, columns=cols, - index=pd.Index(times.values, name='time', dtype='Float64')) + dHdl = pd.DataFrame( + dHdl.values, + columns=cols, + index=pd.Index(times.values, name="time", dtype="Float64"), + ) # create columns for each lambda, indicating state each row sampled from # if state is None run as expanded ensemble data or REX @@ -215,8 +225,8 @@ def extract_dHdl(xvg, T, filter=True): # if thermodynamic state is specified map thermodynamic # state data to lambda values, else (for REX) # define state based on the legend - if 'Thermodynamic state' in df: - ts_index = df.columns.get_loc('Thermodynamic state') + if "Thermodynamic state" in df: + ts_index = df.columns.get_loc("Thermodynamic state") thermo_state = df[df.columns[ts_index]] for i, l in enumerate(lambdas): v = [] @@ -235,10 +245,10 @@ def extract_dHdl(xvg, T, filter=True): dHdl[l] = statevec # set up new multi-index - newind = ['time'] + lambdas - dHdl= dHdl.reset_index().set_index(newind) + newind = ["time"] + lambdas + dHdl = dHdl.reset_index().set_index(newind) - dHdl.name='dH/dl' + dHdl.name = "dH/dl" return dHdl @@ -289,34 +299,44 @@ def _extract_state(xvg, headers=None): state = None if headers is None: headers = _get_headers(xvg) - subtitle = _get_value_by_key(headers, 'subtitle') - if subtitle and 'state' in subtitle: - state = int(subtitle.split('state')[1].split(':')[0]) - lambdas = [word.strip(')(,') for word in subtitle.split() if 'lambda' in word] - statevec = eval(subtitle.strip().split(' = ')[-1].strip('"')) + subtitle = _get_value_by_key(headers, "subtitle") + if subtitle and "state" in subtitle: + state = int(subtitle.split("state")[1].split(":")[0]) + lambdas = [word.strip(")(,") for word in subtitle.split() if "lambda" in word] + statevec = eval(subtitle.strip().split(" = ")[-1].strip('"')) # if expanded ensemble data is used the state variable will never be assigned # parsing expanded ensemble data if state is None: lambdas = [] statevec = [] - for line in headers['_raw_lines']: - if ('legend' in line) and ('lambda' in line): - lambdas.append([word.strip(')(,') for word in line.split() if 'lambda' in word][0]) - if ('legend' in line) and (' to ' in line): - statevec.append(([float(i) for i in line.strip().split(' to ')[-1].strip('"()').split(',')])) + for line in headers["_raw_lines"]: + if ("legend" in line) and ("lambda" in line): + lambdas.append( + [word.strip(")(,") for word in line.split() if "lambda" in word][0] + ) + if ("legend" in line) and (" to " in line): + statevec.append( + ( + [ + float(i) + for i in line.strip() + .split(" to ")[-1] + .strip('"()') + .split(",") + ] + ) + ) return state, lambdas, statevec def _extract_legend(xvg): - """Extract information on state sampled for REX simulations. - - """ + """Extract information on state sampled for REX simulations.""" state_legend = {} - with anyopen(xvg, 'r') as f: + with anyopen(xvg, "r") as f: for line in f: - if ('legend' in line) and ('lambda' in line): + if ("legend" in line) and ("lambda" in line): state_legend[line.split()[4]] = float(line.split()[6].strip('"')) return state_legend @@ -344,31 +364,49 @@ def _extract_dataframe(xvg, headers=None, filter=filter): if headers is None: headers = _get_headers(xvg) - xaxis = _get_value_by_key(headers, 'xaxis', 'label') - names = [_get_value_by_key(headers, 's{}'.format(x), 'legend') for x in - range(len(headers)) if 's{}'.format(x) in headers] + xaxis = _get_value_by_key(headers, "xaxis", "label") + names = [ + _get_value_by_key(headers, "s{}".format(x), "legend") + for x in range(len(headers)) + if "s{}".format(x) in headers + ] cols = [xaxis] + names # march through column names, mark duplicates when found - cols = [col + "{}[duplicated]".format(i) if col in cols[:i] else col - for i, col, in enumerate(cols)] + cols = [ + col + "{}[duplicated]".format(i) if col in cols[:i] else col + for i, col, in enumerate(cols) + ] - header_cnt = len(headers['_raw_lines']) + header_cnt = len(headers["_raw_lines"]) if not filter: # assumes clean files - df = pd.read_csv(xvg, sep=r"\s+", header=None, skiprows=header_cnt, - na_filter=True, memory_map=True, names=cols, - dtype=np.float64, - float_precision='high') + df = pd.read_csv( + xvg, + sep=r"\s+", + header=None, + skiprows=header_cnt, + na_filter=True, + memory_map=True, + names=cols, + dtype=np.float64, + float_precision="high", + ) else: - df = pd.read_csv(xvg, sep=r"\s+", header=None, skiprows=header_cnt, - memory_map=True, on_bad_lines='skip') + df = pd.read_csv( + xvg, + sep=r"\s+", + header=None, + skiprows=header_cnt, + memory_map=True, + on_bad_lines="skip", + ) # If names=cols is passed to read_csv, rows with more than the # designated columns will be truncated and used instead of discarded. df.rename(columns={i: name for i, name in enumerate(cols)}, inplace=True) # If dtype=np.float64 and float_precision='high' are passed to read_csv, # 12.345.56 and - cannot be read. - df = df.apply(pd.to_numeric, errors='coerce') + df = df.apply(pd.to_numeric, errors="coerce") # drop duplicate df.dropna(inplace=True) @@ -423,7 +461,7 @@ def _parse_header(line, headers={}, depth=2): else: break - next_t["_val"] = ''.join(s[1:]).rstrip().strip('"') + next_t["_val"] = "".join(s[1:]).rstrip().strip('"') def _get_headers(xvg): @@ -484,17 +522,17 @@ def _get_headers(xvg): headers: dict """ - with anyopen(xvg, 'r') as f: - headers = { '_raw_lines': [] } + with anyopen(xvg, "r") as f: + headers = {"_raw_lines": []} for line in f: line = line.strip() if len(line) == 0: continue - if line.startswith('@'): + if line.startswith("@"): _parse_header(line, headers) - headers['_raw_lines'].append(line) - elif line.startswith('#'): - headers['_raw_lines'].append(line) + headers["_raw_lines"].append(line) + elif line.startswith("#"): + headers["_raw_lines"].append(line) continue # assuming to start a body section else: @@ -522,8 +560,8 @@ def _get_value_by_key(headers, key1, key2=None): val = None if key1 in headers: if key2 is not None and key2 in headers[key1]: - val = headers[key1][key2]['_val'] + val = headers[key1][key2]["_val"] else: - val = headers[key1]['_val'] + val = headers[key1]["_val"] return val diff --git a/src/alchemlyb/parsing/gomc.py b/src/alchemlyb/parsing/gomc.py index 7cf03af4..90124687 100644 --- a/src/alchemlyb/parsing/gomc.py +++ b/src/alchemlyb/parsing/gomc.py @@ -3,12 +3,13 @@ """ import pandas as pd -from .util import anyopen from . import _init_attrs +from .util import anyopen from ..postprocessors.units import R_kJmol k_b = R_kJmol + @_init_attrs def extract_u_nk(filename, T): """Return reduced potentials `u_nk` from a Hamiltonian differences dat file. @@ -34,9 +35,9 @@ def extract_u_nk(filename, T): dh_col_match = "dU/dL" h_col_match = "DelE" - pv_col_match = 'PV' - u_col_match = ['Total_En'] - beta = 1/(k_b * T) + pv_col_match = "PV" + u_col_match = ["Total_En"] + beta = 1 / (k_b * T) state, lambdas, statevec = _extract_state(filename) @@ -56,7 +57,11 @@ def extract_u_nk(filename, T): pv = df[pv_cols[0]] # GOMC also gives us total energy U directly; need this for reduced potential - u_cols = [col for col in df.columns if any(single_u_col_match in col for single_u_col_match in u_col_match)] + u_cols = [ + col + for col in df.columns + if any(single_u_col_match in col for single_u_col_match in u_col_match) + ] u = None if u_cols: u = df[u_cols[0]] @@ -64,7 +69,7 @@ def extract_u_nk(filename, T): u_k = dict() cols = list() for col in dH: - u_col = eval(col.split('->')[1][:-1]) + u_col = eval(col.split("->")[1][:-1]) # calculate reduced potential u_k = dH + pV + U u_k[u_col] = beta * dH[col].values if pv_cols: @@ -73,8 +78,9 @@ def extract_u_nk(filename, T): u_k[u_col] += beta * u.values cols.append(u_col) - u_k = pd.DataFrame(u_k, columns=cols, - index=pd.Index(times.values, name='time', dtype='Float64')) + u_k = pd.DataFrame( + u_k, columns=cols, index=pd.Index(times.values, name="time", dtype="Float64") + ) # Need to modify the lambda name cols = [l + "-lambda" for l in lambdas] @@ -83,13 +89,14 @@ def extract_u_nk(filename, T): u_k[l] = statevec[i] # set up new multi-index - newind = ['time'] + cols + newind = ["time"] + cols u_k = u_k.reset_index().set_index(newind) - u_k.name = 'u_nk' + u_k.name = "u_nk" return u_k + @_init_attrs def extract_dHdl(filename, T): """Return gradients `dH/dl` from a Hamiltonian differences free energy file. @@ -112,7 +119,7 @@ def extract_dHdl(filename, T): the constants used by the corresponding MD engine. """ - beta = 1/(k_b * T) + beta = 1 / (k_b * T) state, lambdas, statevec = _extract_state(filename) @@ -131,8 +138,11 @@ def extract_dHdl(filename, T): # make dimensionless dHdl *= beta - dHdl = pd.DataFrame(dHdl.values, columns=lambdas, - index=pd.Index(times.values, name='time', dtype='Float64')) + dHdl = pd.DataFrame( + dHdl.values, + columns=lambdas, + index=pd.Index(times.values, name="time", dtype="Float64"), + ) # Need to modify the lambda name cols = [l + "-lambda" for l in lambdas] @@ -141,10 +151,10 @@ def extract_dHdl(filename, T): dHdl[l] = statevec[i] # set up new multi-index - newind = ['time'] + cols - dHdl= dHdl.reset_index().set_index(newind) + newind = ["time"] + cols + dHdl = dHdl.reset_index().set_index(newind) - dHdl.name='dH/dl' + dHdl.name = "dH/dl" return dHdl @@ -180,33 +190,29 @@ def extract(filename, T): def _extract_state(filename): - """Extract information on state sampled, names of lambdas. - - """ + """Extract information on state sampled, names of lambdas.""" state = None - with anyopen(filename, 'r') as f: + with anyopen(filename, "r") as f: for line in f: - if ('#' in line) and ('State' in line): - state = int(line.split('State')[1].split(':')[0]) + if ("#" in line) and ("State" in line): + state = int(line.split("State")[1].split(":")[0]) # GOMC always print these two fields - lambdas = ['Coulomb', 'VDW'] - statevec = eval(line.strip().split(' = ')[-1]) + lambdas = ["Coulomb", "VDW"] + statevec = eval(line.strip().split(" = ")[-1]) break return state, lambdas, statevec def _extract_dataframe(filename): - """Extract a DataFrame from free energy data. - - """ + """Extract a DataFrame from free energy data.""" dh_col_match = "dU/dL" h_col_match = "DelE" - pv_col_match = 'PV' - u_col_match = 'Total_En' + pv_col_match = "PV" + u_col_match = "Total_En" xaxis = "time" - with anyopen(filename, 'r') as f: + with anyopen(filename, "r") as f: names = [] rows = [] for line in f: @@ -214,7 +220,7 @@ def _extract_dataframe(filename): if len(line) == 0: # avoid parsing empty line continue - elif line.startswith('#T'): + elif line.startswith("#T"): # this line has state information. No need to be parsed continue elif line.startswith("#Steps"): diff --git a/src/alchemlyb/parsing/namd.py b/src/alchemlyb/parsing/namd.py index c4181b1d..1647467c 100644 --- a/src/alchemlyb/parsing/namd.py +++ b/src/alchemlyb/parsing/namd.py @@ -1,13 +1,15 @@ """Parsers for extracting alchemical data from `NAMD `_ output files. """ -import pandas as pd -import numpy as np +import logging from os.path import basename from re import split -import logging -from .util import anyopen + +import numpy as np +import pandas as pd + from . import _init_attrs +from .util import anyopen from ..postprocessors.units import R_kJmol, kJ2kcal logger = logging.getLogger("alchemlyb.parsers.NAMD") @@ -21,12 +23,12 @@ def _filename_sort_key(s): This means that unlike with the standard Python sorted() function, "foo9" < "foo10". """ - return [int(t) if t.isdigit() else t.lower() for t in split(r'(\d+)', basename(s))] + return [int(t) if t.isdigit() else t.lower() for t in split(r"(\d+)", basename(s))] def _get_lambdas(fep_files): """Retrieves all lambda values included in the FEP files provided. - + We have to do this in order to tolerate truncated and restarted fepout files. The IDWS lambda is not present at the termination of the window, presumably for backwards compatibility with ParseFEP and probably other things. @@ -48,25 +50,25 @@ def _get_lambdas(fep_files): endpoint_windows = [] for fep_file in sorted(fep_files, key=_filename_sort_key): - with anyopen(fep_file, 'r') as f: + with anyopen(fep_file, "r") as f: for line in f: l = line.strip().split() # We might not have a #NEW line so make the best guess - if l[0] == '#NEW': + if l[0] == "#NEW": lambda1, lambda2 = float(l[6]), float(l[8]) - lambda_idws = float(l[10]) if 'LAMBDA_IDWS' in l else None - elif l[0] == '#Free': + lambda_idws = float(l[10]) if "LAMBDA_IDWS" in l else None + elif l[0] == "#Free": lambda1, lambda2, lambda_idws = float(l[7]), float(l[8]), None else: # We only care about lines with lambda values. No need to # do all that other processing below for every line - continue # pragma: no cover + continue # pragma: no cover # Keep track of whether the lambda values are increasing or decreasing, so we can return # a sorted list of the lambdas in the correct order. # If it changes during parsing of this set of fepout files, then we know something is wrong - + # Keep track of endpoints separately since in IDWS runs there must be one of opposite direction if 0.0 in (lambda1, lambda2) or 1.0 in (lambda1, lambda2): endpoint_windows.append((lambda1, lambda2)) @@ -78,23 +80,35 @@ def _get_lambdas(fep_files): is_ascending.add(lambda1 > lambda_idws) if len(is_ascending) > 1: - raise ValueError(f'Lambda values change direction in {fep_file}, relative to the other files: {lambda1} -> {lambda2} (IDWS: {lambda_idws})') + raise ValueError( + f"Lambda values change direction in {fep_file}, relative to the other files: {lambda1} -> {lambda2} (IDWS: {lambda_idws})" + ) # Make sure the lambda2 values are consistent if lambda1 in lambda_fwd_map and lambda_fwd_map[lambda1] != lambda2: - logger.error(f'fwd: lambda1 {lambda1} has lambda2 {lambda_fwd_map[lambda1]} in {fep_file} but it has already been {lambda2}') - raise ValueError('More than one lambda2 value for a particular lambda1') + logger.error( + f"fwd: lambda1 {lambda1} has lambda2 {lambda_fwd_map[lambda1]} in {fep_file} but it has already been {lambda2}" + ) + raise ValueError( + "More than one lambda2 value for a particular lambda1" + ) lambda_fwd_map[lambda1] = lambda2 # Make sure the lambda_idws values are consistent if lambda_idws is not None: - if lambda1 in lambda_bwd_map and lambda_bwd_map[lambda1] != lambda_idws: - logger.error(f'bwd: lambda1 {lambda1} has lambda_idws {lambda_bwd_map[lambda1]} but it has already been {lambda_idws}') - raise ValueError('More than one lambda_idws value for a particular lambda1') + if ( + lambda1 in lambda_bwd_map + and lambda_bwd_map[lambda1] != lambda_idws + ): + logger.error( + f"bwd: lambda1 {lambda1} has lambda_idws {lambda_bwd_map[lambda1]} but it has already been {lambda_idws}" + ) + raise ValueError( + "More than one lambda_idws value for a particular lambda1" + ) lambda_bwd_map[lambda1] = lambda_idws - is_ascending = next(iter(is_ascending)) all_lambdas = set() @@ -147,7 +161,7 @@ def extract_u_nk(fep_files, T): `fep_files` can now be a list of filenames. """ - beta = 1/(k_b * T) + beta = 1 / (k_b * T) # lists to get times and work values of each window win_ts = [] @@ -156,7 +170,7 @@ def extract_u_nk(fep_files, T): win_de_back = [] # create dataframe for results - u_nk = pd.DataFrame(columns=['time','fep-lambda']) + u_nk = pd.DataFrame(columns=["time", "fep-lambda"]) # boolean flag to parse data after equil time parsing = False @@ -176,32 +190,36 @@ def extract_u_nk(fep_files, T): for fep_file in sorted(fep_files, key=_filename_sort_key): # Note we have not set parsing=False because we could be continuing one window across # more than one fepout file - with anyopen(fep_file, 'r') as f: + with anyopen(fep_file, "r") as f: has_idws = False for line in f: l = line.strip().split() # We don't know if IDWS was enabled just from the #Free line, and we might not have # a #NEW line in this file, so we have to check for the existence of FepE_back lines # We rely on short-circuit evaluation to avoid the string comparison most of the time - if has_idws is False and l[0] == 'FepE_back:': + if has_idws is False and l[0] == "FepE_back:": has_idws = True # New window, get IDWS lambda if any # We keep track of lambdas from the #NEW line and if they disagree with the #Free line # within the same file, then complain. This can happen if truncated fepout files # are presented in the wrong order. - if l[0] == '#NEW': + if l[0] == "#NEW": if parsing: - logger.error(f'Window with lambda1: {lambda1_at_start} lambda2: {lambda2_at_start} lambda_idws: {lambda_idws_at_start} appears truncated') - logger.error(f'because a new window was encountered in {fep_file} before the previous one finished.') - raise ValueError('New window begun after truncated window') + logger.error( + f"Window with lambda1: {lambda1_at_start} lambda2: {lambda2_at_start} lambda_idws: {lambda_idws_at_start} appears truncated" + ) + logger.error( + f"because a new window was encountered in {fep_file} before the previous one finished." + ) + raise ValueError("New window begun after truncated window") lambda1_at_start, lambda2_at_start = float(l[6]), float(l[8]) - lambda_idws_at_start = float(l[10]) if 'LAMBDA_IDWS' in l else None + lambda_idws_at_start = float(l[10]) if "LAMBDA_IDWS" in l else None has_idws = True if lambda_idws_at_start is not None else False # this line marks end of window; dump data into dataframe - if l[0] == '#Free': + if l[0] == "#Free": # extract lambda values for finished window # lambda1 = sampling lambda (row), lambda2 = comparison lambda (col) lambda1 = float(l[7]) @@ -210,17 +228,25 @@ def extract_u_nk(fep_files, T): # If the lambdas are not what we thought they would be, raise an exception to ensure the calculation # fails. This can happen if fepouts where one window spans multiple fepouts are processed out of order # NB: There is no way to tell if lambda_idws changed because it isn't in the '#Free' line that ends a window - if lambda1_at_start is not None \ - and (lambda1, lambda2) != (lambda1_at_start, lambda2_at_start): - logger.error(f"Lambdas changed unexpectedly while processing {fep_file}") - logger.error(f"l1, l2: {lambda1_at_start}, {lambda2_at_start} changed to {lambda1}, {lambda2}") + if lambda1_at_start is not None and (lambda1, lambda2) != ( + lambda1_at_start, + lambda2_at_start, + ): + logger.error( + f"Lambdas changed unexpectedly while processing {fep_file}" + ) + logger.error( + f"l1, l2: {lambda1_at_start}, {lambda2_at_start} changed to {lambda1}, {lambda2}" + ) logger.error(line) - raise ValueError("Inconsistent lambda values within the same window") + raise ValueError( + "Inconsistent lambda values within the same window" + ) # As we are at the end of a window, convert last window's work and times values to np arrays # (with energy unit kT since they were kcal/mol in the fepouts) - win_de_arr = beta * np.asarray(win_de) # dE values - win_ts_arr = np.asarray(win_ts) # timesteps + win_de_arr = beta * np.asarray(win_de) # dE values + win_ts_arr = np.asarray(win_ts) # timesteps # This handles the special case where there are IDWS energies but no lambda_idws value in the # current .fepout file. This can happen when the NAMD firsttimestep is not 0, because NAMD only emits @@ -236,10 +262,16 @@ def extract_u_nk(fep_files, T): # Test for the highly pathological case where the first window is both incomplete and has IDWS # data but no lambda_idws value. if l1_idx == 0: - raise ValueError(f'IDWS data present in first window but lambda_idws not included; no way to infer the correct lambda_idws') + raise ValueError( + f"IDWS data present in first window but lambda_idws not included; no way to infer the correct lambda_idws" + ) lambda_idws_at_start = all_lambdas[l1_idx - 1] - logger.warning(f'Warning: {fep_file} has IDWS data but lambda_idws not included.') - logger.warning(f' lambda1 = {lambda1}, lambda2 = {lambda2}; inferring lambda_idws to be {lambda_idws_at_start}') + logger.warning( + f"Warning: {fep_file} has IDWS data but lambda_idws not included." + ) + logger.warning( + f" lambda1 = {lambda1}, lambda2 = {lambda2}; inferring lambda_idws to be {lambda_idws_at_start}" + ) if lambda_idws_at_start is not None: # Mimic classic DWS data @@ -248,22 +280,28 @@ def extract_u_nk(fep_files, T): win_de_back_arr = beta * np.asarray(win_de_back) n = min(len(win_de_back_arr), len(win_de_arr)) - tempDF = pd.DataFrame({ - 'time': win_ts_arr[:n], - 'fep-lambda': np.full(n,lambda1), - lambda1: 0, - lambda2: win_de_arr[:n], - lambda_idws_at_start: win_de_back_arr[:n]}) + tempDF = pd.DataFrame( + { + "time": win_ts_arr[:n], + "fep-lambda": np.full(n, lambda1), + lambda1: 0, + lambda2: win_de_arr[:n], + lambda_idws_at_start: win_de_back_arr[:n], + } + ) # print(f"{fep_file}: IDWS window {lambda1} {lambda2} {lambda_idws_at_start}") else: # print(f"{fep_file}: Forward-only window {lambda1} {lambda2}") # create dataframe of times and work values # this window's data goes in row LAMBDA1 and column LAMBDA2 - tempDF = pd.DataFrame({ - 'time': win_ts_arr, - 'fep-lambda': np.full(len(win_de_arr), lambda1), - lambda1: 0, - lambda2: win_de_arr}) + tempDF = pd.DataFrame( + { + "time": win_ts_arr, + "fep-lambda": np.full(len(win_de_arr), lambda1), + lambda1: 0, + lambda2: win_de_arr, + } + ) # join the new window's df to existing df u_nk = pd.concat([u_nk, tempDF], sort=False) @@ -275,38 +313,42 @@ def extract_u_nk(fep_files, T): win_ts_back = [] parsing = False has_idws = False - lambda1_at_start, lambda2_at_start, lambda_idws_at_start = None, None, None + lambda1_at_start, lambda2_at_start, lambda_idws_at_start = ( + None, + None, + None, + ) # append work value from 'dE' column of fepout file if parsing: - if l[0] == 'FepEnergy:': + if l[0] == "FepEnergy:": win_de.append(float(l[6])) win_ts.append(float(l[1])) - elif l[0] == 'FepE_back:': + elif l[0] == "FepE_back:": win_de_back.append(float(l[6])) win_ts_back.append(float(l[1])) # Turn parsing on after line 'STARTING COLLECTION OF ENSEMBLE AVERAGE' - if '#STARTING' in l: + if "#STARTING" in l: parsing = True - if len(win_de) != 0 or len(win_de_back) != 0: # pragma: no cover - logger.warning('Trailing data without footer line (\"#Free energy...\"). Interrupted run?') - raise ValueError('Last window is truncated') - + if len(win_de) != 0 or len(win_de_back) != 0: # pragma: no cover + logger.warning( + 'Trailing data without footer line ("#Free energy..."). Interrupted run?' + ) + raise ValueError("Last window is truncated") if lambda2 in (0.0, 1.0): # this excludes the IDWS case where a dataframe already exists for both endpoints # create last dataframe for fep-lambda at last LAMBDA2 - tempDF = pd.DataFrame({ - 'time': win_ts_arr, - 'fep-lambda': lambda2}) + tempDF = pd.DataFrame({"time": win_ts_arr, "fep-lambda": lambda2}) u_nk = pd.concat([u_nk, tempDF], sort=True) - u_nk.set_index(['time','fep-lambda'], inplace=True) + u_nk.set_index(["time", "fep-lambda"], inplace=True) return u_nk + def extract(fep_files, T): """Return reduced potentials `u_nk` from NAMD fepout file(s). @@ -342,4 +384,6 @@ def extract(fep_files, T): .. versionadded:: 1.0.0 """ - return {"u_nk": extract_u_nk(fep_files, T)} # NOTE: maybe we should also have 'dHdl': None + return { + "u_nk": extract_u_nk(fep_files, T) + } # NOTE: maybe we should also have 'dHdl': None diff --git a/src/alchemlyb/parsing/util.py b/src/alchemlyb/parsing/util.py index f8259aa6..28e5a568 100644 --- a/src/alchemlyb/parsing/util.py +++ b/src/alchemlyb/parsing/util.py @@ -1,23 +1,24 @@ """Collection of utilities used by many parsers. """ -import os -from os import PathLike -from typing import IO, Optional, Union import bz2 import gzip +import os +from os import PathLike +from typing import IO, Union + def bz2_open(filename, mode): - mode += 't' if mode in ['r','w','a','x'] else '' + mode += "t" if mode in ["r", "w", "a", "x"] else "" return bz2.open(filename, mode) def gzip_open(filename, mode): - mode += 't' if mode in ['r','w','a','x'] else '' + mode += "t" if mode in ["r", "w", "a", "x"] else "" return gzip.open(filename, mode) -def anyopen(datafile: Union[PathLike, IO], mode='r', compression=None): +def anyopen(datafile: Union[PathLike, IO], mode="r", compression=None): """Return a file stream for file or stream, even if compressed. Supports files compressed with bzip2 (.bz2) and gzip (.gz) compression @@ -59,16 +60,15 @@ def anyopen(datafile: Union[PathLike, IO], mode='r', compression=None): """ # opener for each type of file - extensions = {'.bz2': bz2_open, - '.gz': gzip_open} + extensions = {".bz2": bz2_open, ".gz": gzip_open} # compression selections available - compressions = {'bzip2': bz2_open, - 'gzip': gzip_open} + compressions = {"bzip2": bz2_open, "gzip": gzip_open} # if `datafile` is a stream - if ((hasattr(datafile, 'read') and any((i in mode for i in ('r',)))) or - (hasattr(datafile, 'write') and any((i in mode for i in ('w', 'a', 'x'))))): + if (hasattr(datafile, "read") and any((i in mode for i in ("r",)))) or ( + hasattr(datafile, "write") and any((i in mode for i in ("w", "a", "x"))) + ): # if no compression specified, just pass the stream through if compression is None: return datafile @@ -76,7 +76,9 @@ def anyopen(datafile: Union[PathLike, IO], mode='r', compression=None): compressor = compressions[compression] return compressor(datafile, mode=mode) else: - raise ValueError("`datafile` is a stream, but specified `compression` '{compression}' is not supported") + raise ValueError( + "`datafile` is a stream, but specified `compression` '{compression}' is not supported" + ) # otherwise, treat as a file # allow compression to override any extension on the file diff --git a/src/alchemlyb/postprocessors/__init__.py b/src/alchemlyb/postprocessors/__init__.py index 6e769ac4..932d2b06 100644 --- a/src/alchemlyb/postprocessors/__init__.py +++ b/src/alchemlyb/postprocessors/__init__.py @@ -1,3 +1,3 @@ __all__ = [ - 'units', + "units", ] diff --git a/src/alchemlyb/postprocessors/units.py b/src/alchemlyb/postprocessors/units.py index f5e1984d..510b4465 100644 --- a/src/alchemlyb/postprocessors/units.py +++ b/src/alchemlyb/postprocessors/units.py @@ -12,8 +12,9 @@ #: in :mod:`scipy.constants` R_kJmol = R / 1000 + def to_kT(df, T=None): - """ Convert the unit of a DataFrame to `kT`. + """Convert the unit of a DataFrame to `kT`. If temperature `T` is not provided, the DataFrame need to have attribute `temperature` and `energy_unit`. Otherwise, the temperature of the output @@ -33,28 +34,28 @@ def to_kT(df, T=None): """ new_df = df.copy() if T is not None: - new_df.attrs['temperature'] = T - elif 'temperature' not in df.attrs: - raise TypeError('Attribute temperature not found in the input ' - 'Dataframe.') + new_df.attrs["temperature"] = T + elif "temperature" not in df.attrs: + raise TypeError("Attribute temperature not found in the input " "Dataframe.") - if 'energy_unit' not in df.attrs: - raise TypeError('Attribute energy_unit not found in the input ' - 'Dataframe.') + if "energy_unit" not in df.attrs: + raise TypeError("Attribute energy_unit not found in the input " "Dataframe.") - if df.attrs['energy_unit'] == 'kT': + if df.attrs["energy_unit"] == "kT": return new_df - elif df.attrs['energy_unit'] == 'kJ/mol': - new_df /= R_kJmol * df.attrs['temperature'] - new_df.attrs['energy_unit'] = 'kT' + elif df.attrs["energy_unit"] == "kJ/mol": + new_df /= R_kJmol * df.attrs["temperature"] + new_df.attrs["energy_unit"] = "kT" return new_df - elif df.attrs['energy_unit'] == 'kcal/mol': - new_df /= R_kJmol * df.attrs['temperature'] * kJ2kcal - new_df.attrs['energy_unit'] = 'kT' + elif df.attrs["energy_unit"] == "kcal/mol": + new_df /= R_kJmol * df.attrs["temperature"] * kJ2kcal + new_df.attrs["energy_unit"] = "kT" return new_df else: - raise ValueError('energy_unit {} can only be kT, kJ/mol or ' \ - 'kcal/mol.'.format(df.attrs['energy_unit'])) + raise ValueError( + "energy_unit {} can only be kT, kJ/mol or " + "kcal/mol.".format(df.attrs["energy_unit"]) + ) def to_kcalmol(df, T=None): @@ -77,10 +78,11 @@ def to_kcalmol(df, T=None): `df` converted. """ kt_df = to_kT(df, T) - kt_df *= R_kJmol * df.attrs['temperature'] * kJ2kcal - kt_df.attrs['energy_unit'] = 'kcal/mol' + kt_df *= R_kJmol * df.attrs["temperature"] * kJ2kcal + kt_df.attrs["energy_unit"] = "kcal/mol" return kt_df + def to_kJmol(df, T=None): """Convert the unit of a DataFrame to kJ/mol. @@ -101,12 +103,13 @@ def to_kJmol(df, T=None): `df` converted. """ kt_df = to_kT(df, T) - kt_df *= R_kJmol * df.attrs['temperature'] - kt_df.attrs['energy_unit'] = 'kJ/mol' + kt_df *= R_kJmol * df.attrs["temperature"] + kt_df.attrs["energy_unit"] = "kJ/mol" return kt_df + def get_unit_converter(units): - """ Obtain the converter according to the unit string. + """Obtain the converter according to the unit string. If `units` is 'kT', the `to_kT` converter is returned. If `units` is 'kJ/mol', the `to_kJmol` converter is returned. If `units` is 'kcal/mol', @@ -125,12 +128,12 @@ def get_unit_converter(units): .. versionadded:: 0.5.0 """ - converters = {'kT': to_kT, 'kJ/mol': to_kJmol, - 'kcal/mol': to_kcalmol} + converters = {"kT": to_kT, "kJ/mol": to_kJmol, "kcal/mol": to_kcalmol} try: convert = converters[units] except KeyError: raise ValueError( f"Energy unit {units} is not supported, " - f"choose one of {list(converters.keys())}") + f"choose one of {list(converters.keys())}" + ) return convert diff --git a/src/alchemlyb/preprocessing/__init__.py b/src/alchemlyb/preprocessing/__init__.py index 6b759482..223c942e 100644 --- a/src/alchemlyb/preprocessing/__init__.py +++ b/src/alchemlyb/preprocessing/__init__.py @@ -3,16 +3,22 @@ preparing data for estimators. """ -from .subsampling import slicing, dhdl2series, u_nk2series, decorrelate_dhdl, decorrelate_u_nk -from .subsampling import statistical_inefficiency from .subsampling import equilibrium_detection +from .subsampling import ( + slicing, + dhdl2series, + u_nk2series, + decorrelate_dhdl, + decorrelate_u_nk, +) +from .subsampling import statistical_inefficiency __all__ = [ - 'slicing', - 'statistical_inefficiency', - 'equilibrium_detection', - 'decorrelate_dhdl', - 'decorrelate_u_nk', - 'dhdl2series', - 'u_nk2series' + "slicing", + "statistical_inefficiency", + "equilibrium_detection", + "decorrelate_dhdl", + "decorrelate_u_nk", + "dhdl2series", + "u_nk2series", ] diff --git a/src/alchemlyb/preprocessing/subsampling.py b/src/alchemlyb/preprocessing/subsampling.py index 3b8bd8af..653adbca 100644 --- a/src/alchemlyb/preprocessing/subsampling.py +++ b/src/alchemlyb/preprocessing/subsampling.py @@ -4,13 +4,18 @@ import warnings import pandas as pd -from pymbar.timeseries import (statisticalInefficiency, - detectEquilibration, - subsampleCorrelatedData, ) +from pymbar.timeseries import ( + statisticalInefficiency, + detectEquilibration, + subsampleCorrelatedData, +) + from .. import pass_attrs -def decorrelate_u_nk(df, method='dE', drop_duplicates=True, - sort=True, remove_burnin=False, **kwargs): + +def decorrelate_u_nk( + df, method="dE", drop_duplicates=True, sort=True, remove_burnin=False, **kwargs +): """Subsample an u_nk DataFrame based on the selected method. The method can be either 'all' (obtained as a sum over all energy @@ -57,8 +62,8 @@ def decorrelate_u_nk(df, method='dE', drop_duplicates=True, deprecate the 'dhdl'. """ - kwargs['drop_duplicates'] = drop_duplicates - kwargs['sort'] = sort + kwargs["drop_duplicates"] = drop_duplicates + kwargs["sort"] = sort series = u_nk2series(df, method) @@ -67,8 +72,10 @@ def decorrelate_u_nk(df, method='dE', drop_duplicates=True, else: return statistical_inefficiency(df, series, **kwargs) -def decorrelate_dhdl(df, drop_duplicates=True, sort=True, - remove_burnin=False, **kwargs): + +def decorrelate_dhdl( + df, drop_duplicates=True, sort=True, remove_burnin=False, **kwargs +): """Subsample a dhdl DataFrame. This is a wrapper function around the function :func:`~alchemlyb.preprocessing.subsampling.statistical_inefficiency` and @@ -111,8 +118,8 @@ def decorrelate_dhdl(df, drop_duplicates=True, sort=True, """ - kwargs['drop_duplicates'] = drop_duplicates - kwargs['sort'] = sort + kwargs["drop_duplicates"] = drop_duplicates + kwargs["sort"] = sort series = dhdl2series(df) @@ -121,8 +128,9 @@ def decorrelate_dhdl(df, drop_duplicates=True, sort=True, else: return statistical_inefficiency(df, series, **kwargs) + @pass_attrs -def u_nk2series(df, method='dE'): +def u_nk2series(df, method="dE"): """Convert an u_nk DataFrame into a series based on the selected method for subsampling. @@ -152,18 +160,22 @@ def u_nk2series(df, method='dE'): # deprecation: remove in 3.0.0 # (the deprecations should show up in the calling functions) - if method == 'dhdl': - warnings.warn("Method 'dhdl' has been deprecated, using 'dE' instead. " - "'dhdl' will be removed in alchemlyb 3.0.0.", - category=DeprecationWarning, - stacklevel=2) - method = 'dE' - elif method == 'dhdl_all': - warnings.warn("Method 'dhdl_all' has been deprecated, using 'all' instead. " - "'dhdl_all' will be removed in alchemlyb 3.0.0.", - category=DeprecationWarning, - stacklevel=2) - method = 'all' + if method == "dhdl": + warnings.warn( + "Method 'dhdl' has been deprecated, using 'dE' instead. " + "'dhdl' will be removed in alchemlyb 3.0.0.", + category=DeprecationWarning, + stacklevel=2, + ) + method = "dE" + elif method == "dhdl_all": + warnings.warn( + "Method 'dhdl_all' has been deprecated, using 'all' instead. " + "'dhdl_all' will be removed in alchemlyb 3.0.0.", + category=DeprecationWarning, + stacklevel=2, + ) + method = "all" # Check if the input is u_nk try: @@ -172,11 +184,11 @@ def u_nk2series(df, method='dE'): key = key[0] df[key] except KeyError: - raise ValueError('The input should be u_nk') + raise ValueError("The input should be u_nk") - if method == 'all': + if method == "all": series = df.sum(axis=1) - elif method == 'dE': + elif method == "dE": # Using the same logic as alchemical-analysis key = df.index.values[0][1:] if len(key) == 1: @@ -192,13 +204,12 @@ def u_nk2series(df, method='dE'): else: series = df.iloc[:, index - 1] else: - raise ValueError( - 'Decorrelation method {} not found.'.format(method)) + raise ValueError("Decorrelation method {} not found.".format(method)) return series @pass_attrs -def dhdl2series(df, method='all'): +def dhdl2series(df, method="all"): """Convert a dhdl DataFrame to a series for subsampling. The series is generated by summing over all energy components (axis 1 of @@ -235,13 +246,15 @@ def dhdl2series(df, method='all'): def _check_multiple_times(df): if isinstance(df, pd.Series): - return df.sort_index(axis=0).reset_index('time', name='').duplicated('time').any() + return ( + df.sort_index(axis=0).reset_index("time", name="").duplicated("time").any() + ) else: - return df.sort_index(axis=0).reset_index('time').duplicated('time').any() + return df.sort_index(axis=0).reset_index("time").duplicated("time").any() def _check_sorted(df): - return df.reset_index(0)['time'].is_monotonic_increasing + return df.reset_index(0)["time"].is_monotonic_increasing def _drop_duplicates(df, series=None): @@ -265,34 +278,44 @@ def _drop_duplicates(df, series=None): """ if isinstance(df, pd.Series): # remove the duplicate based on time - drop_duplicates_series = df.reset_index('time', name=''). \ - drop_duplicates('time') + drop_duplicates_series = df.reset_index("time", name="").drop_duplicates("time") # Rest the time index - lambda_names = ['time', ] + lambda_names = [ + "time", + ] lambda_names.extend(drop_duplicates_series.index.names) - df = drop_duplicates_series.set_index('time', append=True). \ - reorder_levels(lambda_names) + df = drop_duplicates_series.set_index("time", append=True).reorder_levels( + lambda_names + ) else: # remove the duplicate based on time - drop_duplicates_df = df.reset_index('time').drop_duplicates('time') + drop_duplicates_df = df.reset_index("time").drop_duplicates("time") # Rest the time index - lambda_names = ['time', ] + lambda_names = [ + "time", + ] lambda_names.extend(drop_duplicates_df.index.names) - df = drop_duplicates_df.set_index('time', append=True). \ - reorder_levels(lambda_names) + df = drop_duplicates_df.set_index("time", append=True).reorder_levels( + lambda_names + ) # Do the same withing with the series if series is not None: # remove the duplicate based on time - drop_duplicates_series = series.reset_index('time', name=''). \ - drop_duplicates('time') + drop_duplicates_series = series.reset_index("time", name="").drop_duplicates( + "time" + ) # Rest the time index - lambda_names = ['time', ] + lambda_names = [ + "time", + ] lambda_names.extend(drop_duplicates_series.index.names) - series = drop_duplicates_series.set_index('time', append=True). \ - reorder_levels(lambda_names) + series = drop_duplicates_series.set_index("time", append=True).reorder_levels( + lambda_names + ) return df, series + def _sort_by_time(df, series=None): """Sort the ``df`` by time which could be Dataframe or Series, if series is provided, sort the series as well. @@ -311,12 +334,13 @@ def _sort_by_time(df, series=None): series : Series Formatted Series. """ - df = df.sort_index(level='time') + df = df.sort_index(level="time") if series is not None: - series = series.sort_index(level='time') + series = series.sort_index(level="time") return df, series + def _prepare_input(df, series, drop_duplicates, sort): """Prepare and check the input to be used for statistical_inefficiency or equilibrium_detection. @@ -341,7 +365,8 @@ def _prepare_input(df, series, drop_duplicates, sort): raise KeyError( "Duplicate time values found; statistical inefficiency " "only works on a single, contiguous, " - "and sorted timeseries.") + "and sorted timeseries." + ) if not _check_sorted(df): if sort: @@ -349,16 +374,17 @@ def _prepare_input(df, series, drop_duplicates, sort): else: raise KeyError( "Statistical inefficiency only works as expected if " - "values are sorted by time, increasing.") + "values are sorted by time, increasing." + ) if series is not None: - if (len(series) != len(df) or - not all( - series.reset_index()['time'] == df.reset_index()['time'])): - raise ValueError( - "series and data must be sampled at the same times") + if len(series) != len(df) or not all( + series.reset_index()["time"] == df.reset_index()["time"] + ): + raise ValueError("series and data must be sampled at the same times") return df, series + def slicing(df, lower=None, upper=None, step=None, force=False): """Subsample a DataFrame using simple slicing. @@ -390,16 +416,25 @@ def slicing(df, lower=None, upper=None, step=None, force=False): raise KeyError("DataFrame rows must be sorted by time, increasing.") if not force and _check_multiple_times(df): - raise KeyError("Duplicate time values found; it's generally advised " - "to use slicing on DataFrames with unique time values " - "for each row. Use `force=True` to ignore this error.") + raise KeyError( + "Duplicate time values found; it's generally advised " + "to use slicing on DataFrames with unique time values " + "for each row. Use `force=True` to ignore this error." + ) return df -def statistical_inefficiency(df, series=None, lower=None, upper=None, - step=None, conservative=True, - drop_duplicates=False, sort=False): +def statistical_inefficiency( + df, + series=None, + lower=None, + upper=None, + step=None, + conservative=True, + drop_duplicates=False, + sort=False, +): """Subsample a DataFrame based on the calculated statistical inefficiency of a timeseries. @@ -480,8 +515,7 @@ def statistical_inefficiency(df, series=None, lower=None, upper=None, statinef = statisticalInefficiency(series, fast=False) # use the subsampleCorrelatedData function to get the subsample index - indices = subsampleCorrelatedData(series, g=statinef, - conservative=conservative) + indices = subsampleCorrelatedData(series, g=statinef, conservative=conservative) df = df.iloc[indices] else: df = slicing(df, lower=lower, upper=upper, step=step) @@ -489,8 +523,15 @@ def statistical_inefficiency(df, series=None, lower=None, upper=None, return df -def equilibrium_detection(df, series=None, lower=None, upper=None, step=None, - drop_duplicates=False, sort=False): +def equilibrium_detection( + df, + series=None, + lower=None, + upper=None, + step=None, + drop_duplicates=False, + sort=False, +): """Subsample a DataFrame using automated equilibrium detection on a timeseries. This function uses the :mod:`pymbar` implementation of the *simple diff --git a/src/alchemlyb/tests/parsing/test_amber.py b/src/alchemlyb/tests/parsing/test_amber.py index c1fe137d..0d186cc7 100644 --- a/src/alchemlyb/tests/parsing/test_amber.py +++ b/src/alchemlyb/tests/parsing/test_amber.py @@ -2,28 +2,30 @@ """ import logging + +import pandas as pd import pytest +from alchemtest.amber import load_bace_example +from alchemtest.amber import load_bace_improper +from alchemtest.amber import load_simplesolvated +from alchemtest.amber import load_testfiles from numpy.testing import assert_allclose -import pandas as pd +from alchemlyb.parsing.amber import extract from alchemlyb.parsing.amber import extract_dHdl from alchemlyb.parsing.amber import extract_u_nk -from alchemlyb.parsing.amber import extract -from alchemtest.amber import load_simplesolvated -from alchemtest.amber import load_bace_example -from alchemtest.amber import load_bace_improper -from alchemtest.amber import load_testfiles ################################################################################## ################ Check the parser behaviour with problematic files ################################################################################## + @pytest.fixture(name="testfiles", scope="module") def fixture_testfiles(): - """ Returns the testfiles data dictionary """ + """Returns the testfiles data dictionary""" bunch = load_testfiles() - return bunch['data'] + return bunch["data"] def test_file_not_found(): @@ -77,10 +79,10 @@ def test_no_control_data(caplog, testfiles): def test_no_free_energy_info(caplog, testfiles): """Test if we raise an exception if there is no free energy section""" filename = testfiles["no_free_energy_info"][0] - with pytest.raises(ValueError, match='no free energy section found'): + with pytest.raises(ValueError, match="no free energy section found"): with caplog.at_level(logging.ERROR): _ = extract(str(filename), T=298.0) - assert 'No free energy section found' in caplog.text + assert "No free energy section found" in caplog.text def test_no_useful_data(caplog, testfiles): @@ -89,7 +91,7 @@ def test_no_useful_data(caplog, testfiles): with pytest.raises(ValueError, match="does not contain any data"): with caplog.at_level(logging.ERROR): _ = extract(str(filename), T=298.0) - assert 'does not contain any data' in caplog.text + assert "does not contain any data" in caplog.text def test_no_temp0_set(caplog, testfiles): @@ -119,16 +121,16 @@ def test_long_and_wrong_number_MBAR(caplog, testfiles): with pytest.raises(ValueError, match="the number of lambda windows read"): with caplog.at_level(logging.ERROR): _ = extract_u_nk(str(filename), T=300.0) - assert 'the number of lambda windows read' in caplog.text + assert "the number of lambda windows read" in caplog.text def test_no_starting_time(caplog, testfiles): """Test if raise an exception if the starting time is not read""" filename = testfiles["no_starting_simulation_time"][0] - with pytest.raises(ValueError, match='No starting simulation time in file'): + with pytest.raises(ValueError, match="No starting simulation time in file"): with caplog.at_level(logging.ERROR): _ = extract(str(filename), T=298.0) - assert 'No starting simulation time in file' in caplog.text + assert "No starting simulation time in file" in caplog.text def test_parse_without_spaces_around_equal(testfiles): @@ -138,23 +140,24 @@ def test_parse_without_spaces_around_equal(testfiles): """ filename = testfiles["no_spaces_around_equal"][0] df_dict = extract(str(filename), T=298.0) - assert isinstance(df_dict['dHdl'], pd.DataFrame) + assert isinstance(df_dict["dHdl"], pd.DataFrame) ################################################################################## ################ Check the parser behaviour with standard single files ################################################################################## + @pytest.fixture(name="single_u_nk", scope="module") def fixture_single_u_nk(): """return a single file to check u_unk parsing""" - return load_bace_example().data['complex']['vdw'][0] + return load_bace_example().data["complex"]["vdw"][0] @pytest.fixture(name="single_dHdl", scope="module") def fixture_single_dHdl(): """return a single file to check dHdl parsing""" - return load_simplesolvated().data['charge'][0] + return load_simplesolvated().data["charge"][0] def test_dHdl_time_reading(single_dHdl): @@ -175,18 +178,18 @@ def test_extract_with_both_data(single_u_nk): """Test that dHdl and u_nk have the correct form when extracted from files with the single "extract" funcion.""" df_dict = extract(single_u_nk, T=298.0) - assert df_dict['dHdl'].index.names == ('time', 'lambdas') - assert df_dict['dHdl'].shape == (500, 1) - assert df_dict['u_nk'].index.names == ('time', 'lambdas') + assert df_dict["dHdl"].index.names == ("time", "lambdas") + assert df_dict["dHdl"].shape == (500, 1) + assert df_dict["u_nk"].index.names == ("time", "lambdas") def test_extract_with_only_dhdl_data(single_dHdl): """Test that parsing with the extract function a file - with just dHdl gives the correct results""" + with just dHdl gives the correct results""" df_dict = extract(single_dHdl, T=298.0) - assert df_dict['dHdl'].index.names == ('time', 'lambdas') - assert df_dict['dHdl'].shape == (500, 1) - assert df_dict['u_nk'] is None + assert df_dict["dHdl"].index.names == ("time", "lambdas") + assert df_dict["dHdl"].shape == (500, 1) + assert df_dict["u_nk"] is None def test_wrong_T_should_raise_warning(single_dHdl, T=300.0): @@ -195,24 +198,21 @@ def test_wrong_T_should_raise_warning(single_dHdl, T=300.0): read from the AMBER file gives a warning """ with pytest.raises( - ValueError, - match="is different from the temperature passed as parameter"): + ValueError, match="is different from the temperature passed as parameter" + ): _ = extract(single_dHdl, T=T) - ################################################################### ################ Check the behaviour on proper datasets ################################################################### -@pytest.mark.parametrize("filename", - [filename - for leg in load_simplesolvated()['data'].values() - for filename in leg]) -def test_dHdl(filename, - names=('time', 'lambdas'), - shape=(500, 1)): +@pytest.mark.parametrize( + "filename", + [filename for leg in load_simplesolvated()["data"].values() for filename in leg], +) +def test_dHdl(filename, names=("time", "lambdas"), shape=(500, 1)): """Test that dHdl has the correct form when extracted from files.""" dHdl = extract_dHdl(filename, T=298.0) @@ -220,27 +220,33 @@ def test_dHdl(filename, assert dHdl.shape == shape -@pytest.mark.parametrize("mbar_filename", - [mbar_filename - for leg in load_bace_example()['data']['complex'].values() - for mbar_filename in leg]) -def test_u_nk(mbar_filename, - names=('time', 'lambdas')): +@pytest.mark.parametrize( + "mbar_filename", + [ + mbar_filename + for leg in load_bace_example()["data"]["complex"].values() + for mbar_filename in leg + ], +) +def test_u_nk(mbar_filename, names=("time", "lambdas")): """Test the u_nk has the correct form when extracted from files""" u_nk = extract_u_nk(mbar_filename, T=298.0) assert u_nk.index.names == names -@pytest.mark.parametrize("improper_filename", - [improper_filename - for leg in load_bace_improper()['data'].values() - for improper_filename in leg]) -def test_u_nk_improper(improper_filename, - names=('time', 'lambdas')): +@pytest.mark.parametrize( + "improper_filename", + [ + improper_filename + for leg in load_bace_improper()["data"].values() + for improper_filename in leg + ], +) +def test_u_nk_improper(improper_filename, names=("time", "lambdas")): """Test the u_nk has the correct form when extracted from files""" try: u_nk = extract_u_nk(improper_filename, T=298.0) assert u_nk.index.names == names except Exception: - assert '0.5626' in improper_filename + assert "0.5626" in improper_filename diff --git a/src/alchemlyb/tests/parsing/test_gmx.py b/src/alchemlyb/tests/parsing/test_gmx.py index d85ad1bf..1959c925 100644 --- a/src/alchemlyb/tests/parsing/test_gmx.py +++ b/src/alchemlyb/tests/parsing/test_gmx.py @@ -3,118 +3,144 @@ """ import bz2 -import pytest -from alchemlyb.parsing.gmx import extract_dHdl, extract_u_nk, extract +import pytest from alchemtest.gmx import load_benzene -from alchemtest.gmx import load_expanded_ensemble_case_1, load_expanded_ensemble_case_2, load_expanded_ensemble_case_3 -from alchemtest.gmx import load_water_particle_with_total_energy +from alchemtest.gmx import ( + load_expanded_ensemble_case_1, + load_expanded_ensemble_case_2, + load_expanded_ensemble_case_3, +) from alchemtest.gmx import load_water_particle_with_potential_energy +from alchemtest.gmx import load_water_particle_with_total_energy from alchemtest.gmx import load_water_particle_without_energy from numpy.testing import assert_almost_equal +from alchemlyb.parsing.gmx import extract_dHdl, extract_u_nk, extract + def test_dHdl(): - """Test that dHdl has the correct form when extracted from files. - - """ + """Test that dHdl has the correct form when extracted from files.""" dataset = load_benzene() - for leg in dataset['data']: - for filename in dataset['data'][leg]: + for leg in dataset["data"]: + for filename in dataset["data"][leg]: dHdl = extract_dHdl(filename, T=300) - assert dHdl.index.names == ['time', 'fep-lambda'] + assert dHdl.index.names == ["time", "fep-lambda"] assert dHdl.shape == (4001, 1) -def test_u_nk(): - """Test that u_nk has the correct form when extracted from files. - """ +def test_u_nk(): + """Test that u_nk has the correct form when extracted from files.""" dataset = load_benzene() - for leg in dataset['data']: - for filename in dataset['data'][leg]: + for leg in dataset["data"]: + for filename in dataset["data"][leg]: u_nk = extract_u_nk(filename, T=300) - assert u_nk.index.names == ['time', 'fep-lambda'] - if leg == 'Coulomb': + assert u_nk.index.names == ["time", "fep-lambda"] + if leg == "Coulomb": assert u_nk.shape == (4001, 5) - elif leg == 'VDW': + elif leg == "VDW": assert u_nk.shape == (4001, 16) -def test_u_nk_case1(): - """Test that u_nk has the correct form when extracted from expanded ensemble files (case 1). - """ +def test_u_nk_case1(): + """Test that u_nk has the correct form when extracted from expanded ensemble files (case 1).""" dataset = load_expanded_ensemble_case_1() - for leg in dataset['data']: - for filename in dataset['data'][leg]: + for leg in dataset["data"]: + for filename in dataset["data"][leg]: u_nk = extract_u_nk(filename, T=300, filter=False) - assert u_nk.index.names == ['time', 'fep-lambda', 'coul-lambda', 'vdw-lambda', 'restraint-lambda'] + assert u_nk.index.names == [ + "time", + "fep-lambda", + "coul-lambda", + "vdw-lambda", + "restraint-lambda", + ] assert u_nk.shape == (50001, 28) -def test_dHdl_case1(): - """Test that dHdl has the correct form when extracted from expanded ensemble files (case 1). - """ +def test_dHdl_case1(): + """Test that dHdl has the correct form when extracted from expanded ensemble files (case 1).""" dataset = load_expanded_ensemble_case_1() - for leg in dataset['data']: - for filename in dataset['data'][leg]: + for leg in dataset["data"]: + for filename in dataset["data"][leg]: dHdl = extract_dHdl(filename, T=300, filter=False) - assert dHdl.index.names == ['time', 'fep-lambda', 'coul-lambda', 'vdw-lambda', 'restraint-lambda'] + assert dHdl.index.names == [ + "time", + "fep-lambda", + "coul-lambda", + "vdw-lambda", + "restraint-lambda", + ] assert dHdl.shape == (50001, 4) -def test_u_nk_case2(): - """Test that u_nk has the correct form when extracted from expanded ensemble files (case 2). - """ +def test_u_nk_case2(): + """Test that u_nk has the correct form when extracted from expanded ensemble files (case 2).""" dataset = load_expanded_ensemble_case_2() - for leg in dataset['data']: - for filename in dataset['data'][leg]: + for leg in dataset["data"]: + for filename in dataset["data"][leg]: u_nk = extract_u_nk(filename, T=300, filter=False) - assert u_nk.index.names == ['time', 'fep-lambda', 'coul-lambda', 'vdw-lambda', 'restraint-lambda'] + assert u_nk.index.names == [ + "time", + "fep-lambda", + "coul-lambda", + "vdw-lambda", + "restraint-lambda", + ] assert u_nk.shape == (25001, 28) -def test_u_nk_case3(): - """Test that u_nk has the correct form when extracted from REX files (case 3). - """ +def test_u_nk_case3(): + """Test that u_nk has the correct form when extracted from REX files (case 3).""" dataset = load_expanded_ensemble_case_3() - for leg in dataset['data']: - for filename in dataset['data'][leg]: + for leg in dataset["data"]: + for filename in dataset["data"][leg]: u_nk = extract_u_nk(filename, T=300, filter=False) - assert u_nk.index.names == ['time', 'fep-lambda', 'coul-lambda', 'vdw-lambda', 'restraint-lambda'] + assert u_nk.index.names == [ + "time", + "fep-lambda", + "coul-lambda", + "vdw-lambda", + "restraint-lambda", + ] assert u_nk.shape == (2500, 28) -def test_dHdl_case3(): - """Test that dHdl has the correct form when extracted from REX files (case 3). - """ +def test_dHdl_case3(): + """Test that dHdl has the correct form when extracted from REX files (case 3).""" dataset = load_expanded_ensemble_case_3() - for leg in dataset['data']: - for filename in dataset['data'][leg]: + for leg in dataset["data"]: + for filename in dataset["data"][leg]: dHdl = extract_dHdl(filename, T=300, filter=False) - assert dHdl.index.names == ['time', 'fep-lambda', 'coul-lambda', 'vdw-lambda', 'restraint-lambda'] + assert dHdl.index.names == [ + "time", + "fep-lambda", + "coul-lambda", + "vdw-lambda", + "restraint-lambda", + ] assert dHdl.shape == (2500, 4) -def test_u_nk_with_total_energy(): - """Test that the reduced potential is calculated correctly when the total energy is given. - """ +def test_u_nk_with_total_energy(): + """Test that the reduced potential is calculated correctly when the total energy is given.""" # Load dataset dataset = load_water_particle_with_total_energy() @@ -124,15 +150,16 @@ def test_u_nk_with_total_energy(): # Check one specific value in the dataframe assert_almost_equal( - extract_u_nk(dataset['data']['AllStates'][0], T=300).loc[0][(0.0,0.0)].values[0], + extract_u_nk(dataset["data"]["AllStates"][0], T=300) + .loc[0][(0.0, 0.0)] + .values[0], -11211.577658852531, - decimal=6 + decimal=6, ) -def test_u_nk_with_potential_energy(): - """Test that the reduced potential is calculated correctly when the potential energy is given. - """ +def test_u_nk_with_potential_energy(): + """Test that the reduced potential is calculated correctly when the potential energy is given.""" # Load dataset dataset = load_water_particle_with_potential_energy() @@ -142,16 +169,16 @@ def test_u_nk_with_potential_energy(): # Check one specific value in the dataframe assert_almost_equal( - extract_u_nk(dataset['data']['AllStates'][0], T=300).loc[0][(0.0,0.0)].values[0], + extract_u_nk(dataset["data"]["AllStates"][0], T=300) + .loc[0][(0.0, 0.0)] + .values[0], -15656.557252200757, - decimal=6 + decimal=6, ) def test_u_nk_without_energy(): - """Test that the reduced potential is calculated correctly when no energy is given. - - """ + """Test that the reduced potential is calculated correctly when no energy is given.""" # Load dataset dataset = load_water_particle_without_energy() @@ -161,105 +188,114 @@ def test_u_nk_without_energy(): # Check one specific value in the dataframe assert_almost_equal( - extract_u_nk(dataset['data']['AllStates'][0], T=300).loc[0][(0.0,0.0)].values[0], + extract_u_nk(dataset["data"]["AllStates"][0], T=300) + .loc[0][(0.0, 0.0)] + .values[0], 0.0, - decimal=6 + decimal=6, ) def _diag_sum(dataset): - """Calculate the sum of diagonal elements (i, i) - - """ + """Calculate the sum of diagonal elements (i, i)""" # Initialize the sum variable ds = 0.0 - for leg in dataset['data']: - for filename in dataset['data'][leg]: + for leg in dataset["data"]: + for filename in dataset["data"][leg]: u_nk = extract_u_nk(filename, T=300) # Calculate the sum of diagonal elements: for i, lambda_ in enumerate(u_nk.columns): - #18.6 is the time step - ds += u_nk.loc[i*186/10][lambda_].values[0] + # 18.6 is the time step + ds += u_nk.loc[i * 186 / 10][lambda_].values[0] return ds + def test_extract_u_nk_unit(): - '''Test if extract_u_nk assign the attr correctly''' + """Test if extract_u_nk assign the attr correctly""" dataset = load_benzene() - u_nk = extract_u_nk(dataset['data']['Coulomb'][0], 310) - assert u_nk.attrs['temperature'] == 310 - assert u_nk.attrs['energy_unit'] == 'kT' + u_nk = extract_u_nk(dataset["data"]["Coulomb"][0], 310) + assert u_nk.attrs["temperature"] == 310 + assert u_nk.attrs["energy_unit"] == "kT" + def test_extract_dHdl_unit(): - '''Test if extract_u_nk assign the attr correctly''' + """Test if extract_u_nk assign the attr correctly""" dataset = load_benzene() - dhdl = extract_dHdl(dataset['data']['Coulomb'][0], 310) - assert dhdl.attrs['temperature'] == 310 - assert dhdl.attrs['energy_unit'] == 'kT' + dhdl = extract_dHdl(dataset["data"]["Coulomb"][0], 310) + assert dhdl.attrs["temperature"] == 310 + assert dhdl.attrs["energy_unit"] == "kT" + def test_calling_extract(): - '''Test if the extract function is working''' + """Test if the extract function is working""" dataset = load_benzene() - df_dict = extract(dataset['data']['Coulomb'][0], 310) - assert df_dict['dHdl'].attrs['temperature'] == 310 - assert df_dict['dHdl'].attrs['energy_unit'] == 'kT' - assert df_dict['u_nk'].attrs['temperature'] == 310 - assert df_dict['u_nk'].attrs['energy_unit'] == 'kT' - -class TestRobustGMX(): - '''Test dropping the row that is wrong in different way''' + df_dict = extract(dataset["data"]["Coulomb"][0], 310) + assert df_dict["dHdl"].attrs["temperature"] == 310 + assert df_dict["dHdl"].attrs["energy_unit"] == "kT" + assert df_dict["u_nk"].attrs["temperature"] == 310 + assert df_dict["u_nk"].attrs["energy_unit"] == "kT" + + +class TestRobustGMX: + """Test dropping the row that is wrong in different way""" + @staticmethod - @pytest.fixture(scope='class') + @pytest.fixture(scope="class") def data(): - dhdl = extract_dHdl(load_benzene()['data']['Coulomb'][0], 310) - with bz2.open(load_benzene()['data']['Coulomb'][0], "rt") as bz_file: + dhdl = extract_dHdl(load_benzene()["data"]["Coulomb"][0], 310) + with bz2.open(load_benzene()["data"]["Coulomb"][0], "rt") as bz_file: text = bz_file.read() return text, len(dhdl) def test_sanity(self, data, tmp_path): - '''Test if the test routine is working.''' + """Test if the test routine is working.""" text, length = data - new_text = tmp_path / 'text.xvg' + new_text = tmp_path / "text.xvg" new_text.write_text(text) dhdl = extract_dHdl(new_text, 310) assert len(dhdl) == length def test_truncated_row(self, data, tmp_path): - '''Test the case where the last row has been truncated.''' + """Test the case where the last row has been truncated.""" text, length = data - new_text = tmp_path / 'text.xvg' - new_text.write_text(text + '40010.0 27.0\n') + new_text = tmp_path / "text.xvg" + new_text.write_text(text + "40010.0 27.0\n") dhdl = extract_dHdl(new_text, 310, filter=True) assert len(dhdl) == length def test_truncated_number(self, data, tmp_path): - '''Test the case where the last row has been truncated and a - has - been left.''' + """Test the case where the last row has been truncated and a - has + been left.""" text, length = data - new_text = tmp_path / 'text.xvg' - new_text.write_text(text + '40010.0 27.0 -\n') + new_text = tmp_path / "text.xvg" + new_text.write_text(text + "40010.0 27.0 -\n") dhdl = extract_dHdl(new_text, 310, filter=True) assert len(dhdl) == length def test_weirdnumber(self, data, tmp_path): - '''Test the case where the last number has been appended a weird - number.''' + """Test the case where the last number has been appended a weird + number.""" text, length = data - new_text = tmp_path / 'text.xvg' + new_text = tmp_path / "text.xvg" # Note the 27.040010.0 which is the sum of 27.0 and 40010.0 - new_text.write_text(text + '40010.0 27.040010.0 27.0 0.0 6.7 13.5 20.2 27.0 0.7 27.0 0.0 6.7 ' - '13.5 20.2 27.0 0.7\n') + new_text.write_text( + text + "40010.0 27.040010.0 27.0 0.0 6.7 13.5 20.2 27.0 0.7 27.0 0.0 6.7 " + "13.5 20.2 27.0 0.7\n" + ) dhdl = extract_dHdl(new_text, 310, filter=True) assert len(dhdl) == length def test_too_many_cols(self, data, tmp_path): - '''Test the case where the row has too many columns.''' + """Test the case where the row has too many columns.""" text, length = data - new_text = tmp_path / 'text.xvg' - new_text.write_text(text + - '40010.0 27.0 0.0 6.7 13.5 20.2 27.0 0.7 27.0 0.0 6.7 13.5 20.2 27.0 0.7\n') + new_text = tmp_path / "text.xvg" + new_text.write_text( + text + + "40010.0 27.0 0.0 6.7 13.5 20.2 27.0 0.7 27.0 0.0 6.7 13.5 20.2 27.0 0.7\n" + ) dhdl = extract_dHdl(new_text, 310, filter=True) assert len(dhdl) == length diff --git a/src/alchemlyb/tests/parsing/test_gomc.py b/src/alchemlyb/tests/parsing/test_gomc.py index d61b3789..241add45 100644 --- a/src/alchemlyb/tests/parsing/test_gomc.py +++ b/src/alchemlyb/tests/parsing/test_gomc.py @@ -2,43 +2,40 @@ """ -from alchemlyb.parsing.gomc import extract_dHdl, extract_u_nk, extract from alchemtest.gomc import load_benzene +from alchemlyb.parsing.gomc import extract_dHdl, extract_u_nk, extract + def test_dHdl(): - """Test that dHdl has the correct form when extracted from files. - - """ + """Test that dHdl has the correct form when extracted from files.""" dataset = load_benzene() - for filename in dataset['data']: + for filename in dataset["data"]: dHdl = extract_dHdl(filename, T=298) - assert dHdl.index.names == ['time', 'Coulomb-lambda', 'VDW-lambda'] + assert dHdl.index.names == ["time", "Coulomb-lambda", "VDW-lambda"] assert dHdl.shape == (1000, 2) -def test_u_nk(): - """Test that u_nk has the correct form when extracted from files. - """ +def test_u_nk(): + """Test that u_nk has the correct form when extracted from files.""" dataset = load_benzene() - for filename in dataset['data']: + for filename in dataset["data"]: u_nk = extract_u_nk(filename, T=298) - assert u_nk.index.names == ['time', 'Coulomb-lambda', 'VDW-lambda'] + assert u_nk.index.names == ["time", "Coulomb-lambda", "VDW-lambda"] assert u_nk.shape == (1000, 23) -def test_extract(): - """Test that u_nk and dHdl have the correct form when extracted from files. - """ +def test_extract(): + """Test that u_nk and dHdl have the correct form when extracted from files.""" dataset = load_benzene() - df_dict = extract(dataset['data'][0], T=298) + df_dict = extract(dataset["data"][0], T=298) - assert df_dict['u_nk'].index.names == ['time', 'Coulomb-lambda', 'VDW-lambda'] - assert df_dict['u_nk'].shape == (1000, 23) - assert df_dict['dHdl'].index.names == ['time', 'Coulomb-lambda', 'VDW-lambda'] - assert df_dict['dHdl'].shape == (1000, 2) + assert df_dict["u_nk"].index.names == ["time", "Coulomb-lambda", "VDW-lambda"] + assert df_dict["u_nk"].shape == (1000, 23) + assert df_dict["dHdl"].index.names == ["time", "Coulomb-lambda", "VDW-lambda"] + assert df_dict["dHdl"].shape == (1000, 2) diff --git a/src/alchemlyb/tests/parsing/test_namd.py b/src/alchemlyb/tests/parsing/test_namd.py index 8c4c1858..f5168e8f 100644 --- a/src/alchemlyb/tests/parsing/test_namd.py +++ b/src/alchemlyb/tests/parsing/test_namd.py @@ -1,16 +1,17 @@ """NAMD parser tests. """ +import bz2 from os.path import basename from re import search -import bz2 -import pytest -from alchemlyb.parsing.namd import extract_u_nk, extract -from alchemtest.namd import load_tyr2ala +import pytest from alchemtest.namd import load_idws from alchemtest.namd import load_restarted from alchemtest.namd import load_restarted_reversed +from alchemtest.namd import load_tyr2ala + +from alchemlyb.parsing.namd import extract_u_nk, extract # Indices of lambda values in the following line in NAMD fepout files: # #NEW FEP WINDOW: LAMBDA SET TO 0.6 LAMBDA2 0.7 LAMBDA_IDWS 0.5 @@ -27,27 +28,30 @@ def dataset(): return load_tyr2ala() -@pytest.mark.parametrize("direction,shape", - [('forward', (21021, 21)), - ('backward', (21021, 21)), - ]) + +@pytest.mark.parametrize( + "direction,shape", + [ + ("forward", (21021, 21)), + ("backward", (21021, 21)), + ], +) def test_u_nk(dataset, direction, shape): - """Test that u_nk has the correct form when extracted from files. - """ - for filename in dataset['data'][direction]: + """Test that u_nk has the correct form when extracted from files.""" + for filename in dataset["data"][direction]: u_nk = extract_u_nk(filename, T=300) - assert u_nk.index.names == ['time', 'fep-lambda'] + assert u_nk.index.names == ["time", "fep-lambda"] assert u_nk.shape == shape + def test_u_nk_idws(): - """Test that u_nk has the correct form when extracted from files. - """ + """Test that u_nk has the correct form when extracted from files.""" - filenames = load_idws()['data']['forward'] + filenames = load_idws()["data"]["forward"] u_nk = extract_u_nk(filenames, T=300) - assert u_nk.index.names == ['time', 'fep-lambda'] + assert u_nk.index.names == ["time", "fep-lambda"] assert u_nk.shape == (29252, 11) @@ -64,7 +68,7 @@ def _corrupt_fepout(fepout_in, params, tmp_path): ---------- fepout_in: str Path to fepout file to be modified. This file will not be overwritten. - + params: list of tuples For each tuple, the first element must be a str that will be passed to startswith() to identify the line(s) to modify (e.g. "#NEW"). The @@ -82,13 +86,17 @@ def _corrupt_fepout(fepout_in, params, tmp_path): """ fepout_out = tmp_path / basename(fepout_in) - with bz2.open(fepout_out, 'wt') as f_out: - with bz2.open(fepout_in, 'rt') as f_in: + with bz2.open(fepout_out, "wt") as f_out: + with bz2.open(fepout_in, "rt") as f_in: for line in f_in: for prefix, func in params: if line.startswith(prefix): tokens_out = func(line.split()) - line = ' '.join(tokens_out) + '\n' if tokens_out is not None else None + line = ( + " ".join(tokens_out) + "\n" + if tokens_out is not None + else None + ) if line is not None: f_out.write(line) return str(fepout_out) @@ -99,9 +107,10 @@ def restarted_dataset_inconsistent(restarted_dataset, tmp_path): """Returns intentionally messed up dataset where lambda1 and lambda2 at start and end of a window are different.""" - filenames = sorted(restarted_dataset['data']['both']) + filenames = sorted(restarted_dataset["data"]["both"]) changed = False + def func_free_line(l): nonlocal changed if float(l[7]) >= 0.7 and float(l[7]) < 0.9: @@ -110,13 +119,15 @@ def func_free_line(l): return l for i in range(len(filenames)): - filenames[i] = _corrupt_fepout(filenames[i], [('#Free', func_free_line)], tmp_path) + filenames[i] = _corrupt_fepout( + filenames[i], [("#Free", func_free_line)], tmp_path + ) # Only actually modify one window so we don't trigger the wrong exception if changed is True: break # Don't directly modify the glob object - restarted_dataset['data']['both'] = filenames + restarted_dataset["data"]["both"] = filenames return restarted_dataset @@ -129,26 +140,32 @@ def restarted_dataset_idws_without_lambda_idws(restarted_dataset, tmp_path): # First window won't have any IDWS data so we just drop all its files and fudge the lambdas # in the next window to include 0.0 or 1.0 (as appropriate) so we still have a nominally complete calculation - - filenames = [x for x in sorted(restarted_dataset['data']['both']) if search('000[a-z]?.fepout', x) is None] + + filenames = [ + x + for x in sorted(restarted_dataset["data"]["both"]) + if search("000[a-z]?.fepout", x) is None + ] def func_new_line(l): - if float(l[LAMBDA1_IDX_NEW]) > 0.5: # 1->0 (reversed) calculation - l[LAMBDA1_IDX_NEW] == '1.0' - else: # regular 0->1 calculation - l[LAMBDA1_IDX_NEW] = '0.0' + if float(l[LAMBDA1_IDX_NEW]) > 0.5: # 1->0 (reversed) calculation + l[LAMBDA1_IDX_NEW] == "1.0" + else: # regular 0->1 calculation + l[LAMBDA1_IDX_NEW] = "0.0" # Drop the lambda_idws return l[:9] - + def func_free_line(l): - if float(l[LAMBDA1_IDX_FREE]) > 0.5: # 1->0 (reversed) calculation - l[LAMBDA1_IDX_FREE] == '1.0' - else: # regular 0->1 calculation - l[LAMBDA1_IDX_FREE] = '0.0' + if float(l[LAMBDA1_IDX_FREE]) > 0.5: # 1->0 (reversed) calculation + l[LAMBDA1_IDX_FREE] == "1.0" + else: # regular 0->1 calculation + l[LAMBDA1_IDX_FREE] = "0.0" return l - - filenames[0] = _corrupt_fepout(filenames[0], [('#NEW', func_new_line), ('#Free', func_free_line)], tmp_path) - restarted_dataset['data']['both'] = filenames + + filenames[0] = _corrupt_fepout( + filenames[0], [("#NEW", func_new_line), ("#Free", func_free_line)], tmp_path + ) + restarted_dataset["data"]["both"] = filenames return restarted_dataset @@ -157,7 +174,7 @@ def restarted_dataset_toomany_lambda2(restarted_dataset, tmp_path): """Returns intentionally messed up dataset, where there are too many lambda2 values for a given lambda1.""" - filenames = sorted(restarted_dataset['data']['both']) + filenames = sorted(restarted_dataset["data"]["both"]) # For the same l1 and lidws we retain old lambda2 values thus ensuring a collision # Also, don't make a window where lambda1 >= lambda2 because this will trigger the @@ -165,22 +182,23 @@ def restarted_dataset_toomany_lambda2(restarted_dataset, tmp_path): def func_new_line(l): if float(l[LAMBDA2_IDX_NEW]) <= 0.2: return l - l[LAMBDA1_IDX_NEW] = '0.2' - if len(l) > 9 and l[9] == 'LAMBDA_IDWS': - l[LAMBDA_IDWS_IDX_NEW] = '0.1' + l[LAMBDA1_IDX_NEW] = "0.2" + if len(l) > 9 and l[9] == "LAMBDA_IDWS": + l[LAMBDA_IDWS_IDX_NEW] = "0.1" return l def func_free_line(l): if float(l[LAMBDA2_IDX_FREE]) <= 0.2: return l - l[LAMBDA1_IDX_FREE] = '0.2' + l[LAMBDA1_IDX_FREE] = "0.2" return l for i in range(len(filenames)): - filenames[i] = \ - _corrupt_fepout(filenames[i], [('#NEW', func_new_line), ('#Free', func_free_line)], tmp_path) + filenames[i] = _corrupt_fepout( + filenames[i], [("#NEW", func_new_line), ("#Free", func_free_line)], tmp_path + ) - restarted_dataset['data']['both'] = filenames + restarted_dataset["data"]["both"] = filenames return restarted_dataset @@ -189,7 +207,7 @@ def restarted_dataset_toomany_lambda_idws(restarted_dataset, tmp_path): """Returns intentionally messed up dataset, where there are too many lambda2 values for a given lambda1.""" - filenames = sorted(restarted_dataset['data']['both']) + filenames = sorted(restarted_dataset["data"]["both"]) # For the same lambda1 and lambda2 we retain the first set of lambda1/lambda2 values # and replicate them across all windows thus ensuring that there will be more than @@ -198,7 +216,7 @@ def restarted_dataset_toomany_lambda_idws(restarted_dataset, tmp_path): def func_new_line(l): nonlocal this_lambda1, this_lambda2 - + if this_lambda1 is None: this_lambda1, this_lambda2 = l[LAMBDA1_IDX_NEW], l[LAMBDA2_IDX_NEW] # Ensure that changing these lambda values won't cause a reversal in direction and trigger @@ -212,9 +230,11 @@ def func_free_line(l): return l for i in range(len(filenames)): - filenames[i] = _corrupt_fepout(filenames[i], [('#NEW', func_new_line)], tmp_path) + filenames[i] = _corrupt_fepout( + filenames[i], [("#NEW", func_new_line)], tmp_path + ) - restarted_dataset['data']['both'] = filenames + restarted_dataset["data"]["both"] = filenames return restarted_dataset @@ -222,7 +242,7 @@ def func_free_line(l): def restarted_dataset_direction_changed(restarted_dataset, tmp_path): """Returns intentionally messed up dataset, with one window where the lambda values are reversed.""" - filenames = sorted(restarted_dataset['data']['both']) + filenames = sorted(restarted_dataset["data"]["both"]) def func_new_line(l): l[6], l[8], l[10] = l[10], l[8], l[6] @@ -231,12 +251,16 @@ def func_new_line(l): def func_free_line(l): l[7], l[8] = l[8], l[7] return l - + # Reverse the direction of lambdas for this window idx_to_corrupt = filenames.index(sorted(filenames)[-3]) - fname1 = _corrupt_fepout(filenames[idx_to_corrupt], [('#NEW', func_new_line), ('#Free', func_free_line)], tmp_path) + fname1 = _corrupt_fepout( + filenames[idx_to_corrupt], + [("#NEW", func_new_line), ("#Free", func_free_line)], + tmp_path, + ) filenames[idx_to_corrupt] = fname1 - restarted_dataset['data']['both'] = filenames + restarted_dataset["data"]["both"] = filenames return restarted_dataset @@ -244,15 +268,17 @@ def func_free_line(l): def restarted_dataset_all_windows_truncated(restarted_dataset, tmp_path): """Returns dataset where all windows are truncated (no #Free... footer lines).""" - filenames = sorted(restarted_dataset['data']['both']) + filenames = sorted(restarted_dataset["data"]["both"]) def func_free_line(l): return None for i in range(len(filenames)): - filenames[i] = _corrupt_fepout(filenames[i], [('#Free', func_free_line)], tmp_path) - - restarted_dataset['data']['both'] = filenames + filenames[i] = _corrupt_fepout( + filenames[i], [("#Free", func_free_line)], tmp_path + ) + + restarted_dataset["data"]["both"] = filenames return restarted_dataset @@ -260,13 +286,15 @@ def func_free_line(l): def restarted_dataset_last_window_truncated(restarted_dataset, tmp_path): """Returns dataset where the last window is truncated (no #Free... footer line).""" - filenames = sorted(restarted_dataset['data']['both']) + filenames = sorted(restarted_dataset["data"]["both"]) def func_free_line(l): return None - filenames[-1] = _corrupt_fepout(filenames[-1], [('#Free', func_free_line)], tmp_path) - restarted_dataset['data']['both'] = filenames + filenames[-1] = _corrupt_fepout( + filenames[-1], [("#Free", func_free_line)], tmp_path + ) + restarted_dataset["data"]["both"] = filenames return restarted_dataset @@ -274,72 +302,91 @@ def test_u_nk_restarted(): """Test that u_nk has the correct form when extracted from an IDWS FEP run that includes terminations and restarts. """ - filenames = load_restarted()['data']['both'] + filenames = load_restarted()["data"]["both"] u_nk = extract_u_nk(filenames, T=300) - - assert u_nk.index.names == ['time', 'fep-lambda'] + + assert u_nk.index.names == ["time", "fep-lambda"] assert u_nk.shape == (30061, 11) def test_u_nk_restarted_missing_window_header(tmp_path): """Test that u_nk has the correct form when a #NEW line is missing from the restarted dataset and the parser has to infer lambda_idws for that window.""" - filenames = sorted(load_restarted()['data']['both']) + filenames = sorted(load_restarted()["data"]["both"]) # Remove "#NEW" line - filenames[4] = _corrupt_fepout(filenames[4], [('#NEW', lambda l: None),], tmp_path) + filenames[4] = _corrupt_fepout( + filenames[4], + [ + ("#NEW", lambda l: None), + ], + tmp_path, + ) u_nk = extract_u_nk(filenames, T=300) - - assert u_nk.index.names == ['time', 'fep-lambda'] + + assert u_nk.index.names == ["time", "fep-lambda"] assert u_nk.shape == (30061, 11) def test_u_nk_restarted_reversed(): - filenames = load_restarted_reversed()['data']['both'] + filenames = load_restarted_reversed()["data"]["both"] u_nk = extract_u_nk(filenames, T=300) - - assert u_nk.index.names == ['time', 'fep-lambda'] + + assert u_nk.index.names == ["time", "fep-lambda"] assert u_nk.shape == (30170, 11) def test_extract(): - filenames = load_restarted_reversed()['data']['both'] + filenames = load_restarted_reversed()["data"]["both"] df_dict = extract(filenames, T=300) - assert df_dict['u_nk'].index.names == ['time', 'fep-lambda'] - assert df_dict['u_nk'].shape == (30170, 11) - assert 'dHdl' not in df_dict + assert df_dict["u_nk"].index.names == ["time", "fep-lambda"] + assert df_dict["u_nk"].shape == (30170, 11) + assert "dHdl" not in df_dict def test_u_nk_restarted_reversed_missing_window_header(tmp_path): """Test that u_nk has the correct form when a #NEW line is missing from the restarted_reversed dataset and the parser has to infer lambda_idws for that window.""" - filenames = sorted(load_restarted_reversed()['data']['both']) + filenames = sorted(load_restarted_reversed()["data"]["both"]) # Remove "#NEW" line - filenames[4] = _corrupt_fepout(filenames[4], [('#NEW', lambda l: None),], tmp_path) + filenames[4] = _corrupt_fepout( + filenames[4], + [ + ("#NEW", lambda l: None), + ], + tmp_path, + ) u_nk = extract_u_nk(filenames, T=300) - - assert u_nk.index.names == ['time', 'fep-lambda'] + + assert u_nk.index.names == ["time", "fep-lambda"] assert u_nk.shape == (30170, 11) def test_u_nk_restarted_direction_changed(restarted_dataset_direction_changed): """Test that when lambda values change direction within a dataset, parsing throws an error.""" - with pytest.raises(ValueError, match='Lambda values change direction'): - u_nk = extract_u_nk(restarted_dataset_direction_changed['data']['both'], T=300) + with pytest.raises(ValueError, match="Lambda values change direction"): + u_nk = extract_u_nk(restarted_dataset_direction_changed["data"]["both"], T=300) -def test_u_nk_restarted_idws_without_lambda_idws(restarted_dataset_idws_without_lambda_idws): +def test_u_nk_restarted_idws_without_lambda_idws( + restarted_dataset_idws_without_lambda_idws, +): """Test that when the first window has IDWS data but no lambda_idws, parsing throws an error. - + In this situation, the lambda_idws cannot be inferred, because there's no previous lambda value available. """ - with pytest.raises(ValueError, match='IDWS data present in first window but lambda_idws not included'): - u_nk = extract_u_nk(restarted_dataset_idws_without_lambda_idws['data']['both'], T=300) + with pytest.raises( + ValueError, + match="IDWS data present in first window but lambda_idws not included", + ): + u_nk = extract_u_nk( + restarted_dataset_idws_without_lambda_idws["data"]["both"], T=300 + ) def test_u_nk_restarted_inconsistent(restarted_dataset_inconsistent): @@ -347,33 +394,45 @@ def test_u_nk_restarted_inconsistent(restarted_dataset_inconsistent): parsing throws an error. """ - with pytest.raises(ValueError, match='Inconsistent lambda values within the same window'): - u_nk = extract_u_nk(restarted_dataset_inconsistent['data']['both'], T=300) + with pytest.raises( + ValueError, match="Inconsistent lambda values within the same window" + ): + u_nk = extract_u_nk(restarted_dataset_inconsistent["data"]["both"], T=300) def test_u_nk_restarted_toomany_lambda_idws(restarted_dataset_toomany_lambda_idws): """Test that when there is more than one lambda_idws for a given lambda1, parsing throws an error.""" - with pytest.raises(ValueError, match='More than one lambda_idws value for a particular lambda1'): - u_nk = extract_u_nk(restarted_dataset_toomany_lambda_idws['data']['both'], T=300) + with pytest.raises( + ValueError, match="More than one lambda_idws value for a particular lambda1" + ): + u_nk = extract_u_nk( + restarted_dataset_toomany_lambda_idws["data"]["both"], T=300 + ) def test_u_nk_restarted_toomany_lambda2(restarted_dataset_toomany_lambda2): """Test that when there is more than one lambda2 for a given lambda1, parsing throws an error.""" - with pytest.raises(ValueError, match='More than one lambda2 value for a particular lambda1'): - u_nk = extract_u_nk(restarted_dataset_toomany_lambda2['data']['both'], T=300) + with pytest.raises( + ValueError, match="More than one lambda2 value for a particular lambda1" + ): + u_nk = extract_u_nk(restarted_dataset_toomany_lambda2["data"]["both"], T=300) def test_u_nk_restarted_all_windows_truncated(restarted_dataset_all_windows_truncated): """Test that when there is more than one lambda2 for a given lambda1, parsing throws an error.""" - with pytest.raises(ValueError, match='New window begun after truncated window'): - u_nk = extract_u_nk(restarted_dataset_all_windows_truncated['data']['both'], T=300) + with pytest.raises(ValueError, match="New window begun after truncated window"): + u_nk = extract_u_nk( + restarted_dataset_all_windows_truncated["data"]["both"], T=300 + ) def test_u_nk_restarted_last_window_truncated(restarted_dataset_last_window_truncated): """Test that when there is more than one lambda2 for a given lambda1, parsing throws an error.""" - with pytest.raises(ValueError, match='Last window is truncated'): - u_nk = extract_u_nk(restarted_dataset_last_window_truncated['data']['both'], T=300) + with pytest.raises(ValueError, match="Last window is truncated"): + u_nk = extract_u_nk( + restarted_dataset_last_window_truncated["data"]["both"], T=300 + ) diff --git a/src/alchemlyb/tests/parsing/test_util.py b/src/alchemlyb/tests/parsing/test_util.py index 334ee0c2..85107d61 100644 --- a/src/alchemlyb/tests/parsing/test_util.py +++ b/src/alchemlyb/tests/parsing/test_util.py @@ -1,32 +1,29 @@ import io -import pytest +import pytest from alchemtest.gmx import load_expanded_ensemble_case_1 + from alchemlyb.parsing.util import anyopen def test_gzip(): - """Test that gzip reads .gz files in the correct (text) mode. - - """ + """Test that gzip reads .gz files in the correct (text) mode.""" dataset = load_expanded_ensemble_case_1() - for leg in dataset['data']: - for filename in dataset['data'][leg]: - with anyopen(filename, 'r') as f: + for leg in dataset["data"]: + for filename in dataset["data"][leg]: + with anyopen(filename, "r") as f: assert type(f.readline()) is str def test_gzip_stream(): - """Test that `anyopen` reads streams with specified compression. - - """ + """Test that `anyopen` reads streams with specified compression.""" dataset = load_expanded_ensemble_case_1() - for leg in dataset['data']: - for filename in dataset['data'][leg]: - with open(filename, 'rb') as f: - with anyopen(f, mode='r', compression='gzip') as f_uc: + for leg in dataset["data"]: + for filename in dataset["data"][leg]: + with open(filename, "rb") as f: + with anyopen(f, mode="r", compression="gzip") as f_uc: assert type(f_uc.readline()) is str @@ -37,11 +34,11 @@ def test_gzip_stream_wrong(): """ dataset = load_expanded_ensemble_case_1() - for leg in dataset['data']: - for filename in dataset['data'][leg]: - with open(filename, 'rb') as f: - with anyopen(f, mode='r', compression='bzip2') as f_uc: - with pytest.raises(OSError, match='Invalid data stream'): + for leg in dataset["data"]: + for filename in dataset["data"][leg]: + with open(filename, "rb") as f: + with anyopen(f, mode="r", compression="bzip2") as f_uc: + with pytest.raises(OSError, match="Invalid data stream"): assert type(f_uc.readline()) is str @@ -52,33 +49,30 @@ def test_gzip_stream_wrong_no_compression(): """ dataset = load_expanded_ensemble_case_1() - for leg in dataset['data']: - for filename in dataset['data'][leg]: - with open(filename, 'rb') as f: - with anyopen(f, mode='r') as f_uc: + for leg in dataset["data"]: + for filename in dataset["data"][leg]: + with open(filename, "rb") as f: + with anyopen(f, mode="r") as f_uc: assert type(f_uc.readline()) is bytes -@pytest.mark.parametrize('extension', ['bz2', 'gz']) +@pytest.mark.parametrize("extension", ["bz2", "gz"]) def test_file_roundtrip(extension, tmp_path): - """Test that roundtripping write/read to a file works with `anyopen`. - - """ + """Test that roundtripping write/read to a file works with `anyopen`.""" data = "my momma told me to pick the very best one and you are not it" - filepath = tmp_path / f'testfile.txt.{extension}' - with anyopen(filepath, mode='w') as f: + filepath = tmp_path / f"testfile.txt.{extension}" + with anyopen(filepath, mode="w") as f: f.write(data) - with anyopen(filepath, 'r') as f: + with anyopen(filepath, "r") as f: data_out = f.read() assert data_out == data -@pytest.mark.parametrize('extension,compression', - [('bz2', 'gzip'), ('gz', 'bzip2')]) +@pytest.mark.parametrize("extension,compression", [("bz2", "gzip"), ("gz", "bzip2")]) def test_file_roundtrip_force_compression(extension, compression, tmp_path): """Test that roundtripping write/read to a file works with `anyopen`, in which we force compression despite different extension. @@ -87,50 +81,45 @@ def test_file_roundtrip_force_compression(extension, compression, tmp_path): data = "my momma told me to pick the very best one and you are not it" - filepath = tmp_path / f'testfile.txt.{extension}' - with anyopen(filepath, mode='w', compression=compression) as f: + filepath = tmp_path / f"testfile.txt.{extension}" + with anyopen(filepath, mode="w", compression=compression) as f: f.write(data) - with anyopen(filepath, 'r', compression=compression) as f: + with anyopen(filepath, "r", compression=compression) as f: data_out = f.read() assert data_out == data -@pytest.mark.parametrize('compression', ['bzip2', 'gzip']) +@pytest.mark.parametrize("compression", ["bzip2", "gzip"]) def test_stream_roundtrip(compression): - """Test that roundtripping write/read to a stream works with `anyopen` - - """ + """Test that roundtripping write/read to a stream works with `anyopen`""" data = "my momma told me to pick the very best one and you are not it" with io.BytesIO() as stream: - # write to stream - with anyopen(stream, mode='w', compression=compression) as f: + with anyopen(stream, mode="w", compression=compression) as f: f.write(data) # start at the beginning stream.seek(0) # read from stream - with anyopen(stream, 'r', compression=compression) as f: + with anyopen(stream, "r", compression=compression) as f: data_out = f.read() assert data_out == data -def test_stream_unsupported_compression(): - """Test that we throw a ValueError when an unsupported compression is used. - """ +def test_stream_unsupported_compression(): + """Test that we throw a ValueError when an unsupported compression is used.""" - compression="fakez" + compression = "fakez" data = b"my momma told me to pick the very best one and you are not it" with io.BytesIO() as stream: - # write to stream stream.write(data) @@ -139,5 +128,5 @@ def test_stream_unsupported_compression(): # read from stream with pytest.raises(ValueError): - with anyopen(stream, 'r', compression=compression) as f: + with anyopen(stream, "r", compression=compression) as f: data_out = f.read() diff --git a/src/alchemlyb/tests/test_convergence.py b/src/alchemlyb/tests/test_convergence.py index bae2b350..3a774c31 100644 --- a/src/alchemlyb/tests/test_convergence.py +++ b/src/alchemlyb/tests/test_convergence.py @@ -1,93 +1,137 @@ import numpy as np import pandas as pd import pytest - from alchemtest.gmx import load_benzene -from alchemlyb.parsing import gmx + from alchemlyb.convergence import forward_backward_convergence, fwdrev_cumavg_Rc, A_c from alchemlyb.convergence.convergence import _cummean +from alchemlyb.parsing import gmx @pytest.fixture() def gmx_benzene(): dataset = load_benzene() - return [gmx.extract_dHdl(dhdl, T=300) for dhdl in dataset['data']['Coulomb']], \ - [gmx.extract_u_nk(dhdl, T=300) for dhdl in dataset['data']['Coulomb']] + return [gmx.extract_dHdl(dhdl, T=300) for dhdl in dataset["data"]["Coulomb"]], [ + gmx.extract_u_nk(dhdl, T=300) for dhdl in dataset["data"]["Coulomb"] + ] + def test_convergence_ti(gmx_benzene): dHdl, u_nk = gmx_benzene - convergence = forward_backward_convergence(dHdl, 'TI') + convergence = forward_backward_convergence(dHdl, "TI") assert convergence.shape == (10, 5) - assert convergence.loc[0, 'Forward'] == pytest.approx(3.07, 0.01) - assert convergence.loc[0, 'Backward'] == pytest.approx(3.11, 0.01) - assert convergence.loc[9, 'Forward'] == pytest.approx(3.09, 0.01) - assert convergence.loc[9, 'Backward'] == pytest.approx(3.09, 0.01) + assert convergence.loc[0, "Forward"] == pytest.approx(3.07, 0.01) + assert convergence.loc[0, "Backward"] == pytest.approx(3.11, 0.01) + assert convergence.loc[9, "Forward"] == pytest.approx(3.09, 0.01) + assert convergence.loc[9, "Backward"] == pytest.approx(3.09, 0.01) + -@pytest.mark.parametrize('estimator', ['MBAR', 'BAR']) +@pytest.mark.parametrize("estimator", ["MBAR", "BAR"]) def test_convergence_fep(gmx_benzene, estimator): dHdl, u_nk = gmx_benzene convergence = forward_backward_convergence(u_nk, estimator) assert convergence.shape == (10, 5) - assert convergence.loc[0, 'Forward'] == pytest.approx(3.02, 0.01) - assert convergence.loc[0, 'Backward'] == pytest.approx(3.06, 0.01) - assert convergence.loc[9, 'Forward'] == pytest.approx(3.05, 0.01) - assert convergence.loc[9, 'Backward'] == pytest.approx(3.04, 0.01) + assert convergence.loc[0, "Forward"] == pytest.approx(3.02, 0.01) + assert convergence.loc[0, "Backward"] == pytest.approx(3.06, 0.01) + assert convergence.loc[9, "Forward"] == pytest.approx(3.05, 0.01) + assert convergence.loc[9, "Backward"] == pytest.approx(3.04, 0.01) + def test_convergence_wrong_estimator(gmx_benzene): dHdl, u_nk = gmx_benzene with pytest.raises(ValueError, match="is not available in"): - forward_backward_convergence(u_nk, 'WWW') + forward_backward_convergence(u_nk, "WWW") + def test_convergence_wrong_cases(gmx_benzene): dHdl, u_nk = gmx_benzene with pytest.warns(DeprecationWarning, match="Using lower-case strings for"): - forward_backward_convergence(u_nk, 'mbar') + forward_backward_convergence(u_nk, "mbar") + def test_convergence_method(gmx_benzene): dHdl, u_nk = gmx_benzene - convergence = forward_backward_convergence(u_nk, 'MBAR', num=2, method='adaptive') + convergence = forward_backward_convergence(u_nk, "MBAR", num=2, method="adaptive") assert len(convergence) == 2 + def test_cummean_short(): - '''Test the case where the input is shorter than the expected output''' + """Test the case where the input is shorter than the expected output""" value = _cummean(np.empty(10), 100) assert len(value) == 10 + def test_cummean_long(): - '''Test the case where the input is longer than the expected output''' + """Test the case where the input is longer than the expected output""" value = _cummean(np.empty(20), 10) assert len(value) == 10 + def test_cummean_long_none_integter(): - '''Test the case where the input is not a integer multiple of the expected output''' + """Test the case where the input is not a integer multiple of the expected output""" value = _cummean(np.empty(25), 10) assert len(value) == 10 + def test_R_c_converged(): - data = pd.Series(data=[0,]*100) - data.attrs['temperature'] = 310 - data.attrs['energy_unit'] = 'kcal/mol' + data = pd.Series( + data=[ + 0, + ] + * 100 + ) + data.attrs["temperature"] = 310 + data.attrs["energy_unit"] = "kcal/mol" value, running_average = fwdrev_cumavg_Rc(data) np.testing.assert_allclose(value, 0.0) + def test_R_c_notconverged(): data = pd.Series(data=range(21)) - data.attrs['temperature'] = 310 - data.attrs['energy_unit'] = 'kcal/mol' + data.attrs["temperature"] = 310 + data.attrs["energy_unit"] = "kcal/mol" value, running_average = fwdrev_cumavg_Rc(data, tol=0.1, precision=0.05) np.testing.assert_allclose(value, 1.0) + def test_R_c_real(): - data = pd.Series(data=np.hstack((range(10), [4.5,]*10))) - data.attrs['temperature'] = 310 - data.attrs['energy_unit'] = 'kcal/mol' + data = pd.Series( + data=np.hstack( + ( + range(10), + [ + 4.5, + ] + * 10, + ) + ) + ) + data.attrs["temperature"] = 310 + data.attrs["energy_unit"] = "kcal/mol" value, running_average = fwdrev_cumavg_Rc(data, tol=2.0) np.testing.assert_allclose(value, 0.35) + def test_A_c_real(): - data = pd.Series(data=np.hstack((range(10), [4.5,]*10))) - data.attrs['temperature'] = 310 - data.attrs['energy_unit'] = 'kcal/mol' - value = A_c([data, ] * 2, tol=2.0) + data = pd.Series( + data=np.hstack( + ( + range(10), + [ + 4.5, + ] + * 10, + ) + ) + ) + data.attrs["temperature"] = 310 + data.attrs["energy_unit"] = "kcal/mol" + value = A_c( + [ + data, + ] + * 2, + tol=2.0, + ) np.testing.assert_allclose(value, 0.65) diff --git a/src/alchemlyb/tests/test_fep_estimators.py b/src/alchemlyb/tests/test_fep_estimators.py index f61a1a43..d399c3d6 100644 --- a/src/alchemlyb/tests/test_fep_estimators.py +++ b/src/alchemlyb/tests/test_fep_estimators.py @@ -1,105 +1,136 @@ """Tests for all FEP-based estimators in ``alchemlyb``. """ -import pytest - -import numpy as np -import pandas as pd - -import alchemlyb -from alchemlyb.parsing import gmx, amber, namd, gomc -from alchemlyb.estimators import MBAR, BAR, AutoMBAR -import alchemtest.gmx import alchemtest.amber +import alchemtest.gmx import alchemtest.gomc import alchemtest.namd +import numpy as np +import pytest +from alchemtest.generic import load_MBAR_BGFS from alchemtest.gmx import load_benzene, load_ABFE + +import alchemlyb +from alchemlyb.estimators import MBAR, BAR, AutoMBAR +from alchemlyb.parsing import gmx, amber, namd, gomc from alchemlyb.parsing.gmx import extract_u_nk -from alchemtest.generic import load_MBAR_BGFS + def gmx_benzene_coul_u_nk(): dataset = alchemtest.gmx.load_benzene() - u_nk = alchemlyb.concat([gmx.extract_u_nk(filename, T=300) - for filename in dataset['data']['Coulomb']]) + u_nk = alchemlyb.concat( + [gmx.extract_u_nk(filename, T=300) for filename in dataset["data"]["Coulomb"]] + ) return u_nk + def gmx_benzene_vdw_u_nk(): dataset = alchemtest.gmx.load_benzene() - u_nk = alchemlyb.concat([gmx.extract_u_nk(filename, T=300) - for filename in dataset['data']['VDW']]) + u_nk = alchemlyb.concat( + [gmx.extract_u_nk(filename, T=300) for filename in dataset["data"]["VDW"]] + ) return u_nk + def gmx_expanded_ensemble_case_1(): dataset = alchemtest.gmx.load_expanded_ensemble_case_1() - u_nk = alchemlyb.concat([gmx.extract_u_nk(filename, T=300, filter=False) - for filename in dataset['data']['AllStates']]) + u_nk = alchemlyb.concat( + [ + gmx.extract_u_nk(filename, T=300, filter=False) + for filename in dataset["data"]["AllStates"] + ] + ) return u_nk + def gmx_expanded_ensemble_case_2(): dataset = alchemtest.gmx.load_expanded_ensemble_case_2() - u_nk = alchemlyb.concat([gmx.extract_u_nk(filename, T=300, filter=False) - for filename in dataset['data']['AllStates']]) + u_nk = alchemlyb.concat( + [ + gmx.extract_u_nk(filename, T=300, filter=False) + for filename in dataset["data"]["AllStates"] + ] + ) return u_nk + def gmx_expanded_ensemble_case_3(): dataset = alchemtest.gmx.load_expanded_ensemble_case_3() - u_nk = alchemlyb.concat([gmx.extract_u_nk(filename, T=300, filter=False) - for filename in dataset['data']['AllStates']]) + u_nk = alchemlyb.concat( + [ + gmx.extract_u_nk(filename, T=300, filter=False) + for filename in dataset["data"]["AllStates"] + ] + ) return u_nk + def gmx_water_particle_with_total_energy(): dataset = alchemtest.gmx.load_water_particle_with_total_energy() - u_nk = alchemlyb.concat([gmx.extract_u_nk(filename, T=300) - for filename in dataset['data']['AllStates']]) + u_nk = alchemlyb.concat( + [gmx.extract_u_nk(filename, T=300) for filename in dataset["data"]["AllStates"]] + ) return u_nk + def gmx_water_particle_with_potential_energy(): dataset = alchemtest.gmx.load_water_particle_with_potential_energy() - u_nk = alchemlyb.concat([gmx.extract_u_nk(filename, T=300) - for filename in dataset['data']['AllStates']]) + u_nk = alchemlyb.concat( + [gmx.extract_u_nk(filename, T=300) for filename in dataset["data"]["AllStates"]] + ) return u_nk + def gmx_water_particle_without_energy(): dataset = alchemtest.gmx.load_water_particle_without_energy() - u_nk = alchemlyb.concat([gmx.extract_u_nk(filename, T=300) - for filename in dataset['data']['AllStates']]) + u_nk = alchemlyb.concat( + [gmx.extract_u_nk(filename, T=300) for filename in dataset["data"]["AllStates"]] + ) return u_nk + def amber_bace_example_complex_vdw(): dataset = alchemtest.amber.load_bace_example() - u_nk = alchemlyb.concat([amber.extract_u_nk(filename, T=298.0) - for filename in dataset['data']['complex']['vdw']]) + u_nk = alchemlyb.concat( + [ + amber.extract_u_nk(filename, T=298.0) + for filename in dataset["data"]["complex"]["vdw"] + ] + ) return u_nk + def gomc_benzene_u_nk(): dataset = alchemtest.gomc.load_benzene() - u_nk = alchemlyb.concat([gomc.extract_u_nk(filename, T=298) - for filename in dataset['data']]) + u_nk = alchemlyb.concat( + [gomc.extract_u_nk(filename, T=298) for filename in dataset["data"]] + ) return u_nk + def namd_tyr2ala(): dataset = alchemtest.namd.load_tyr2ala() - u_nk1 = namd.extract_u_nk(dataset['data']['forward'][0], T=300) - u_nk2 = namd.extract_u_nk(dataset['data']['backward'][0], T=300) + u_nk1 = namd.extract_u_nk(dataset["data"]["forward"][0], T=300) + u_nk2 = namd.extract_u_nk(dataset["data"]["backward"][0], T=300) # combine dataframes of fwd and rev directions u_nk1[u_nk1.isna()] = u_nk2 @@ -107,29 +138,30 @@ def namd_tyr2ala(): return u_nk + def namd_idws(): dataset = alchemtest.namd.load_idws() - u_nk = namd.extract_u_nk(dataset['data']['forward'], T=300) + u_nk = namd.extract_u_nk(dataset["data"]["forward"], T=300) return u_nk + def namd_idws_restarted(): dataset = alchemtest.namd.load_restarted() - u_nk = namd.extract_u_nk(dataset['data']['both'], T=300) + u_nk = namd.extract_u_nk(dataset["data"]["both"], T=300) return u_nk + def namd_idws_restarted_reversed(): dataset = alchemtest.namd.load_restarted_reversed() - u_nk = namd.extract_u_nk(dataset['data']['both'], T=300) + u_nk = namd.extract_u_nk(dataset["data"]["both"], T=300) return u_nk class FEPestimatorMixin: - """Mixin for all FEP Estimator test classes. - - """ + """Mixin for all FEP Estimator test classes.""" def compare_delta_f(self, X_delta_f): est = self.cls().fit(X_delta_f[0]) @@ -145,23 +177,25 @@ def get_delta_f(self, est): class TestMBAR(FEPestimatorMixin): - """Tests for MBAR. + """Tests for MBAR.""" - """ cls = MBAR - @pytest.fixture(scope="class", - params=[(gmx_benzene_coul_u_nk, 3.041, 0.02088), - (gmx_benzene_vdw_u_nk, -3.007, 0.04519), - (gmx_expanded_ensemble_case_1, 75.923, 0.14124), - (gmx_expanded_ensemble_case_2, 75.915, 0.14372), - (gmx_expanded_ensemble_case_3, 76.173, 0.11345), - (gmx_water_particle_with_total_energy, -11.680, 0.083655), - (gmx_water_particle_with_potential_energy, -11.675, 0.083589), - (gmx_water_particle_without_energy, -11.654, 0.083415), - (amber_bace_example_complex_vdw, 2.41149, 0.0620658), - (gomc_benzene_u_nk, -0.79994, 0.091579), - ]) + @pytest.fixture( + scope="class", + params=[ + (gmx_benzene_coul_u_nk, 3.041, 0.02088), + (gmx_benzene_vdw_u_nk, -3.007, 0.04519), + (gmx_expanded_ensemble_case_1, 75.923, 0.14124), + (gmx_expanded_ensemble_case_2, 75.915, 0.14372), + (gmx_expanded_ensemble_case_3, 76.173, 0.11345), + (gmx_water_particle_with_total_energy, -11.680, 0.083655), + (gmx_water_particle_with_potential_energy, -11.675, 0.083589), + (gmx_water_particle_without_energy, -11.654, 0.083415), + (amber_bace_example_complex_vdw, 2.41149, 0.0620658), + (gomc_benzene_u_nk, -0.79994, 0.091579), + ], + ) def X_delta_f(self, request): get_unk, E, dE = request.param return get_unk(), E, dE @@ -169,54 +203,62 @@ def X_delta_f(self, request): def test_mbar(self, X_delta_f): self.compare_delta_f(X_delta_f) + class TestAutoMBAR(TestMBAR): cls = AutoMBAR -class TestMBAR_fail(): + +class TestMBAR_fail: @pytest.fixture(scope="class") def n_uk_list(self): - n_uk_list = [gmx.extract_u_nk(dhdl, T=300) for dhdl in - load_ABFE()['data']['complex']] + n_uk_list = [ + gmx.extract_u_nk(dhdl, T=300) for dhdl in load_ABFE()["data"]["complex"] + ] return n_uk_list def test_failback_adaptive(self, n_uk_list): # The hybr will fail on this while adaptive will work - mbar = AutoMBAR().fit(alchemlyb.concat([n_uk[:2] for n_uk in - n_uk_list])) - assert np.isclose(mbar.d_delta_f_[(0.0, 0.0, 0.0)][(1.0, 1.0, 1.0)], 1.76832, 0.1) + mbar = AutoMBAR().fit(alchemlyb.concat([n_uk[:2] for n_uk in n_uk_list])) + assert np.isclose( + mbar.d_delta_f_[(0.0, 0.0, 0.0)][(1.0, 1.0, 1.0)], 1.76832, 0.1 + ) + def test_AutoMBAR_BGFS(): # A case where only BFGS would work mbar = AutoMBAR() - u_nk = np.load(load_MBAR_BGFS()['data']['u_nk']) - N_k = np.load(load_MBAR_BGFS()['data']['N_k']) - solver_options = {"maximum_iterations": 10000,"verbose": False} + u_nk = np.load(load_MBAR_BGFS()["data"]["u_nk"]) + N_k = np.load(load_MBAR_BGFS()["data"]["N_k"]) + solver_options = {"maximum_iterations": 10000, "verbose": False} solver_protocol = {"method": None, "options": solver_options} mbar, out = mbar._do_MBAR(u_nk.T, N_k, solver_protocol) assert np.isclose(out[0][1][0], 12.552409, 0.1) + class TestBAR(FEPestimatorMixin): - """Tests for BAR. + """Tests for BAR.""" - """ cls = BAR - @pytest.fixture(scope="class", - params = [(gmx_benzene_coul_u_nk, 3.044, 0.01640), - (gmx_benzene_vdw_u_nk, -3.033, 0.03438), - (gmx_expanded_ensemble_case_1, 75.993, 0.11056), - (gmx_expanded_ensemble_case_2, 76.009, 0.11220), - (gmx_expanded_ensemble_case_3, 76.219, 0.08886), - (gmx_water_particle_with_total_energy, -11.675, 0.065055), - (gmx_water_particle_with_potential_energy, -11.724, 0.064964), - (gmx_water_particle_without_energy, -11.660, 0.064914), - (amber_bace_example_complex_vdw, 2.39294, 0.051192), - (namd_tyr2ala, 11.0044, 0.10235), - (namd_idws, 0.221147, 0.041003), - (namd_idws_restarted, 7.081127, 0.0344211), - (namd_idws_restarted_reversed, -4.18405, 0.03457), - (gomc_benzene_u_nk, -0.87095, 0.071263), - ]) + @pytest.fixture( + scope="class", + params=[ + (gmx_benzene_coul_u_nk, 3.044, 0.01640), + (gmx_benzene_vdw_u_nk, -3.033, 0.03438), + (gmx_expanded_ensemble_case_1, 75.993, 0.11056), + (gmx_expanded_ensemble_case_2, 76.009, 0.11220), + (gmx_expanded_ensemble_case_3, 76.219, 0.08886), + (gmx_water_particle_with_total_energy, -11.675, 0.065055), + (gmx_water_particle_with_potential_energy, -11.724, 0.064964), + (gmx_water_particle_without_energy, -11.660, 0.064914), + (amber_bace_example_complex_vdw, 2.39294, 0.051192), + (namd_tyr2ala, 11.0044, 0.10235), + (namd_idws, 0.221147, 0.041003), + (namd_idws_restarted, 7.081127, 0.0344211), + (namd_idws_restarted_reversed, -4.18405, 0.03457), + (gomc_benzene_u_nk, -0.87095, 0.071263), + ], + ) def X_delta_f(self, request): get_unk, E, dE = request.param return get_unk(), E, dE @@ -228,40 +270,44 @@ def get_delta_f(self, est): ee = 0.0 for i in range(len(est.d_delta_f_) - 1): - ee += est.d_delta_f_.values[i][i+1]**2 + ee += est.d_delta_f_.values[i][i + 1] ** 2 # Use .iloc[0, -1] as we want to cater for both # delta_f_.loc[0.0, 1.0] and delta_f_.loc[(0.0, 0.0), (0.0, 1.0)] return est.delta_f_.iloc[0, -1], ee**0.5 -class Test_Units(): - '''Test the units.''' + +class Test_Units: + """Test the units.""" @staticmethod - @pytest.fixture(scope='class') + @pytest.fixture(scope="class") def u_nk(): bz = load_benzene().data u_nk_coul = alchemlyb.concat( - [extract_u_nk(xvg, T=300) for xvg in bz['Coulomb']]) - u_nk_coul.attrs = extract_u_nk(load_benzene().data['Coulomb'][0], T=300).attrs + [extract_u_nk(xvg, T=300) for xvg in bz["Coulomb"]] + ) + u_nk_coul.attrs = extract_u_nk(load_benzene().data["Coulomb"][0], T=300).attrs return u_nk_coul def test_bar(self, u_nk): bar = BAR().fit(u_nk) - assert bar.delta_f_.attrs['temperature'] == 300 - assert bar.delta_f_.attrs['energy_unit'] == 'kT' - assert bar.d_delta_f_.attrs['temperature'] == 300 - assert bar.d_delta_f_.attrs['energy_unit'] == 'kT' + assert bar.delta_f_.attrs["temperature"] == 300 + assert bar.delta_f_.attrs["energy_unit"] == "kT" + assert bar.d_delta_f_.attrs["temperature"] == 300 + assert bar.d_delta_f_.attrs["energy_unit"] == "kT" def test_mbar(self, u_nk): mbar = MBAR().fit(u_nk) - assert mbar.delta_f_.attrs['temperature'] == 300 - assert mbar.delta_f_.attrs['energy_unit'] == 'kT' - assert mbar.d_delta_f_.attrs['temperature'] == 300 - assert mbar.d_delta_f_.attrs['energy_unit'] == 'kT' - -class TestEstimatorMixOut(): - '''Ensure that the attribute d_delta_f_, delta_f_, states_ cannot be - modified. ''' + assert mbar.delta_f_.attrs["temperature"] == 300 + assert mbar.delta_f_.attrs["energy_unit"] == "kT" + assert mbar.d_delta_f_.attrs["temperature"] == 300 + assert mbar.d_delta_f_.attrs["energy_unit"] == "kT" + + +class TestEstimatorMixOut: + """Ensure that the attribute d_delta_f_, delta_f_, states_ cannot be + modified.""" + @pytest.mark.parametrize("estimator", [MBAR, BAR]) def test_d_delta_f_(self, estimator): _estimator = estimator() diff --git a/src/alchemlyb/tests/test_import.py b/src/alchemlyb/tests/test_import.py index 50a5d933..ef467ec8 100644 --- a/src/alchemlyb/tests/test_import.py +++ b/src/alchemlyb/tests/test_import.py @@ -1,4 +1,5 @@ import alchemlyb + def test_name(): - assert alchemlyb.__name__ == 'alchemlyb' + assert alchemlyb.__name__ == "alchemlyb" diff --git a/src/alchemlyb/tests/test_preprocessing.py b/src/alchemlyb/tests/test_preprocessing.py index 22e54fdc..0d204cdb 100644 --- a/src/alchemlyb/tests/test_preprocessing.py +++ b/src/alchemlyb/tests/test_preprocessing.py @@ -11,15 +11,21 @@ import alchemlyb from alchemlyb.parsing import gmx, namd from alchemlyb.parsing.gmx import extract_u_nk, extract_dHdl -from alchemlyb.preprocessing import (slicing, statistical_inefficiency, - equilibrium_detection, - decorrelate_u_nk, decorrelate_dhdl, - u_nk2series, dhdl2series) +from alchemlyb.preprocessing import ( + slicing, + statistical_inefficiency, + equilibrium_detection, + decorrelate_u_nk, + decorrelate_dhdl, + u_nk2series, + dhdl2series, +) def gmx_benzene_dHdl(): dataset = alchemtest.gmx.load_benzene() - return gmx.extract_dHdl(dataset['data']['Coulomb'][0], T=300) + return gmx.extract_dHdl(dataset["data"]["Coulomb"][0], T=300) + # When issue #206 is addressed make the gmx_benzene_dHdl() function the # fixture, remove the wrapper below, and replace @@ -28,39 +34,48 @@ def gmx_benzene_dHdl(): def gmx_benzene_dHdl_fixture(): return gmx_benzene_dHdl() + @pytest.fixture() def gmx_ABFE(): dataset = alchemtest.gmx.load_ABFE() - return gmx.extract_u_nk(dataset['data']['complex'][0], T=300) + return gmx.extract_u_nk(dataset["data"]["complex"][0], T=300) + @pytest.fixture() def gmx_ABFE_dhdl(): dataset = alchemtest.gmx.load_ABFE() - return gmx.extract_dHdl(dataset['data']['complex'][0], T=300) + return gmx.extract_dHdl(dataset["data"]["complex"][0], T=300) + @pytest.fixture() def gmx_ABFE_u_nk(): dataset = alchemtest.gmx.load_ABFE() - return gmx.extract_u_nk(dataset['data']['complex'][-1], T=300) + return gmx.extract_u_nk(dataset["data"]["complex"][-1], T=300) + @pytest.fixture() def gmx_benzene_u_nk_fixture(): dataset = alchemtest.gmx.load_benzene() - return gmx.extract_u_nk(dataset['data']['Coulomb'][0], T=300) + return gmx.extract_u_nk(dataset["data"]["Coulomb"][0], T=300) + def gmx_benzene_u_nk(): dataset = alchemtest.gmx.load_benzene() - return gmx.extract_u_nk(dataset['data']['Coulomb'][0], T=300) + return gmx.extract_u_nk(dataset["data"]["Coulomb"][0], T=300) def gmx_benzene_dHdl_full(): dataset = alchemtest.gmx.load_benzene() - return alchemlyb.concat([gmx.extract_dHdl(i, T=300) for i in dataset['data']['Coulomb']]) + return alchemlyb.concat( + [gmx.extract_dHdl(i, T=300) for i in dataset["data"]["Coulomb"]] + ) def gmx_benzene_u_nk_full(): dataset = alchemtest.gmx.load_benzene() - return alchemlyb.concat([gmx.extract_u_nk(i, T=300) for i in dataset['data']['Coulomb']]) + return alchemlyb.concat( + [gmx.extract_u_nk(i, T=300) for i in dataset["data"]["Coulomb"]] + ) def _check_data_is_outside_bounds(data, lower, upper): @@ -70,40 +85,42 @@ def _check_data_is_outside_bounds(data, lower, upper): This is used by slicing tests to make sure that the data provided is appropriate for the tests. """ - assert any(data.reset_index()['time'] < lower) - assert any(data.reset_index()['time'] > upper) + assert any(data.reset_index()["time"] < lower) + assert any(data.reset_index()["time"] > upper) class TestSlicing: - """Test slicing functionality. + """Test slicing functionality.""" - """ def slicer(self, *args, **kwargs): return slicing(*args, **kwargs) - @pytest.mark.parametrize(('data', 'size'), [(gmx_benzene_dHdl(), 661), - (gmx_benzene_u_nk(), 661)]) + @pytest.mark.parametrize( + ("data", "size"), [(gmx_benzene_dHdl(), 661), (gmx_benzene_u_nk(), 661)] + ) def test_basic_slicing(self, data, size): assert len(self.slicer(data, lower=1000, upper=34000, step=5)) == size def test_unchanged(self): # NAMD energy files only have dE for adjacent lambdas, this ensures # that the slicer will not drop these rows as they have NaN values. - file = load_idws().data['forward'][0] + file = load_idws().data["forward"][0] u_nk = namd.extract_u_nk(file, 298) # Do the pre-processing as the u_nk are from all lambdas - groups = u_nk.groupby('fep-lambda') + groups = u_nk.groupby("fep-lambda") for key, group in groups: - group = group[~group.index.duplicated(keep='first')] + group = group[~group.index.duplicated(keep="first")] df = self.slicer(group, None, None, None) assert len(df) == len(group) - @pytest.mark.parametrize(('dataloader', 'lower', 'upper'), - [ - ('gmx_benzene_dHdl_fixture', 1000, 34000), - ('gmx_benzene_u_nk_fixture', 1000, 34000), - ]) + @pytest.mark.parametrize( + ("dataloader", "lower", "upper"), + [ + ("gmx_benzene_dHdl_fixture", 1000, 34000), + ("gmx_benzene_u_nk_fixture", 1000, 34000), + ], + ) def test_data_is_unchanged(self, dataloader, lower, upper, request): """ Test that slicing does not change the underlying data @@ -115,17 +132,16 @@ def test_data_is_unchanged(self, dataloader, lower, upper, request): # Slice data, and test that we didn't change the input data original_length = len(data) - sliced = self.slicer(data, - lower=lower, - upper=upper, - step=5) + sliced = self.slicer(data, lower=lower, upper=upper, step=5) assert len(data) == original_length - @pytest.mark.parametrize(('dataloader', 'lower', 'upper'), - [ - ('gmx_benzene_dHdl_fixture', 1000, 34000), - ('gmx_benzene_u_nk_fixture', 1000, 34000), - ]) + @pytest.mark.parametrize( + ("dataloader", "lower", "upper"), + [ + ("gmx_benzene_dHdl_fixture", 1000, 34000), + ("gmx_benzene_u_nk_fixture", 1000, 34000), + ], + ) def test_lower_and_upper_bound(self, dataloader, lower, upper, request): """ Test that the lower and upper time is respected @@ -137,19 +153,13 @@ def test_lower_and_upper_bound(self, dataloader, lower, upper, request): # Slice data, and test that we don't observe times outside # the prescribed range - sliced = self.slicer(data, - lower=lower, - upper=upper, - step=5) - assert all(sliced.reset_index()['time'] >= lower) - assert all(sliced.reset_index()['time'] <= upper) - - @pytest.mark.parametrize('data', [gmx_benzene_dHdl(), - gmx_benzene_u_nk()]) - def test_disordered_exception(self, data): - """Test that a shuffled DataFrame yields a KeyError. + sliced = self.slicer(data, lower=lower, upper=upper, step=5) + assert all(sliced.reset_index()["time"] >= lower) + assert all(sliced.reset_index()["time"] <= upper) - """ + @pytest.mark.parametrize("data", [gmx_benzene_dHdl(), gmx_benzene_u_nk()]) + def test_disordered_exception(self, data): + """Test that a shuffled DataFrame yields a KeyError.""" indices = data.index.values np.random.shuffle(indices) @@ -158,102 +168,85 @@ def test_disordered_exception(self, data): with pytest.raises(KeyError): self.slicer(df, lower=200) - @pytest.mark.parametrize('data', [gmx_benzene_dHdl_full(), - gmx_benzene_u_nk_full()]) + @pytest.mark.parametrize("data", [gmx_benzene_dHdl_full(), gmx_benzene_u_nk_full()]) def test_duplicated_exception(self, data): - """Test that a DataFrame with duplicate times yields a KeyError. - - """ + """Test that a DataFrame with duplicate times yields a KeyError.""" with pytest.raises(KeyError): self.slicer(data.sort_index(axis=0), lower=200) def test_subsample_bounds_and_step(self, gmx_ABFE): - """Make sure that slicing the series also works - """ - subsample = statistical_inefficiency(gmx_ABFE, - gmx_ABFE.sum(axis=1), - lower=100, - upper=400, - step=2) + """Make sure that slicing the series also works""" + subsample = statistical_inefficiency( + gmx_ABFE, gmx_ABFE.sum(axis=1), lower=100, upper=400, step=2 + ) assert len(subsample) == 76 def test_multiindex_duplicated(self, gmx_ABFE): - subsample = statistical_inefficiency(gmx_ABFE, - gmx_ABFE.sum(axis=1)) + subsample = statistical_inefficiency(gmx_ABFE, gmx_ABFE.sum(axis=1)) assert len(subsample) == 501 def test_sort_off(self, gmx_ABFE): unsorted = alchemlyb.concat([gmx_ABFE[-500:], gmx_ABFE[:500]]) with pytest.raises(KeyError): - statistical_inefficiency(unsorted, - unsorted.sum(axis=1), - sort=False) + statistical_inefficiency(unsorted, unsorted.sum(axis=1), sort=False) def test_sort_on(self, gmx_ABFE): unsorted = alchemlyb.concat([gmx_ABFE[-500:], gmx_ABFE[:500]]) - subsample = statistical_inefficiency(unsorted, - unsorted.sum(axis=1), - sort=True) - assert subsample.reset_index(0)['time'].is_monotonic_increasing + subsample = statistical_inefficiency(unsorted, unsorted.sum(axis=1), sort=True) + assert subsample.reset_index(0)["time"].is_monotonic_increasing def test_sort_on_noseries(self, gmx_ABFE): unsorted = alchemlyb.concat([gmx_ABFE[-500:], gmx_ABFE[:500]]) - subsample = statistical_inefficiency(unsorted, - None, - sort=True) - assert subsample.reset_index(0)['time'].is_monotonic_increasing + subsample = statistical_inefficiency(unsorted, None, sort=True) + assert subsample.reset_index(0)["time"].is_monotonic_increasing def test_duplication_off(self, gmx_ABFE): duplicated = alchemlyb.concat([gmx_ABFE, gmx_ABFE]) with pytest.raises(KeyError): - statistical_inefficiency(duplicated, - duplicated.sum(axis=1), - drop_duplicates=False) + statistical_inefficiency( + duplicated, duplicated.sum(axis=1), drop_duplicates=False + ) def test_duplication_on_dataframe(self, gmx_ABFE): duplicated = alchemlyb.concat([gmx_ABFE, gmx_ABFE]) - subsample = statistical_inefficiency(duplicated, - duplicated.sum(axis=1), - drop_duplicates=True) + subsample = statistical_inefficiency( + duplicated, duplicated.sum(axis=1), drop_duplicates=True + ) assert len(subsample) < 1000 def test_duplication_on_dataframe_noseries(self, gmx_ABFE): duplicated = alchemlyb.concat([gmx_ABFE, gmx_ABFE]) - subsample = statistical_inefficiency(duplicated, - None, - drop_duplicates=True) + subsample = statistical_inefficiency(duplicated, None, drop_duplicates=True) assert len(subsample) == 1001 def test_duplication_on_series(self, gmx_ABFE): duplicated = alchemlyb.concat([gmx_ABFE, gmx_ABFE]) - subsample = statistical_inefficiency(duplicated.sum(axis=1), - duplicated.sum(axis=1), - drop_duplicates=True) + subsample = statistical_inefficiency( + duplicated.sum(axis=1), duplicated.sum(axis=1), drop_duplicates=True + ) assert len(subsample) < 1000 def test_duplication_on_series_noseries(self, gmx_ABFE): duplicated = alchemlyb.concat([gmx_ABFE, gmx_ABFE]) - subsample = statistical_inefficiency(duplicated.sum(axis=1), - None, - drop_duplicates=True) + subsample = statistical_inefficiency( + duplicated.sum(axis=1), None, drop_duplicates=True + ) assert len(subsample) == 1001 -class CorrelatedPreprocessors: - @pytest.mark.parametrize(('data', 'size'), [(gmx_benzene_dHdl(), 4001), - (gmx_benzene_u_nk(), 4001)]) +class CorrelatedPreprocessors: + @pytest.mark.parametrize( + ("data", "size"), [(gmx_benzene_dHdl(), 4001), (gmx_benzene_u_nk(), 4001)] + ) def test_subsampling(self, data, size): """Basic test for execution; resulting size of dataset sensitive to machine and depends on algorithm. """ assert len(self.slicer(data, series=data.loc[:, data.columns[0]])) <= size - @pytest.mark.parametrize('data', [gmx_benzene_dHdl(), - gmx_benzene_u_nk()]) + @pytest.mark.parametrize("data", [gmx_benzene_dHdl(), gmx_benzene_u_nk()]) def test_no_series(self, data): - """Check that we get the same result as simple slicing with no Series. - - """ + """Check that we get the same result as simple slicing with no Series.""" df_sub = self.slicer(data, lower=200, upper=5000, step=2) df_sliced = slicing(data, lower=200, upper=5000, step=2) @@ -261,43 +254,53 @@ def test_no_series(self, data): class TestStatisticalInefficiency(TestSlicing, CorrelatedPreprocessors): - def slicer(self, *args, **kwargs): return statistical_inefficiency(*args, **kwargs) - @pytest.mark.parametrize(('conservative', 'data', 'size'), - [ - (True, gmx_benzene_dHdl(), 2001), # 0.00: g = 1.0559445620585415 - (True, gmx_benzene_u_nk(), 2001), # 'fep': g = 1.0560203916559594 - (False, gmx_benzene_dHdl(), 3789), - (False, gmx_benzene_u_nk(), 3571), - ]) + @pytest.mark.parametrize( + ("conservative", "data", "size"), + [ + (True, gmx_benzene_dHdl(), 2001), + # 0.00: g = 1.0559445620585415 + (True, gmx_benzene_u_nk(), 2001), + # 'fep': g = 1.0560203916559594 + (False, gmx_benzene_dHdl(), 3789), + (False, gmx_benzene_u_nk(), 3571), + ], + ) def test_conservative(self, data, size, conservative): - sliced = self.slicer(data, series=data.loc[:, data.columns[0]], conservative=conservative) + sliced = self.slicer( + data, series=data.loc[:, data.columns[0]], conservative=conservative + ) # results can vary slightly with different machines # so possibly do # delta = 10 # assert size - delta < len(sliced) < size + delta assert len(sliced) == size - @pytest.mark.parametrize('series', [ - gmx_benzene_dHdl()['fep'][:20], # wrong length - gmx_benzene_dHdl()['fep'][::-1], # wrong time stamps (reversed) - ]) + @pytest.mark.parametrize( + "series", + [ + gmx_benzene_dHdl()["fep"][:20], # wrong length + gmx_benzene_dHdl()["fep"][::-1], # wrong time stamps (reversed) + ], + ) def test_raise_ValueError_for_mismatched_data(self, series): data = gmx_benzene_dHdl() with pytest.raises(ValueError): self.slicer(data, series=series) - @pytest.mark.parametrize(('dataloader', 'lower', 'upper'), - [ - ('gmx_benzene_dHdl_fixture', 1000, 34000), - ('gmx_benzene_u_nk_fixture', 1000, 34000), - ]) - @pytest.mark.parametrize('use_series', [True, False]) - @pytest.mark.parametrize('conservative', [True, False]) + @pytest.mark.parametrize( + ("dataloader", "lower", "upper"), + [ + ("gmx_benzene_dHdl_fixture", 1000, 34000), + ("gmx_benzene_u_nk_fixture", 1000, 34000), + ], + ) + @pytest.mark.parametrize("use_series", [True, False]) + @pytest.mark.parametrize("conservative", [True, False]) def test_data_is_unchanged( - self, dataloader, use_series, lower, upper, conservative, request + self, dataloader, use_series, lower, upper, conservative, request ): """ Test that using statistical_inefficiency does not change the underlying data @@ -317,23 +320,27 @@ def test_data_is_unchanged( # Slice data, and test that we didn't change the input data original_length = len(data) - self.slicer(data, - series=series, - lower=lower, - upper=upper, - step=5, - conservative=conservative) + self.slicer( + data, + series=series, + lower=lower, + upper=upper, + step=5, + conservative=conservative, + ) assert len(data) == original_length - @pytest.mark.parametrize(('dataloader', 'lower', 'upper'), - [ - ('gmx_benzene_dHdl_fixture', 1000, 34000), - ('gmx_benzene_u_nk_fixture', 1000, 34000), - ]) - @pytest.mark.parametrize('use_series', [True, False]) - @pytest.mark.parametrize('conservative', [True, False]) + @pytest.mark.parametrize( + ("dataloader", "lower", "upper"), + [ + ("gmx_benzene_dHdl_fixture", 1000, 34000), + ("gmx_benzene_u_nk_fixture", 1000, 34000), + ], + ) + @pytest.mark.parametrize("use_series", [True, False]) + @pytest.mark.parametrize("conservative", [True, False]) def test_lower_and_upper_bound_slicer( - self, dataloader, use_series, lower, upper, conservative, request + self, dataloader, use_series, lower, upper, conservative, request ): """ Test that the lower and upper time is respected when using statistical_inefficiency @@ -353,23 +360,27 @@ def test_lower_and_upper_bound_slicer( # Slice data, and test that we don't observe times outside # the prescribed range - sliced = self.slicer(data, - series=series, - lower=lower, - upper=upper, - step=5, - conservative=conservative) - assert all(sliced.reset_index()['time'] >= lower) - assert all(sliced.reset_index()['time'] <= upper) - - @pytest.mark.parametrize(('dataloader', 'lower', 'upper'), - [ - ('gmx_benzene_dHdl_fixture', 1000, 34000), - ('gmx_benzene_u_nk_fixture', 1000, 34000), - ]) - @pytest.mark.parametrize('conservative', [True, False]) + sliced = self.slicer( + data, + series=series, + lower=lower, + upper=upper, + step=5, + conservative=conservative, + ) + assert all(sliced.reset_index()["time"] >= lower) + assert all(sliced.reset_index()["time"] <= upper) + + @pytest.mark.parametrize( + ("dataloader", "lower", "upper"), + [ + ("gmx_benzene_dHdl_fixture", 1000, 34000), + ("gmx_benzene_u_nk_fixture", 1000, 34000), + ], + ) + @pytest.mark.parametrize("conservative", [True, False]) def test_slicing_inefficiency_equivalence( - self, dataloader, lower, upper, conservative, request + self, dataloader, lower, upper, conservative, request ): """ Test that first slicing the data frame, then subsampling is equivalent to @@ -382,143 +393,201 @@ def test_slicing_inefficiency_equivalence( # Slice dataframe, then subsample it based on the sum of its components sliced_data = slicing(data, lower=lower, upper=upper) - subsampled_sliced_data = self.slicer(sliced_data, - series=sliced_data.sum(axis=1), - conservative=conservative) + subsampled_sliced_data = self.slicer( + sliced_data, series=sliced_data.sum(axis=1), conservative=conservative + ) # Subsample the dataframe based on the sum of its components while # also specifying the slicing range - subsampled_data = self.slicer(data, - series=data.sum(axis=1), - lower=lower, - upper=upper, - conservative=conservative) + subsampled_data = self.slicer( + data, + series=data.sum(axis=1), + lower=lower, + upper=upper, + conservative=conservative, + ) assert (subsampled_sliced_data == subsampled_data).all(axis=None) class TestEquilibriumDetection(TestSlicing, CorrelatedPreprocessors): - def slicer(self, *args, **kwargs): return equilibrium_detection(*args, **kwargs) -class Test_Units(): - '''Test the preprocessing module.''' + +class Test_Units: + """Test the preprocessing module.""" + @staticmethod - @pytest.fixture(scope='class') + @pytest.fixture(scope="class") def dhdl(): dataset = load_benzene() - dhdl = extract_dHdl(dataset['data']['Coulomb'][0], 310) + dhdl = extract_dHdl(dataset["data"]["Coulomb"][0], 310) return dhdl def test_slicing(self, dhdl): - '''Test if extract_u_nk assign the attr correctly''' + """Test if extract_u_nk assign the attr correctly""" dataset = load_benzene() - u_nk = extract_u_nk(dataset['data']['Coulomb'][0], 310) + u_nk = extract_u_nk(dataset["data"]["Coulomb"][0], 310) new_u_nk = slicing(u_nk) - assert new_u_nk.attrs['temperature'] == 310 - assert new_u_nk.attrs['energy_unit'] == 'kT' + assert new_u_nk.attrs["temperature"] == 310 + assert new_u_nk.attrs["energy_unit"] == "kT" def test_statistical_inefficiency(self, dhdl): - '''Test if extract_u_nk assign the attr correctly''' + """Test if extract_u_nk assign the attr correctly""" dataset = load_benzene() - dhdl = extract_dHdl(dataset['data']['Coulomb'][0], 310) + dhdl = extract_dHdl(dataset["data"]["Coulomb"][0], 310) new_dhdl = statistical_inefficiency(dhdl) - assert new_dhdl.attrs['temperature'] == 310 - assert new_dhdl.attrs['energy_unit'] == 'kT' + assert new_dhdl.attrs["temperature"] == 310 + assert new_dhdl.attrs["energy_unit"] == "kT" def test_equilibrium_detection(self, dhdl): - '''Test if extract_u_nk assign the attr correctly''' + """Test if extract_u_nk assign the attr correctly""" dataset = load_benzene() - dhdl = extract_dHdl(dataset['data']['Coulomb'][0], 310) + dhdl = extract_dHdl(dataset["data"]["Coulomb"][0], 310) new_dhdl = equilibrium_detection(dhdl) - assert new_dhdl.attrs['temperature'] == 310 - assert new_dhdl.attrs['energy_unit'] == 'kT' + assert new_dhdl.attrs["temperature"] == 310 + assert new_dhdl.attrs["energy_unit"] == "kT" + -@pytest.mark.parametrize(('method', 'size'), [('all', 2001), - ('dE', 2001)]) +@pytest.mark.parametrize(("method", "size"), [("all", 2001), ("dE", 2001)]) def test_decorrelate_u_nk_single_l(gmx_benzene_u_nk_fixture, method, size): - assert len(decorrelate_u_nk(gmx_benzene_u_nk_fixture, method=method, - drop_duplicates=True, - sort=True)) == size + assert ( + len( + decorrelate_u_nk( + gmx_benzene_u_nk_fixture, method=method, drop_duplicates=True, sort=True + ) + ) + == size + ) + def test_decorrelate_u_nk_burnin(gmx_benzene_u_nk_fixture): - assert len(decorrelate_u_nk(gmx_benzene_u_nk_fixture, method='dE', - drop_duplicates=True, - sort=True, remove_burnin=True)) == 2849 + assert ( + len( + decorrelate_u_nk( + gmx_benzene_u_nk_fixture, + method="dE", + drop_duplicates=True, + sort=True, + remove_burnin=True, + ) + ) + == 2849 + ) -def test_decorrelate_dhdl_burnin(gmx_benzene_dHdl_fixture): - assert len(decorrelate_dhdl(gmx_benzene_dHdl_fixture, - drop_duplicates=True, - sort=True, remove_burnin=True)) == 2848 -@pytest.mark.parametrize(('method', 'size'), [('all', 1001), - ('dE', 334)]) +def test_decorrelate_dhdl_burnin(gmx_benzene_dHdl_fixture): + assert ( + len( + decorrelate_dhdl( + gmx_benzene_dHdl_fixture, + drop_duplicates=True, + sort=True, + remove_burnin=True, + ) + ) + == 2848 + ) + + +@pytest.mark.parametrize(("method", "size"), [("all", 1001), ("dE", 334)]) def test_decorrelate_u_nk_multiple_l(gmx_ABFE_u_nk, method, size): - assert len(decorrelate_u_nk(gmx_ABFE_u_nk, method=method,)) == size + assert ( + len( + decorrelate_u_nk( + gmx_ABFE_u_nk, + method=method, + ) + ) + == size + ) + def test_decorrelate_dhdl_single_l(gmx_benzene_u_nk_fixture): - assert len(decorrelate_dhdl(gmx_benzene_u_nk_fixture, drop_duplicates=True, - sort=True)) == 2001 + assert ( + len(decorrelate_dhdl(gmx_benzene_u_nk_fixture, drop_duplicates=True, sort=True)) + == 2001 + ) + def test_decorrelate_dhdl_multiple_l(gmx_ABFE_dhdl): - assert len(decorrelate_dhdl(gmx_ABFE_dhdl,)) == 501 + assert ( + len( + decorrelate_dhdl( + gmx_ABFE_dhdl, + ) + ) + == 501 + ) + def test_raise_non_uk(gmx_ABFE_dhdl): with pytest.raises(ValueError): - decorrelate_u_nk(gmx_ABFE_dhdl, ) + decorrelate_u_nk( + gmx_ABFE_dhdl, + ) -class TestDhdl2series(): + +class TestDhdl2series: @staticmethod - @pytest.fixture(scope='class') + @pytest.fixture(scope="class") def dhdl(): dataset = load_benzene() - dhdl = extract_dHdl(dataset['data']['Coulomb'][0], 300) + dhdl = extract_dHdl(dataset["data"]["Coulomb"][0], 300) return dhdl - @pytest.mark.parametrize("methodargs", [{}, {'method': 'all'}]) + @pytest.mark.parametrize("methodargs", [{}, {"method": "all"}]) def test_dhdl2series(self, dhdl, methodargs): series = dhdl2series(dhdl, **methodargs) assert len(series) == len(dhdl) assert_allclose(series, dhdl.sum(axis=1)) def test_other_method_ValueError(self, dhdl): - with pytest.raises(ValueError, - match="Only method='all' is supported for dhdl2series()."): + with pytest.raises( + ValueError, match="Only method='all' is supported for dhdl2series()." + ): dhdl2series(dhdl, method="dE") -class TestU_nk2series(): + +class TestU_nk2series: @staticmethod - @pytest.fixture(scope='class') + @pytest.fixture(scope="class") def u_nk(): dataset = load_benzene() - u_nk = extract_u_nk(dataset['data']['Coulomb'][0], 300) + u_nk = extract_u_nk(dataset["data"]["Coulomb"][0], 300) return u_nk - @pytest.mark.parametrize("methodargs,reference", # reference = sum - [({}, 9207.80229000283), - ({'method': 'all'}, 85982.34668751864), - ({'method': 'dE'}, 9207.80229000283), - ]) + @pytest.mark.parametrize( + "methodargs,reference", # reference = sum + [ + ({}, 9207.80229000283), + ({"method": "all"}, 85982.34668751864), + ({"method": "dE"}, 9207.80229000283), + ], + ) def test_u_nk2series(self, u_nk, methodargs, reference): series = u_nk2series(u_nk, **methodargs) assert len(series) == len(u_nk) assert_allclose(series.sum(), reference) - @pytest.mark.parametrize("methodargs,reference", # reference = sum - [({'method': 'dhdl_all'}, 85982.34668751864), - ({'method': 'dhdl'}, 9207.80229000283), - ]) + @pytest.mark.parametrize( + "methodargs,reference", # reference = sum + [ + ({"method": "dhdl_all"}, 85982.34668751864), + ({"method": "dhdl"}, 9207.80229000283), + ], + ) def test_u_nk2series_deprecated(self, u_nk, methodargs, reference): - with pytest.warns(DeprecationWarning, - match=r"Method 'dhdl.*' has been deprecated, using '.*' instead\. " - r"'dhdl.*' will be removed in alchemlyb 3\.0\.0\."): + with pytest.warns( + DeprecationWarning, + match=r"Method 'dhdl.*' has been deprecated, using '.*' instead\. " + r"'dhdl.*' will be removed in alchemlyb 3\.0\.0\.", + ): series = u_nk2series(u_nk, **methodargs) assert len(series) == len(u_nk) assert_allclose(series.sum(), reference) - def test_other_method_ValueError(self, u_nk): - with pytest.raises(ValueError, - match='Decorrelation method bogus not found.'): + with pytest.raises(ValueError, match="Decorrelation method bogus not found."): u_nk2series(u_nk, method="bogus") diff --git a/src/alchemlyb/tests/test_ti_estimators.py b/src/alchemlyb/tests/test_ti_estimators.py index 38b623a8..b510a09b 100644 --- a/src/alchemlyb/tests/test_ti_estimators.py +++ b/src/alchemlyb/tests/test_ti_estimators.py @@ -1,111 +1,142 @@ """Tests for all TI-based estimators in ``alchemlyb``. """ -import pytest - +import alchemtest.amber +import alchemtest.gmx +import alchemtest.gomc import pandas as pd +import pytest +from alchemtest.gmx import load_benzene, load_ABFE import alchemlyb -from alchemlyb.parsing import gmx, amber, gomc from alchemlyb.estimators import TI -import alchemtest.gmx -import alchemtest.amber -import alchemtest.gomc -from alchemtest.gmx import load_benzene, load_ABFE +from alchemlyb.parsing import gmx, amber, gomc from alchemlyb.parsing.gmx import extract_dHdl def gmx_benzene_coul_dHdl(): dataset = alchemtest.gmx.load_benzene() - dHdl = alchemlyb.concat([gmx.extract_dHdl(filename, T=300) - for filename in dataset['data']['Coulomb']]) + dHdl = alchemlyb.concat( + [gmx.extract_dHdl(filename, T=300) for filename in dataset["data"]["Coulomb"]] + ) return dHdl + def gmx_benzene_vdw_dHdl(): dataset = alchemtest.gmx.load_benzene() - dHdl = alchemlyb.concat([gmx.extract_dHdl(filename, T=300) - for filename in dataset['data']['VDW']]) + dHdl = alchemlyb.concat( + [gmx.extract_dHdl(filename, T=300) for filename in dataset["data"]["VDW"]] + ) return dHdl + def gmx_expanded_ensemble_case_1_dHdl(): dataset = alchemtest.gmx.load_expanded_ensemble_case_1() - dHdl = alchemlyb.concat([gmx.extract_dHdl(filename, T=300, filter=False) - for filename in dataset['data']['AllStates']]) + dHdl = alchemlyb.concat( + [ + gmx.extract_dHdl(filename, T=300, filter=False) + for filename in dataset["data"]["AllStates"] + ] + ) return dHdl + def gmx_expanded_ensemble_case_2_dHdl(): dataset = alchemtest.gmx.load_expanded_ensemble_case_2() - dHdl = alchemlyb.concat([gmx.extract_dHdl(filename, T=300, filter=False) - for filename in dataset['data']['AllStates']]) + dHdl = alchemlyb.concat( + [ + gmx.extract_dHdl(filename, T=300, filter=False) + for filename in dataset["data"]["AllStates"] + ] + ) return dHdl + def gmx_expanded_ensemble_case_3_dHdl(): dataset = alchemtest.gmx.load_expanded_ensemble_case_3() - dHdl = alchemlyb.concat([gmx.extract_dHdl(filename, T=300, filter=False) - for filename in dataset['data']['AllStates']]) + dHdl = alchemlyb.concat( + [ + gmx.extract_dHdl(filename, T=300, filter=False) + for filename in dataset["data"]["AllStates"] + ] + ) return dHdl + def gmx_water_particle_with_total_energy_dHdl(): dataset = alchemtest.gmx.load_water_particle_with_total_energy() - dHdl = alchemlyb.concat([gmx.extract_dHdl(filename, T=300) - for filename in dataset['data']['AllStates']]) + dHdl = alchemlyb.concat( + [gmx.extract_dHdl(filename, T=300) for filename in dataset["data"]["AllStates"]] + ) return dHdl + def gmx_water_particle_with_potential_energy_dHdl(): dataset = alchemtest.gmx.load_water_particle_with_potential_energy() - dHdl = alchemlyb.concat([gmx.extract_dHdl(filename, T=300) - for filename in dataset['data']['AllStates']]) + dHdl = alchemlyb.concat( + [gmx.extract_dHdl(filename, T=300) for filename in dataset["data"]["AllStates"]] + ) return dHdl + def gmx_water_particle_without_energy_dHdl(): dataset = alchemtest.gmx.load_water_particle_without_energy() - dHdl = alchemlyb.concat([gmx.extract_dHdl(filename, T=300) - for filename in dataset['data']['AllStates']]) + dHdl = alchemlyb.concat( + [gmx.extract_dHdl(filename, T=300) for filename in dataset["data"]["AllStates"]] + ) return dHdl + def amber_simplesolvated_charge_dHdl(): dataset = alchemtest.amber.load_simplesolvated() - dHdl = alchemlyb.concat([amber.extract_dHdl(filename, T=298.0) - for filename in dataset['data']['charge']]) + dHdl = alchemlyb.concat( + [ + amber.extract_dHdl(filename, T=298.0) + for filename in dataset["data"]["charge"] + ] + ) return dHdl + def amber_simplesolvated_vdw_dHdl(): dataset = alchemtest.amber.load_simplesolvated() - dHdl = alchemlyb.concat([amber.extract_dHdl(filename, T=298.0) - for filename in dataset['data']['vdw']]) + dHdl = alchemlyb.concat( + [amber.extract_dHdl(filename, T=298.0) for filename in dataset["data"]["vdw"]] + ) return dHdl + def gomc_benzene_dHdl(): dataset = alchemtest.gomc.load_benzene() - dHdl = alchemlyb.concat([gomc.extract_dHdl(filename, T=298) - for filename in dataset['data']]) + dHdl = alchemlyb.concat( + [gomc.extract_dHdl(filename, T=298) for filename in dataset["data"]] + ) return dHdl class TIestimatorMixin: - def test_get_delta_f(self, X_delta_f): dHdl, E, dE = X_delta_f est = self.cls().fit(dHdl) @@ -117,93 +148,120 @@ def test_get_delta_f(self, X_delta_f): assert E == pytest.approx(delta_f, rel=1e-3) assert dE == pytest.approx(d_delta_f, rel=1e-3) + class TestTI(TIestimatorMixin): - """Tests for TI. + """Tests for TI.""" - """ cls = TI T = 298.0 kT_amber = amber.k_b * T - @pytest.fixture(scope="class", - params = [(gmx_benzene_coul_dHdl, 3.089, 0.02157), - (gmx_benzene_vdw_dHdl, -3.056, 0.04863), - (gmx_expanded_ensemble_case_1_dHdl, 76.220, 0.15568), - (gmx_expanded_ensemble_case_2_dHdl, 76.247, 0.15889), - (gmx_expanded_ensemble_case_3_dHdl, 76.387, 0.12532), - (gmx_water_particle_with_total_energy_dHdl, -11.696, 0.091775), - (gmx_water_particle_with_potential_energy_dHdl, -11.751, 0.091149), - (gmx_water_particle_without_energy_dHdl, -11.687, 0.091604), - (amber_simplesolvated_charge_dHdl, -60.114/kT_amber, 0.08186/kT_amber), - (amber_simplesolvated_vdw_dHdl, 3.824/kT_amber, 0.13254/kT_amber), - ]) + @pytest.fixture( + scope="class", + params=[ + (gmx_benzene_coul_dHdl, 3.089, 0.02157), + (gmx_benzene_vdw_dHdl, -3.056, 0.04863), + (gmx_expanded_ensemble_case_1_dHdl, 76.220, 0.15568), + (gmx_expanded_ensemble_case_2_dHdl, 76.247, 0.15889), + (gmx_expanded_ensemble_case_3_dHdl, 76.387, 0.12532), + (gmx_water_particle_with_total_energy_dHdl, -11.696, 0.091775), + (gmx_water_particle_with_potential_energy_dHdl, -11.751, 0.091149), + (gmx_water_particle_without_energy_dHdl, -11.687, 0.091604), + (amber_simplesolvated_charge_dHdl, -60.114 / kT_amber, 0.08186 / kT_amber), + (amber_simplesolvated_vdw_dHdl, 3.824 / kT_amber, 0.13254 / kT_amber), + ], + ) def X_delta_f(self, request): get_dHdl, E, dE = request.param return get_dHdl(), E, dE + def test_TI_separate_dhdl_multiple_column(): dHdl = gomc_benzene_dHdl() estimator = TI().fit(dHdl) assert all([isinstance(dhdl, pd.Series) for dhdl in estimator.separate_dhdl()]) assert sorted([len(dhdl) for dhdl in estimator.separate_dhdl()]) == [8, 16] + def test_TI_separate_dhdl_single_column(): dHdl = gmx_benzene_coul_dHdl() estimator = TI().fit(dHdl) assert all([isinstance(dhdl, pd.Series) for dhdl in estimator.separate_dhdl()]) - assert [len(dhdl) for dhdl in estimator.separate_dhdl()] == [5, ] + assert [len(dhdl) for dhdl in estimator.separate_dhdl()] == [ + 5, + ] + def test_TI_separate_dhdl_no_pertubed(): - '''The test for the case where two lambda are there and one is not pertubed''' + """The test for the case where two lambda are there and one is not pertubed""" dHdl = gmx_benzene_coul_dHdl() - dHdl.insert(1, 'bound-lambda', [1.0, ] * len(dHdl)) - dHdl.insert(1, 'bound', [1.0, ] * len(dHdl)) - dHdl.set_index('bound-lambda', append=True, inplace=True) + dHdl.insert( + 1, + "bound-lambda", + [ + 1.0, + ] + * len(dHdl), + ) + dHdl.insert( + 1, + "bound", + [ + 1.0, + ] + * len(dHdl), + ) + dHdl.set_index("bound-lambda", append=True, inplace=True) estimator = TI().fit(dHdl) assert all([isinstance(dhdl, pd.Series) for dhdl in estimator.separate_dhdl()]) - assert [len(dhdl) for dhdl in estimator.separate_dhdl()] == [5, ] + assert [len(dhdl) for dhdl in estimator.separate_dhdl()] == [ + 5, + ] + + +class Test_Units: + """Test the units.""" -class Test_Units(): - '''Test the units.''' @staticmethod - @pytest.fixture(scope='class') + @pytest.fixture(scope="class") def dhdl(): bz = load_benzene().data dHdl_coul = alchemlyb.concat( - [extract_dHdl(xvg, T=300) for xvg in bz['Coulomb']]) + [extract_dHdl(xvg, T=300) for xvg in bz["Coulomb"]] + ) return dHdl_coul def test_ti(self, dhdl): ti = TI().fit(dhdl) - assert ti.delta_f_.attrs['temperature'] == 300 - assert ti.delta_f_.attrs['energy_unit'] == 'kT' - assert ti.d_delta_f_.attrs['temperature'] == 300 - assert ti.d_delta_f_.attrs['energy_unit'] == 'kT' - assert ti.dhdl.attrs['temperature'] == 300 - assert ti.dhdl.attrs['energy_unit'] == 'kT' + assert ti.delta_f_.attrs["temperature"] == 300 + assert ti.delta_f_.attrs["energy_unit"] == "kT" + assert ti.d_delta_f_.attrs["temperature"] == 300 + assert ti.d_delta_f_.attrs["energy_unit"] == "kT" + assert ti.dhdl.attrs["temperature"] == 300 + assert ti.dhdl.attrs["energy_unit"] == "kT" def test_ti_separate_dhdl(self, dhdl): ti = TI().fit(dhdl) dhdl_list = ti.separate_dhdl() for dhdl in dhdl_list: - assert dhdl.attrs['temperature'] == 300 - assert dhdl.attrs['energy_unit'] == 'kT' + assert dhdl.attrs["temperature"] == 300 + assert dhdl.attrs["energy_unit"] == "kT" + + +class Test_MultipleColumnUnits: + """Test the case where the index has multiple columns""" -class Test_MultipleColumnUnits(): - '''Test the case where the index has multiple columns''' @staticmethod - @pytest.fixture(scope='class') + @pytest.fixture(scope="class") def dhdl(): - data = load_ABFE()['data']['complex'] - dhdl = alchemlyb.concat( - [extract_dHdl(data[i], - 300) for i in range(30)]) + data = load_ABFE()["data"]["complex"] + dhdl = alchemlyb.concat([extract_dHdl(data[i], 300) for i in range(30)]) return dhdl def test_ti_separate_dhdl(self, dhdl): ti = TI().fit(dhdl) dhdl_list = ti.separate_dhdl() for dhdl in dhdl_list: - assert dhdl.attrs['temperature'] == 300 - assert dhdl.attrs['energy_unit'] == 'kT' \ No newline at end of file + assert dhdl.attrs["temperature"] == 300 + assert dhdl.attrs["energy_unit"] == "kT" diff --git a/src/alchemlyb/tests/test_units.py b/src/alchemlyb/tests/test_units.py index 8dc059e9..6fcfe5b9 100644 --- a/src/alchemlyb/tests/test_units.py +++ b/src/alchemlyb/tests/test_units.py @@ -1,35 +1,43 @@ -import pytest import pandas as pd +import pytest +from alchemtest.gmx import load_benzene import alchemlyb from alchemlyb import pass_attrs -from alchemtest.gmx import load_benzene from alchemlyb.parsing.gmx import extract_dHdl, extract_u_nk from alchemlyb.postprocessors.units import to_kT -from alchemlyb.preprocessing import (dhdl2series, u_nk2series, - decorrelate_u_nk, decorrelate_dhdl, - slicing, statistical_inefficiency, - equilibrium_detection) +from alchemlyb.preprocessing import ( + dhdl2series, + u_nk2series, + decorrelate_u_nk, + decorrelate_dhdl, + slicing, + statistical_inefficiency, + equilibrium_detection, +) + def test_noT(): - '''Test no temperature error''' + """Test no temperature error""" dataset = load_benzene() - dhdl = extract_dHdl(dataset['data']['Coulomb'][0], 310) - dhdl.attrs.pop('temperature', None) + dhdl = extract_dHdl(dataset["data"]["Coulomb"][0], 310) + dhdl.attrs.pop("temperature", None) with pytest.raises(TypeError): to_kT(dhdl) + def test_nounit(): - '''Test no unit error''' + """Test no unit error""" dataset = load_benzene() - dhdl = extract_dHdl(dataset['data']['Coulomb'][0], 310) - dhdl.attrs.pop('energy_unit', None) + dhdl = extract_dHdl(dataset["data"]["Coulomb"][0], 310) + dhdl.attrs.pop("energy_unit", None) with pytest.raises(TypeError): to_kT(dhdl) + def test_concat(): - '''Test if different attrs could will give rise to error.''' - d = {'col1': [1, 2], 'col2': [3, 4]} + """Test if different attrs could will give rise to error.""" + d = {"col1": [1, 2], "col2": [3, 4]} df1 = pd.DataFrame(data=d) df1.attrs = {1: 1} df2 = pd.DataFrame(data=d) @@ -37,68 +45,73 @@ def test_concat(): with pytest.raises(ValueError): alchemlyb.concat([df1, df2]) + def test_concat_empty(): - '''Test if empty raise the right error.''' + """Test if empty raise the right error.""" with pytest.raises(ValueError): alchemlyb.concat([]) + def test_setT(): - '''Test setting temperature.''' - df = pd.DataFrame(data={'col1': [1, 2]}) - df.attrs = {'temperature': 300, 'energy_unit': 'kT'} + """Test setting temperature.""" + df = pd.DataFrame(data={"col1": [1, 2]}) + df.attrs = {"temperature": 300, "energy_unit": "kT"} new = to_kT(df, 310) - assert new.attrs['temperature'] == 310 + assert new.attrs["temperature"] == 310 + + +class Test_Conversion: + """Test the preprocessing module.""" -class Test_Conversion(): - '''Test the preprocessing module.''' @staticmethod - @pytest.fixture(scope='class') + @pytest.fixture(scope="class") def dhdl(): dataset = load_benzene() - dhdl = extract_dHdl(dataset['data']['Coulomb'][0], 310) + dhdl = extract_dHdl(dataset["data"]["Coulomb"][0], 310) return dhdl def test_kt2kt_number(self, dhdl): new_dhdl = to_kT(dhdl) - assert 12.9 == pytest.approx(new_dhdl.loc[(0.0,0.0)], 0.1) + assert 12.9 == pytest.approx(new_dhdl.loc[(0.0, 0.0)], 0.1) def test_kt2kt_unit(self, dhdl): new_dhdl = to_kT(dhdl) - assert new_dhdl.attrs['energy_unit'] == 'kT' + assert new_dhdl.attrs["energy_unit"] == "kT" def test_kj2kt_unit(self, dhdl): - dhdl.attrs['energy_unit'] = 'kJ/mol' + dhdl.attrs["energy_unit"] = "kJ/mol" new_dhdl = to_kT(dhdl) - assert new_dhdl.attrs['energy_unit'] == 'kT' + assert new_dhdl.attrs["energy_unit"] == "kT" def test_kj2kt_number(self, dhdl): - dhdl.attrs['energy_unit'] = 'kJ/mol' + dhdl.attrs["energy_unit"] = "kJ/mol" new_dhdl = to_kT(dhdl) - assert 5.0 == pytest.approx(new_dhdl.loc[(0.0,0.0)], 0.1) + assert 5.0 == pytest.approx(new_dhdl.loc[(0.0, 0.0)], 0.1) def test_kcal2kt_unit(self, dhdl): - dhdl.attrs['energy_unit'] = 'kcal/mol' + dhdl.attrs["energy_unit"] = "kcal/mol" new_dhdl = to_kT(dhdl) - assert new_dhdl.attrs['energy_unit'] == 'kT' + assert new_dhdl.attrs["energy_unit"] == "kT" def test_kcal2kt_number(self, dhdl): - dhdl.attrs['energy_unit'] = 'kcal/mol' + dhdl.attrs["energy_unit"] = "kcal/mol" new_dhdl = to_kT(dhdl) - assert 21.0 == pytest.approx(new_dhdl.loc[(0.0,0.0)], 0.1) + assert 21.0 == pytest.approx(new_dhdl.loc[(0.0, 0.0)], 0.1) def test_unknown2kt(self, dhdl): - dhdl.attrs['energy_unit'] = 'ddd' + dhdl.attrs["energy_unit"] = "ddd" with pytest.raises(ValueError): to_kT(dhdl) + def test_pd_concat(): - '''Test if concat will preserve the metadata. + """Test if concat will preserve the metadata. When this test is being made, the pd.concat will discard the attrs of the input dataframe. However, this should get fixed in the future. pandas-dev/pandas#28283 - ''' - d = {'col1': [1, 2], 'col2': [3, 4]} + """ + d = {"col1": [1, 2], "col2": [3, 4]} df1 = pd.DataFrame(data=d) df1.attrs = {1: 1} df2 = pd.DataFrame(data=d) @@ -106,8 +119,9 @@ def test_pd_concat(): df = pd.concat([df1, df2]) assert df.attrs == {1: 1} + def test_pass_attrs(): - d = {'col1': [1, 2], 'col2': [3, 4]} + d = {"col1": [1, 2], "col2": [3, 4]} df1 = pd.DataFrame(data=d) df1.attrs = {1: 1} df2 = pd.DataFrame(data=d) @@ -116,40 +130,48 @@ def test_pass_attrs(): @pass_attrs def concat(df1, df2): return pd.concat([df1, df2]) + assert concat(df1, df2).attrs == {1: 1} + def test_pd_slice(): - '''Test if slicing will preserve the metadata.''' - d = {'col1': [1, 2], 'col2': [3, 4]} + """Test if slicing will preserve the metadata.""" + d = {"col1": [1, 2], "col2": [3, 4]} df = pd.DataFrame(data=d) df.attrs = {1: 1} assert df[::2].attrs == {1: 1} -class TestRetainUnit(): - '''This test tests if the functions that should retain the unit would actually - retain the units.''' + +class TestRetainUnit: + """This test tests if the functions that should retain the unit would actually + retain the units.""" + @staticmethod - @pytest.fixture(scope='class') + @pytest.fixture(scope="class") def dhdl(): dataset = load_benzene() - dhdl = extract_dHdl(dataset['data']['Coulomb'][0], 310) + dhdl = extract_dHdl(dataset["data"]["Coulomb"][0], 310) return dhdl @staticmethod - @pytest.fixture(scope='class') + @pytest.fixture(scope="class") def u_nk(): dataset = load_benzene() - u_nk = extract_u_nk(dataset['data']['Coulomb'][0], 310) + u_nk = extract_u_nk(dataset["data"]["Coulomb"][0], 310) return u_nk - @pytest.mark.parametrize('func,fixture_in', - [(dhdl2series, 'dhdl'), - (u_nk2series, 'u_nk'), - (decorrelate_u_nk, 'u_nk'), - (decorrelate_dhdl, 'dhdl'), - (slicing, 'dhdl'), - (statistical_inefficiency, 'dhdl'), - (equilibrium_detection, 'dhdl')]) + @pytest.mark.parametrize( + "func,fixture_in", + [ + (dhdl2series, "dhdl"), + (u_nk2series, "u_nk"), + (decorrelate_u_nk, "u_nk"), + (decorrelate_dhdl, "dhdl"), + (slicing, "dhdl"), + (statistical_inefficiency, "dhdl"), + (equilibrium_detection, "dhdl"), + ], + ) def test_function(self, func, fixture_in, request): result = func(request.getfixturevalue(fixture_in)) - assert result.attrs['energy_unit'] is not None + assert result.attrs["energy_unit"] is not None diff --git a/src/alchemlyb/tests/test_version.py b/src/alchemlyb/tests/test_version.py index 4f2afc78..ddab2ab6 100644 --- a/src/alchemlyb/tests/test_version.py +++ b/src/alchemlyb/tests/test_version.py @@ -1,5 +1,6 @@ import alchemlyb + def test_version(): try: version = alchemlyb.__version__ @@ -8,9 +9,10 @@ def test_version(): assert len(version) > 0 + def test_version_get_versions(): import alchemlyb._version + version = alchemlyb._version.get_versions() assert alchemlyb.__version__ == version["version"] - diff --git a/src/alchemlyb/tests/test_visualisation.py b/src/alchemlyb/tests/test_visualisation.py index 509ecffd..8866822b 100644 --- a/src/alchemlyb/tests/test_visualisation.py +++ b/src/alchemlyb/tests/test_visualisation.py @@ -1,42 +1,51 @@ import matplotlib import matplotlib.pyplot as plt -import pandas as pd import numpy as np +import pandas as pd import pytest +from alchemtest.gmx import load_benzene import alchemlyb -from alchemtest.gmx import load_benzene -from alchemlyb.parsing.gmx import extract_u_nk, extract_dHdl +from alchemlyb.convergence import forward_backward_convergence from alchemlyb.estimators import MBAR, TI, BAR +from alchemlyb.parsing.gmx import extract_u_nk, extract_dHdl +from alchemlyb.visualisation import plot_convergence +from alchemlyb.visualisation.dF_state import plot_dF_state from alchemlyb.visualisation.mbar_matrix import plot_mbar_overlap_matrix from alchemlyb.visualisation.ti_dhdl import plot_ti_dhdl -from alchemlyb.visualisation.dF_state import plot_dF_state -from alchemlyb.visualisation import plot_convergence -from alchemlyb.convergence import forward_backward_convergence + def test_plot_mbar_omatrix(): - '''Just test if the plot runs''' + """Just test if the plot runs""" bz = load_benzene().data - u_nk_coul = alchemlyb.concat([extract_u_nk(xvg, T=300) for xvg in bz['Coulomb']]) + u_nk_coul = alchemlyb.concat([extract_u_nk(xvg, T=300) for xvg in bz["Coulomb"]]) mbar_coul = MBAR() mbar_coul.fit(u_nk_coul) - assert isinstance(plot_mbar_overlap_matrix(mbar_coul.overlap_matrix), - matplotlib.axes.Axes) - assert isinstance(plot_mbar_overlap_matrix(mbar_coul.overlap_matrix, [1,]), - matplotlib.axes.Axes) + assert isinstance( + plot_mbar_overlap_matrix(mbar_coul.overlap_matrix), matplotlib.axes.Axes + ) + assert isinstance( + plot_mbar_overlap_matrix( + mbar_coul.overlap_matrix, + [ + 1, + ], + ), + matplotlib.axes.Axes, + ) # Bump up coverage overlap_maxtrix = mbar_coul.overlap_matrix - overlap_maxtrix[0,0] = 0.0025 + overlap_maxtrix[0, 0] = 0.0025 overlap_maxtrix[-1, -1] = 0.9975 - assert isinstance(plot_mbar_overlap_matrix(overlap_maxtrix), - matplotlib.axes.Axes) + assert isinstance(plot_mbar_overlap_matrix(overlap_maxtrix), matplotlib.axes.Axes) + def test_plot_ti_dhdl(): - '''Just test if the plot runs''' + """Just test if the plot runs""" bz = load_benzene().data - dHdl_coul = alchemlyb.concat([extract_dHdl(xvg, T=300) for xvg in bz['Coulomb']]) + dHdl_coul = alchemlyb.concat([extract_dHdl(xvg, T=300) for xvg in bz["Coulomb"]]) ti_coul = TI() ti_coul.fit(dHdl_coul) @@ -45,36 +54,35 @@ def test_plot_ti_dhdl(): plt.close(ax.figure) fig, ax = plt.subplots(figsize=(8, 6)) - assert isinstance(plot_ti_dhdl(ti_coul, ax=ax), - matplotlib.axes.Axes) - assert isinstance(plot_ti_dhdl(ti_coul, labels=['Coul']), - matplotlib.axes.Axes) - assert isinstance(plot_ti_dhdl(ti_coul, labels=['Coul'], colors=['r']), - matplotlib.axes.Axes) + assert isinstance(plot_ti_dhdl(ti_coul, ax=ax), matplotlib.axes.Axes) + assert isinstance(plot_ti_dhdl(ti_coul, labels=["Coul"]), matplotlib.axes.Axes) + assert isinstance( + plot_ti_dhdl(ti_coul, labels=["Coul"], colors=["r"]), matplotlib.axes.Axes + ) plt.close(fig) - dHdl_vdw = alchemlyb.concat([extract_dHdl(xvg, T=300) for xvg in bz['VDW']]) + dHdl_vdw = alchemlyb.concat([extract_dHdl(xvg, T=300) for xvg in bz["VDW"]]) ti_vdw = TI().fit(dHdl_vdw) ax = plot_ti_dhdl([ti_coul, ti_vdw]) assert isinstance(ax, matplotlib.axes.Axes) plt.close(ax.figure) ti_coul.dhdl = pd.DataFrame.from_dict( - {'fep': range(100)}, - orient='index', - columns=np.arange(100)/100).T + {"fep": range(100)}, orient="index", columns=np.arange(100) / 100 + ).T ti_coul.dhdl.attrs = dHdl_vdw.attrs ax = plot_ti_dhdl(ti_coul) assert isinstance(ax, matplotlib.axes.Axes) plt.close(ax.figure) + def test_plot_dF_state(): - '''Just test if the plot runs''' + """Just test if the plot runs""" bz = load_benzene().data - u_nk_coul = alchemlyb.concat([extract_u_nk(xvg, T=300) for xvg in bz['Coulomb']]) - dHdl_coul = alchemlyb.concat([extract_dHdl(xvg, T=300) for xvg in bz['Coulomb']]) - u_nk_vdw = alchemlyb.concat([extract_u_nk(xvg, T=300) for xvg in bz['VDW']]) - dHdl_vdw = alchemlyb.concat([extract_dHdl(xvg, T=300) for xvg in bz['VDW']]) + u_nk_coul = alchemlyb.concat([extract_u_nk(xvg, T=300) for xvg in bz["Coulomb"]]) + dHdl_coul = alchemlyb.concat([extract_dHdl(xvg, T=300) for xvg in bz["Coulomb"]]) + u_nk_vdw = alchemlyb.concat([extract_u_nk(xvg, T=300) for xvg in bz["VDW"]]) + dHdl_vdw = alchemlyb.concat([extract_dHdl(xvg, T=300) for xvg in bz["VDW"]]) ti_coul = TI().fit(dHdl_coul) ti_vdw = TI().fit(dHdl_vdw) @@ -83,39 +91,47 @@ def test_plot_dF_state(): mbar_coul = MBAR().fit(u_nk_coul) mbar_vdw = MBAR().fit(u_nk_vdw) - dhdl_data = [(ti_coul, ti_vdw), - (bar_coul, bar_vdw), - (mbar_coul, mbar_vdw), ] - fig = plot_dF_state(dhdl_data, orientation='portrait') + dhdl_data = [ + (ti_coul, ti_vdw), + (bar_coul, bar_vdw), + (mbar_coul, mbar_vdw), + ] + fig = plot_dF_state(dhdl_data, orientation="portrait") assert isinstance(fig, matplotlib.figure.Figure) plt.close(fig) - fig = plot_dF_state(dhdl_data, orientation='landscape') + fig = plot_dF_state(dhdl_data, orientation="landscape") assert isinstance(fig, matplotlib.figure.Figure) plt.close(fig) - fig = plot_dF_state(dhdl_data, labels=['MBAR', 'TI', 'BAR']) + fig = plot_dF_state(dhdl_data, labels=["MBAR", "TI", "BAR"]) assert isinstance(fig, matplotlib.figure.Figure) plt.close(fig) with pytest.raises(ValueError): - fig = plot_dF_state(dhdl_data, labels=['MBAR', 'TI',]) - - fig = plot_dF_state(dhdl_data, colors=['#C45AEC', '#33CC33', '#F87431']) + fig = plot_dF_state( + dhdl_data, + labels=[ + "MBAR", + "TI", + ], + ) + + fig = plot_dF_state(dhdl_data, colors=["#C45AEC", "#33CC33", "#F87431"]) assert isinstance(fig, matplotlib.figure.Figure) plt.close(fig) with pytest.raises(ValueError): - fig = plot_dF_state(dhdl_data, colors=['#C45AEC', '#33CC33']) + fig = plot_dF_state(dhdl_data, colors=["#C45AEC", "#33CC33"]) with pytest.raises(ValueError): - fig = plot_dF_state(dhdl_data, orientation='xxx') + fig = plot_dF_state(dhdl_data, orientation="xxx") - fig = plot_dF_state(ti_coul, orientation='landscape') + fig = plot_dF_state(ti_coul, orientation="landscape") assert isinstance(fig, matplotlib.figure.Figure) plt.close(fig) - fig = plot_dF_state(ti_coul, orientation='portrait') + fig = plot_dF_state(ti_coul, orientation="portrait") assert isinstance(fig, matplotlib.figure.Figure) plt.close(fig) @@ -127,80 +143,98 @@ def test_plot_dF_state(): assert isinstance(fig, matplotlib.figure.Figure) plt.close(fig) + def test_plot_convergence_dataframe(): bz = load_benzene().data - data_list = [extract_u_nk(xvg, T=300) for xvg in bz['Coulomb']] - df = forward_backward_convergence(data_list, 'MBAR') + data_list = [extract_u_nk(xvg, T=300) for xvg in bz["Coulomb"]] + df = forward_backward_convergence(data_list, "MBAR") ax = plot_convergence(df) assert isinstance(ax, matplotlib.axes.Axes) plt.close(ax.figure) + def test_plot_convergence_dataframe_noerr(): # Test the input from R_c - data = pd.DataFrame(data={'Forward': range(100), - 'Backward': range(100), - 'data_fraction': np.linspace(0,1,100)}) - data.attrs = {'temperature': 300, 'energy_unit': 'kT'} + data = pd.DataFrame( + data={ + "Forward": range(100), + "Backward": range(100), + "data_fraction": np.linspace(0, 1, 100), + } + ) + data.attrs = {"temperature": 300, "energy_unit": "kT"} ax = plot_convergence(data, final_error=2) assert isinstance(ax, matplotlib.axes.Axes) plt.close(ax.figure) + def test_plot_convergence(): bz = load_benzene().data - data_list = [extract_u_nk(xvg, T=300) for xvg in bz['Coulomb']] + data_list = [extract_u_nk(xvg, T=300) for xvg in bz["Coulomb"]] forward = [] forward_error = [] backward = [] backward_error = [] num_points = 10 - for i in range(1, num_points+1): + for i in range(1, num_points + 1): # Do the forward - slice = int(len(data_list[0])/num_points*i) + slice = int(len(data_list[0]) / num_points * i) u_nk_coul = alchemlyb.concat([data[:slice] for data in data_list]) estimate = MBAR().fit(u_nk_coul) - forward.append(estimate.delta_f_.loc[0.0,1.0]) - forward_error.append(estimate.d_delta_f_.loc[0.0,1.0]) + forward.append(estimate.delta_f_.loc[0.0, 1.0]) + forward_error.append(estimate.d_delta_f_.loc[0.0, 1.0]) # Do the backward u_nk_coul = alchemlyb.concat([data[-slice:] for data in data_list]) estimate = MBAR().fit(u_nk_coul) - backward.append(estimate.delta_f_.loc[0.0,1.0]) - backward_error.append(estimate.d_delta_f_.loc[0.0,1.0]) - - df = pd.DataFrame(data={'Forward': forward, - 'Forward_Error': forward_error, - 'Backward': backward, - 'Backward_Error': backward_error}) + backward.append(estimate.delta_f_.loc[0.0, 1.0]) + backward_error.append(estimate.d_delta_f_.loc[0.0, 1.0]) + + df = pd.DataFrame( + data={ + "Forward": forward, + "Forward_Error": forward_error, + "Backward": backward, + "Backward_Error": backward_error, + } + ) df.attrs = estimate.delta_f_.attrs ax = plot_convergence(df) assert isinstance(ax, matplotlib.axes.Axes) plt.close(ax.figure) -class Test_Units(): + +class Test_Units: @staticmethod - @pytest.fixture(scope='class') + @pytest.fixture(scope="class") def estimaters(): bz = load_benzene().data dHdl_coul = alchemlyb.concat( - [extract_dHdl(xvg, T=300) for xvg in bz['Coulomb']]) + [extract_dHdl(xvg, T=300) for xvg in bz["Coulomb"]] + ) ti = TI().fit(dHdl_coul) u_nk_coul = alchemlyb.concat( - [extract_u_nk(xvg, T=300) for xvg in bz['Coulomb']]) + [extract_u_nk(xvg, T=300) for xvg in bz["Coulomb"]] + ) mbar = MBAR().fit(u_nk_coul) return ti, mbar @staticmethod - @pytest.fixture(scope='class') + @pytest.fixture(scope="class") def convergence(): - df = pd.DataFrame(data={'Forward': range(10), - 'Forward_Error': range(10), - 'Backward': range(10), - 'Backward_Error': range(10)}) - df.attrs = {'temperature': 300, 'energy_unit': 'kT'} + df = pd.DataFrame( + data={ + "Forward": range(10), + "Forward_Error": range(10), + "Backward": range(10), + "Backward_Error": range(10), + } + ) + df.attrs = {"temperature": 300, "energy_unit": "kT"} return df - @pytest.mark.parametrize('units', [None, 'kT', 'kJ/mol', 'kcal/mol']) + @pytest.mark.parametrize("units", [None, "kT", "kJ/mol", "kcal/mol"]) def test_plot_dF_state(self, estimaters, units): fig = plot_dF_state(estimaters, units=units) assert isinstance(fig, matplotlib.figure.Figure) @@ -208,9 +242,9 @@ def test_plot_dF_state(self, estimaters, units): def test_plot_dF_state_unknown(self, estimaters): with pytest.raises(ValueError): - fig = plot_dF_state(estimaters, units='ddd') + fig = plot_dF_state(estimaters, units="ddd") - @pytest.mark.parametrize('units', [None, 'kT', 'kJ/mol', 'kcal/mol']) + @pytest.mark.parametrize("units", [None, "kT", "kJ/mol", "kcal/mol"]) def test_plot_ti_dhdl(self, estimaters, units): ti, mbar = estimaters ax = plot_ti_dhdl(ti, units=units) @@ -220,9 +254,9 @@ def test_plot_ti_dhdl(self, estimaters, units): def test_plot_ti_dhdl_unknown(self, estimaters): ti, mbar = estimaters with pytest.raises(ValueError): - fig = plot_ti_dhdl(ti, units='ddd') + fig = plot_ti_dhdl(ti, units="ddd") - @pytest.mark.parametrize('units', [None, 'kT', 'kJ/mol', 'kcal/mol']) + @pytest.mark.parametrize("units", [None, "kT", "kJ/mol", "kcal/mol"]) def test_plot_convergence(self, convergence, units): ax = plot_convergence(convergence) assert isinstance(ax, matplotlib.axes.Axes) diff --git a/src/alchemlyb/tests/test_workflow.py b/src/alchemlyb/tests/test_workflow.py index a4308145..cc31e611 100644 --- a/src/alchemlyb/tests/test_workflow.py +++ b/src/alchemlyb/tests/test_workflow.py @@ -1,11 +1,14 @@ +import os + +import pandas as pd import pytest + from alchemlyb.workflows import base -import pandas as pd -import os -class Test_automatic_base(): + +class Test_automatic_base: @staticmethod - @pytest.fixture(scope='session') + @pytest.fixture(scope="session") def workflow(tmp_path_factory): outdir = tmp_path_factory.mktemp("out") workflow = base.WorkflowBase(out=str(outdir)) @@ -13,9 +16,9 @@ def workflow(tmp_path_factory): return workflow def test_write(self, workflow): - '''Patch the output directory to tmpdir''' - workflow.result.to_pickle(os.path.join(workflow.out, 'result.pkl')) - assert os.path.exists(os.path.join(workflow.out, 'result.pkl')) + """Patch the output directory to tmpdir""" + workflow.result.to_pickle(os.path.join(workflow.out, "result.pkl")) + assert os.path.exists(os.path.join(workflow.out, "result.pkl")) def test_read(self, workflow): assert len(workflow.u_nk_list) == 0 diff --git a/src/alchemlyb/tests/test_workflow_ABFE.py b/src/alchemlyb/tests/test_workflow_ABFE.py index f282041b..e1d6a4f2 100644 --- a/src/alchemlyb/tests/test_workflow_ABFE.py +++ b/src/alchemlyb/tests/test_workflow_ABFE.py @@ -1,204 +1,235 @@ +import os + import numpy as np import pytest -import os +from alchemtest.amber import load_bace_example +from alchemtest.gmx import load_ABFE, load_benzene from alchemlyb.workflows.abfe import ABFE -from alchemtest.gmx import load_ABFE, load_benzene -from alchemtest.amber import load_bace_example -class Test_automatic_ABFE(): - '''Test the full automatic workflow for load_ABFE from alchemtest.gmx for - three stage transformation.''' + +class Test_automatic_ABFE: + """Test the full automatic workflow for load_ABFE from alchemtest.gmx for + three stage transformation.""" @staticmethod - @pytest.fixture(scope='session') + @pytest.fixture(scope="session") def workflow(tmp_path_factory): outdir = tmp_path_factory.mktemp("out") - dir = os.path.dirname(load_ABFE()['data']['complex'][0]) - workflow = ABFE(units='kcal/mol', software='GROMACS', dir=dir, - prefix='dhdl', suffix='xvg', T=310, outdirectory=str(outdir)) - workflow.run(skiptime=10, uncorr='dE', threshold=50, - estimators=('MBAR', 'BAR', 'TI'), overlap='O_MBAR.pdf', - breakdown=True, forwrev=10) + dir = os.path.dirname(load_ABFE()["data"]["complex"][0]) + workflow = ABFE( + units="kcal/mol", + software="GROMACS", + dir=dir, + prefix="dhdl", + suffix="xvg", + T=310, + outdirectory=str(outdir), + ) + workflow.run( + skiptime=10, + uncorr="dE", + threshold=50, + estimators=("MBAR", "BAR", "TI"), + overlap="O_MBAR.pdf", + breakdown=True, + forwrev=10, + ) return workflow def test_read(self, workflow): - '''test if the files has been loaded correctly.''' + """test if the files has been loaded correctly.""" assert len(workflow.u_nk_list) == 30 assert len(workflow.dHdl_list) == 30 assert all([len(u_nk) == 1001 for u_nk in workflow.u_nk_list]) assert all([len(dHdl) == 1001 for dHdl in workflow.dHdl_list]) def test_subsample(self, workflow): - '''Test if the data has been shrinked by subsampling.''' + """Test if the data has been shrinked by subsampling.""" assert len(workflow.u_nk_sample_list) == 30 assert len(workflow.dHdl_sample_list) == 30 assert all([len(u_nk) < 1001 for u_nk in workflow.u_nk_sample_list]) assert all([len(dHdl) < 1001 for dHdl in workflow.dHdl_sample_list]) def test_estimator(self, workflow): - '''Test if all three estimators have been used.''' + """Test if all three estimators have been used.""" assert len(workflow.estimator) == 3 - assert 'MBAR' in workflow.estimator - assert 'TI' in workflow.estimator - assert 'BAR' in workflow.estimator + assert "MBAR" in workflow.estimator + assert "TI" in workflow.estimator + assert "BAR" in workflow.estimator def test_summary(self, workflow): - '''Test if if the summary is right.''' + """Test if if the summary is right.""" summary = workflow.generate_result() - assert np.isclose(summary['MBAR']['Stages']['TOTAL'], 21.8, 0.1) - assert np.isclose(summary['TI']['Stages']['TOTAL'], 21.8, 0.1) - assert np.isclose(summary['BAR']['Stages']['TOTAL'], 21.8, 0.1) + assert np.isclose(summary["MBAR"]["Stages"]["TOTAL"], 21.8, 0.1) + assert np.isclose(summary["TI"]["Stages"]["TOTAL"], 21.8, 0.1) + assert np.isclose(summary["BAR"]["Stages"]["TOTAL"], 21.8, 0.1) def test_plot_O_MBAR(self, workflow): - '''test if the O_MBAR.pdf has been plotted.''' - assert os.path.isfile(os.path.join(workflow.out, 'O_MBAR.pdf')) + """test if the O_MBAR.pdf has been plotted.""" + assert os.path.isfile(os.path.join(workflow.out, "O_MBAR.pdf")) def test_plot_dhdl_TI(self, workflow): - '''test if the dhdl_TI.pdf has been plotted.''' - assert os.path.isfile(os.path.join(workflow.out, 'dhdl_TI.pdf')) + """test if the dhdl_TI.pdf has been plotted.""" + assert os.path.isfile(os.path.join(workflow.out, "dhdl_TI.pdf")) def test_plot_dF_state(self, workflow): - '''test if the dF_state.pdf has been plotted.''' - assert os.path.isfile(os.path.join(workflow.out, 'dF_state.pdf')) - assert os.path.isfile(os.path.join(workflow.out, 'dF_state_long.pdf')) + """test if the dF_state.pdf has been plotted.""" + assert os.path.isfile(os.path.join(workflow.out, "dF_state.pdf")) + assert os.path.isfile(os.path.join(workflow.out, "dF_state_long.pdf")) def test_check_convergence(self, workflow): - '''test if the dF_state.pdf has been plotted.''' - assert os.path.isfile(os.path.join(workflow.out, 'dF_t.pdf')) + """test if the dF_state.pdf has been plotted.""" + assert os.path.isfile(os.path.join(workflow.out, "dF_t.pdf")) assert len(workflow.convergence) == 10 def test_estimator_method(self, workflow, monkeypatch): - '''Test if the method keyword could be passed to the AutoMBAR estimator.''' - monkeypatch.setattr(workflow, 'estimator', - dict()) - workflow.estimate(estimators='MBAR', method='adaptive') - assert 'MBAR' in workflow.estimator + """Test if the method keyword could be passed to the AutoMBAR estimator.""" + monkeypatch.setattr(workflow, "estimator", dict()) + workflow.estimate(estimators="MBAR", method="adaptive") + assert "MBAR" in workflow.estimator def test_convergence_method(self, workflow, monkeypatch): - '''Test if the method keyword could be passed to the AutoMBAR estimator from convergence.''' - monkeypatch.setattr(workflow, 'convergence', None) - workflow.check_convergence(2, estimator='MBAR', method='adaptive') + """Test if the method keyword could be passed to the AutoMBAR estimator from convergence.""" + monkeypatch.setattr(workflow, "convergence", None) + workflow.check_convergence(2, estimator="MBAR", method="adaptive") assert len(workflow.convergence) == 2 + class Test_manual_ABFE(Test_automatic_ABFE): - '''Test the manual workflow for load_ABFE from alchemtest.gmx for three - stage transformation.''' + """Test the manual workflow for load_ABFE from alchemtest.gmx for three + stage transformation.""" @staticmethod - @pytest.fixture(scope='session') + @pytest.fixture(scope="session") def workflow(tmp_path_factory): outdir = tmp_path_factory.mktemp("out") - dir = os.path.dirname(load_ABFE()['data']['complex'][0]) - workflow = ABFE(software='GROMACS', dir=dir, prefix='dhdl', - suffix='xvg', T=310, outdirectory=str(outdir)) - workflow.update_units('kcal/mol') + dir = os.path.dirname(load_ABFE()["data"]["complex"][0]) + workflow = ABFE( + software="GROMACS", + dir=dir, + prefix="dhdl", + suffix="xvg", + T=310, + outdirectory=str(outdir), + ) + workflow.update_units("kcal/mol") workflow.read() - workflow.preprocess(skiptime=10, uncorr='dE', threshold=50) - workflow.estimate(estimators=('MBAR', 'BAR', 'TI')) - workflow.plot_overlap_matrix(overlap='O_MBAR.pdf') - workflow.plot_ti_dhdl(dhdl_TI='dhdl_TI.pdf') - workflow.plot_dF_state(dF_state='dF_state.pdf') - workflow.check_convergence(10, dF_t='dF_t.pdf') + workflow.preprocess(skiptime=10, uncorr="dE", threshold=50) + workflow.estimate(estimators=("MBAR", "BAR", "TI")) + workflow.plot_overlap_matrix(overlap="O_MBAR.pdf") + workflow.plot_ti_dhdl(dhdl_TI="dhdl_TI.pdf") + workflow.plot_dF_state(dF_state="dF_state.pdf") + workflow.check_convergence(10, dF_t="dF_t.pdf") return workflow def test_plot_dF_state(self, workflow): - '''test if the dF_state.pdf has been plotted.''' - assert os.path.isfile(os.path.join(workflow.out, 'dF_state.pdf')) + """test if the dF_state.pdf has been plotted.""" + assert os.path.isfile(os.path.join(workflow.out, "dF_state.pdf")) def test_convergence_nosample_u_nk(self, workflow, monkeypatch): - '''test if the convergence routine would use the unsampled data - when the data has not been subsampled.''' - monkeypatch.setattr(workflow, 'u_nk_sample_list', - None) + """test if the convergence routine would use the unsampled data + when the data has not been subsampled.""" + monkeypatch.setattr(workflow, "u_nk_sample_list", None) workflow.check_convergence(10) assert len(workflow.convergence) == 10 def test_dhdl_TI_noTI(self, workflow, monkeypatch): - '''Test to plot the dhdl_TI when ti estimator is not there''' + """Test to plot the dhdl_TI when ti estimator is not there""" no_TI = workflow.estimator - no_TI.pop('TI') - monkeypatch.setattr(workflow, 'estimator', - no_TI) + no_TI.pop("TI") + monkeypatch.setattr(workflow, "estimator", no_TI) with pytest.raises(ValueError): - workflow.plot_ti_dhdl(dhdl_TI='dhdl_TI.pdf') + workflow.plot_ti_dhdl(dhdl_TI="dhdl_TI.pdf") def test_noMBAR_for_plot_overlap_matrix(self, workflow, monkeypatch): - monkeypatch.setattr(workflow, 'estimator', {}) + monkeypatch.setattr(workflow, "estimator", {}) assert workflow.plot_overlap_matrix() is None def test_no_u_nk_for_check_convergence(self, workflow, monkeypatch): - monkeypatch.setattr(workflow, 'u_nk_list', None) - monkeypatch.setattr(workflow, 'u_nk_sample_list', None) + monkeypatch.setattr(workflow, "u_nk_list", None) + monkeypatch.setattr(workflow, "u_nk_sample_list", None) with pytest.raises(ValueError): - workflow.check_convergence(10, estimator='MBAR') + workflow.check_convergence(10, estimator="MBAR") def test_no_dHdl_for_check_convergence(self, workflow, monkeypatch): - monkeypatch.setattr(workflow, 'dHdl_list', None) - monkeypatch.setattr(workflow, 'dHdl_sample_list', None) + monkeypatch.setattr(workflow, "dHdl_list", None) + monkeypatch.setattr(workflow, "dHdl_sample_list", None) with pytest.raises(ValueError): - workflow.check_convergence(10, estimator='TI') + workflow.check_convergence(10, estimator="TI") def test_no_update_units(self, workflow): assert workflow.update_units() is None def test_no_name_estimate(self, workflow): with pytest.raises(ValueError): - workflow.estimate('aaa') + workflow.estimate("aaa") -class Test_automatic_benzene(): - '''Test the full automatic workflow for load_benzene from alchemtest.gmx for - single stage transformation.''' +class Test_automatic_benzene: + """Test the full automatic workflow for load_benzene from alchemtest.gmx for + single stage transformation.""" @staticmethod - @pytest.fixture(scope='session') + @pytest.fixture(scope="session") def workflow(tmp_path_factory): outdir = tmp_path_factory.mktemp("out") - dir = os.path.dirname(os.path.dirname( - load_benzene()['data']['Coulomb'][0])) - dir = os.path.join(dir, '*') - workflow = ABFE(units='kcal/mol', software='GROMACS', dir=dir, - prefix='dhdl', suffix='bz2', T=310, - outdirectory=outdir) - workflow.run(skiptime=0, uncorr='dE', threshold=50, - estimators=('MBAR', 'BAR', 'TI'), overlap='O_MBAR.pdf', - breakdown=True, forwrev=10) + dir = os.path.dirname(os.path.dirname(load_benzene()["data"]["Coulomb"][0])) + dir = os.path.join(dir, "*") + workflow = ABFE( + units="kcal/mol", + software="GROMACS", + dir=dir, + prefix="dhdl", + suffix="bz2", + T=310, + outdirectory=outdir, + ) + workflow.run( + skiptime=0, + uncorr="dE", + threshold=50, + estimators=("MBAR", "BAR", "TI"), + overlap="O_MBAR.pdf", + breakdown=True, + forwrev=10, + ) return workflow def test_read(self, workflow): - '''test if the files has been loaded correctly.''' + """test if the files has been loaded correctly.""" assert len(workflow.u_nk_list) == 5 assert len(workflow.dHdl_list) == 5 assert all([len(u_nk) == 4001 for u_nk in workflow.u_nk_list]) assert all([len(dHdl) == 4001 for dHdl in workflow.dHdl_list]) def test_estimator(self, workflow): - '''Test if all three estimators have been used.''' + """Test if all three estimators have been used.""" assert len(workflow.estimator) == 3 - assert 'MBAR' in workflow.estimator - assert 'TI' in workflow.estimator - assert 'BAR' in workflow.estimator + assert "MBAR" in workflow.estimator + assert "TI" in workflow.estimator + assert "BAR" in workflow.estimator def test_O_MBAR(self, workflow): - '''test if the O_MBAR.pdf has been plotted.''' - assert os.path.isfile(os.path.join(workflow.out, 'O_MBAR.pdf')) + """test if the O_MBAR.pdf has been plotted.""" + assert os.path.isfile(os.path.join(workflow.out, "O_MBAR.pdf")) def test_dhdl_TI(self, workflow): - '''test if the dhdl_TI.pdf has been plotted.''' - assert os.path.isfile(os.path.join(workflow.out, 'dhdl_TI.pdf')) + """test if the dhdl_TI.pdf has been plotted.""" + assert os.path.isfile(os.path.join(workflow.out, "dhdl_TI.pdf")) def test_dF_state(self, workflow): - '''test if the dF_state.pdf has been plotted.''' - assert os.path.isfile(os.path.join(workflow.out, 'dF_state.pdf')) + """test if the dF_state.pdf has been plotted.""" + assert os.path.isfile(os.path.join(workflow.out, "dF_state.pdf")) def test_convergence(self, workflow): - '''test if the dF_state.pdf has been plotted.''' - assert os.path.isfile(os.path.join(workflow.out, 'dF_t.pdf')) + """test if the dF_state.pdf has been plotted.""" + assert os.path.isfile(os.path.join(workflow.out, "dF_t.pdf")) assert len(workflow.convergence) == 10 -class Test_unpertubed_lambda(): - '''Test the if two lamdas present and one of them is not pertubed. + +class Test_unpertubed_lambda: + """Test the if two lamdas present and one of them is not pertubed. fep bound time fep-lambda bound-lambda @@ -209,87 +240,118 @@ class Test_unpertubed_lambda(): 40.0 0.5 0 7.768072 0 Where only fep-lambda changes but the bonded-lambda is always 0. - ''' + """ @staticmethod - @pytest.fixture(scope='session') + @pytest.fixture(scope="session") def workflow(tmp_path_factory): outdir = tmp_path_factory.mktemp("out") - dir = os.path.dirname(os.path.dirname( - load_benzene()['data']['Coulomb'][0])) - dir = os.path.join(dir, '*') - workflow = ABFE(software='GROMACS', dir=dir, prefix='dhdl', - suffix='bz2', T=310, outdirectory=outdir) + dir = os.path.dirname(os.path.dirname(load_benzene()["data"]["Coulomb"][0])) + dir = os.path.join(dir, "*") + workflow = ABFE( + software="GROMACS", + dir=dir, + prefix="dhdl", + suffix="bz2", + T=310, + outdirectory=outdir, + ) workflow.read() # Block the n_uk workflow.u_nk_list = [] # Add another lambda column for dHdl in workflow.dHdl_list: - dHdl.insert(1, 'bound-lambda', [1.0, ] * len(dHdl)) - dHdl.insert(1, 'bound', [1.0, ] * len(dHdl)) - dHdl.set_index('bound-lambda', append=True, inplace=True) - - workflow.estimate(estimators=('TI', )) - workflow.plot_ti_dhdl(dhdl_TI='dhdl_TI.pdf') - workflow.plot_dF_state(dF_state='dF_state.pdf') - workflow.check_convergence(10, dF_t='dF_t.pdf', estimator='TI') + dHdl.insert( + 1, + "bound-lambda", + [ + 1.0, + ] + * len(dHdl), + ) + dHdl.insert( + 1, + "bound", + [ + 1.0, + ] + * len(dHdl), + ) + dHdl.set_index("bound-lambda", append=True, inplace=True) + + workflow.estimate(estimators=("TI",)) + workflow.plot_ti_dhdl(dhdl_TI="dhdl_TI.pdf") + workflow.plot_dF_state(dF_state="dF_state.pdf") + workflow.check_convergence(10, dF_t="dF_t.pdf", estimator="TI") return workflow def test_dhdl_TI(self, workflow): - '''test if the dhdl_TI.pdf has been plotted.''' - assert os.path.isfile(os.path.join(workflow.out, 'dhdl_TI.pdf')) + """test if the dhdl_TI.pdf has been plotted.""" + assert os.path.isfile(os.path.join(workflow.out, "dhdl_TI.pdf")) def test_dF_state(self, workflow): - '''test if the dF_state.pdf has been plotted.''' - assert os.path.isfile(os.path.join(workflow.out, 'dF_state.pdf')) + """test if the dF_state.pdf has been plotted.""" + assert os.path.isfile(os.path.join(workflow.out, "dF_state.pdf")) def test_convergence(self, workflow): - '''test if the dF_state.pdf has been plotted.''' - assert os.path.isfile(os.path.join(workflow.out, 'dF_t.pdf')) + """test if the dF_state.pdf has been plotted.""" + assert os.path.isfile(os.path.join(workflow.out, "dF_t.pdf")) assert len(workflow.convergence) == 10 def test_single_estimator_ti(self, workflow): - workflow.estimate(estimators='TI') + workflow.estimate(estimators="TI") summary = workflow.generate_result() - assert np.isclose(summary['TI']['Stages']['TOTAL'], 2.946, 0.1) + assert np.isclose(summary["TI"]["Stages"]["TOTAL"], 2.946, 0.1) + -class Test_methods(): - '''Test various methods.''' +class Test_methods: + """Test various methods.""" @staticmethod - @pytest.fixture(scope='class') + @pytest.fixture(scope="class") def workflow(tmp_path_factory): outdir = tmp_path_factory.mktemp("out") - dir = os.path.dirname(os.path.dirname( - load_benzene()['data']['Coulomb'][0])) - dir = os.path.join(dir, '*') - workflow = ABFE(software='GROMACS', dir=dir, prefix='dhdl', - suffix='bz2', T=310, outdirectory=outdir) + dir = os.path.dirname(os.path.dirname(load_benzene()["data"]["Coulomb"][0])) + dir = os.path.join(dir, "*") + workflow = ABFE( + software="GROMACS", + dir=dir, + prefix="dhdl", + suffix="bz2", + T=310, + outdirectory=outdir, + ) workflow.read() return workflow def test_run_none(self, workflow): - '''Don't run anything''' - workflow.run(uncorr=None, estimators=None, overlap=None, breakdown=None, - forwrev=None) + """Don't run anything""" + workflow.run( + uncorr=None, estimators=None, overlap=None, breakdown=None, forwrev=None + ) def test_run_single_estimator(self, workflow, monkeypatch): - monkeypatch.setattr(workflow, 'u_nk_list', []) - monkeypatch.setattr(workflow, 'dHdl_list', []) - workflow.run(uncorr=None, estimators='MBAR', overlap=None, breakdown=True, - forwrev=None) + monkeypatch.setattr(workflow, "u_nk_list", []) + monkeypatch.setattr(workflow, "dHdl_list", []) + workflow.run( + uncorr=None, estimators="MBAR", overlap=None, breakdown=True, forwrev=None + ) def test_run_invalid_estimator(self, workflow): - with pytest.raises(ValueError, - match=r'Estimator aaa is not supported.'): - workflow.run(uncorr=None, estimators='aaa', overlap=None, breakdown=None, - forwrev=None) - - @pytest.mark.parametrize('read_u_nk', [True, False]) - @pytest.mark.parametrize('read_dHdl', [True, False]) + with pytest.raises(ValueError, match=r"Estimator aaa is not supported."): + workflow.run( + uncorr=None, + estimators="aaa", + overlap=None, + breakdown=None, + forwrev=None, + ) + + @pytest.mark.parametrize("read_u_nk", [True, False]) + @pytest.mark.parametrize("read_dHdl", [True, False]) def test_read_TI_FEP(self, workflow, monkeypatch, read_u_nk, read_dHdl): - monkeypatch.setattr(workflow, 'u_nk_list', []) - monkeypatch.setattr(workflow, 'dHdl_list', []) + monkeypatch.setattr(workflow, "u_nk_list", []) + monkeypatch.setattr(workflow, "dHdl_list", []) workflow.read(read_u_nk, read_dHdl) if read_u_nk: assert len(workflow.u_nk_list) == 5 @@ -303,104 +365,112 @@ def test_read_TI_FEP(self, workflow, monkeypatch, read_u_nk, read_dHdl): def test_read_invalid_u_nk(self, workflow, monkeypatch): def extract_u_nk(self, T): - raise IOError('Error read u_nk.') - monkeypatch.setattr(workflow, '_extract_u_nk', - extract_u_nk) - with pytest.raises(OSError, - match=r'Error reading u_nk .*dhdl\.xvg\.bz2'): + raise IOError("Error read u_nk.") + + monkeypatch.setattr(workflow, "_extract_u_nk", extract_u_nk) + with pytest.raises(OSError, match=r"Error reading u_nk .*dhdl\.xvg\.bz2"): workflow.read() def test_read_invalid_dHdl(self, workflow, monkeypatch): def extract_dHdl(self, T): - raise IOError('Error read dHdl.') - monkeypatch.setattr(workflow, '_extract_dHdl', - extract_dHdl) - with pytest.raises(OSError, - match=r'Error reading dHdl .*dhdl\.xvg\.bz2'): + raise IOError("Error read dHdl.") + + monkeypatch.setattr(workflow, "_extract_dHdl", extract_dHdl) + with pytest.raises(OSError, match=r"Error reading dHdl .*dhdl\.xvg\.bz2"): workflow.read() def test_uncorr_threshold(self, workflow, monkeypatch): - '''Test if the full data will be used when the number of data points - are less than the threshold.''' - monkeypatch.setattr(workflow, 'u_nk_list', - [u_nk[:40] for u_nk in workflow.u_nk_list]) - monkeypatch.setattr(workflow, 'dHdl_list', - [dHdl[:40] for dHdl in workflow.dHdl_list]) + """Test if the full data will be used when the number of data points + are less than the threshold.""" + monkeypatch.setattr( + workflow, "u_nk_list", [u_nk[:40] for u_nk in workflow.u_nk_list] + ) + monkeypatch.setattr( + workflow, "dHdl_list", [dHdl[:40] for dHdl in workflow.dHdl_list] + ) workflow.preprocess(threshold=50) assert all([len(u_nk) == 40 for u_nk in workflow.u_nk_sample_list]) assert all([len(dHdl) == 40 for dHdl in workflow.dHdl_sample_list]) def test_no_u_nk_preprocess(self, workflow, monkeypatch): - monkeypatch.setattr(workflow, 'u_nk_list', []) + monkeypatch.setattr(workflow, "u_nk_list", []) workflow.preprocess(threshold=50) assert len(workflow.u_nk_list) == 0 def test_no_dHdl_preprocess(self, workflow, monkeypatch): - monkeypatch.setattr(workflow, 'dHdl_list', []) + monkeypatch.setattr(workflow, "dHdl_list", []) workflow.preprocess(threshold=50) assert len(workflow.dHdl_list) == 0 def test_single_estimator_mbar(self, workflow): - workflow.estimate(estimators='MBAR') + workflow.estimate(estimators="MBAR") assert len(workflow.estimator) == 1 - assert 'MBAR' in workflow.estimator + assert "MBAR" in workflow.estimator summary = workflow.generate_result() - assert np.isclose(summary['MBAR']['Stages']['TOTAL'], 2.946, 0.1) + assert np.isclose(summary["MBAR"]["Stages"]["TOTAL"], 2.946, 0.1) def test_single_estimator_ti(self, workflow): - workflow.estimate(estimators='TI') + workflow.estimate(estimators="TI") summary = workflow.generate_result() - assert np.isclose(summary['TI']['Stages']['TOTAL'], 2.946, 0.1) + assert np.isclose(summary["TI"]["Stages"]["TOTAL"], 2.946, 0.1) def test_bar_convergence(self, workflow): - workflow.check_convergence(10, estimator='BAR') + workflow.check_convergence(10, estimator="BAR") assert len(workflow.convergence) == 10 def test_convergence_invalid_estimator(self, workflow): with pytest.raises(ValueError): - workflow.check_convergence(10, estimator='aaa') + workflow.check_convergence(10, estimator="aaa") def test_ti_convergence(self, workflow): - workflow.check_convergence(10, estimator='TI') + workflow.check_convergence(10, estimator="TI") assert len(workflow.convergence) == 10 def test_unprocessed_n_uk(self, workflow, monkeypatch): - monkeypatch.setattr(workflow, 'u_nk_sample_list', - None) + monkeypatch.setattr(workflow, "u_nk_sample_list", None) workflow.estimate() assert len(workflow.estimator) == 3 - assert 'MBAR' in workflow.estimator + assert "MBAR" in workflow.estimator def test_unprocessed_dhdl(self, workflow, monkeypatch): - monkeypatch.setattr(workflow, 'dHdl_sample_list', - None) - workflow.check_convergence(10, estimator='TI') + monkeypatch.setattr(workflow, "dHdl_sample_list", None) + workflow.check_convergence(10, estimator="TI") assert len(workflow.convergence) == 10 -class Test_automatic_amber(): - '''Test the full automatic workflow for load_ABFE from alchemtest.amber for - three stage transformation.''' + +class Test_automatic_amber: + """Test the full automatic workflow for load_ABFE from alchemtest.amber for + three stage transformation.""" @staticmethod - @pytest.fixture(scope='session') + @pytest.fixture(scope="session") def workflow(tmp_path_factory): outdir = tmp_path_factory.mktemp("out") dir, _ = os.path.split( - os.path.dirname(load_bace_example()['data']['complex']['vdw'][0])) - - workflow = ABFE(units='kcal/mol', software='AMBER', dir=dir, - prefix='ti', suffix='bz2', T=298.0, outdirectory=str( - outdir)) + os.path.dirname(load_bace_example()["data"]["complex"]["vdw"][0]) + ) + + workflow = ABFE( + units="kcal/mol", + software="AMBER", + dir=dir, + prefix="ti", + suffix="bz2", + T=298.0, + outdirectory=str(outdir), + ) workflow.read() - workflow.estimate(estimators='TI') + workflow.estimate(estimators="TI") return workflow def test_summary(self, workflow): - '''Test if if the summary is right.''' + """Test if if the summary is right.""" summary = workflow.generate_result() - assert np.isclose(summary['TI']['Stages']['TOTAL'], 1.40405980473, 0.1) + assert np.isclose(summary["TI"]["Stages"]["TOTAL"], 1.40405980473, 0.1) + def test_no_parser(): with pytest.raises(NotImplementedError): - workflow = ABFE(units='kcal/mol', software='aaa', - prefix='ti', suffix='bz2', T=298.0) + workflow = ABFE( + units="kcal/mol", software="aaa", prefix="ti", suffix="bz2", T=298.0 + ) diff --git a/src/alchemlyb/visualisation/__init__.py b/src/alchemlyb/visualisation/__init__.py index d58b367e..6955dcaf 100644 --- a/src/alchemlyb/visualisation/__init__.py +++ b/src/alchemlyb/visualisation/__init__.py @@ -1,4 +1,4 @@ +from .convergence import plot_convergence +from .dF_state import plot_dF_state from .mbar_matrix import plot_mbar_overlap_matrix from .ti_dhdl import plot_ti_dhdl -from .dF_state import plot_dF_state -from .convergence import plot_convergence \ No newline at end of file diff --git a/src/alchemlyb/visualisation/convergence.py b/src/alchemlyb/visualisation/convergence.py index fcef3e50..c1cc477a 100644 --- a/src/alchemlyb/visualisation/convergence.py +++ b/src/alchemlyb/visualisation/convergence.py @@ -1,92 +1,91 @@ import matplotlib.pyplot as plt -import pandas as pd -from matplotlib.font_manager import FontProperties as FP import numpy as np +from matplotlib.font_manager import FontProperties as FP from ..postprocessors.units import get_unit_converter + def plot_convergence(dataframe, units=None, final_error=None, ax=None): """Plot the forward and backward convergence. - The input could be the result from - :func:`~alchemlyb.convergence.forward_backward_convergence` or - :func:`~alchemlyb.convergence.fwdrev_cumavg_Rc`. The input should be a - :class:`pandas.DataFrame` which has column `Forward`, `Backward` and - :attr:`pandas.DataFrame.attrs` should compile with :ref:`note-on-units`. - The errorbar will be plotted if column `Forward_Error` and `Backward_Error` - is present. - - `Forward`: A column of free energy estimate from the first X% of data, - where optional `Forward_Error` column is the corresponding error. - - `Backward`: A column of free energy estimate from the last X% of data., - where optional `Backward_Error` column is the corresponding error. - - `final_error` is the error of the final value and is shown as the error band around the - final value. It can be provided in case an estimate is available that is more appropriate - than the default, which is the error of the last value in `Backward`. - - Parameters - ---------- - dataframe : Dataframe - Output Dataframe has column `Forward`, `Backward` or optionally - `Forward_Error`, `Backward_Error` see :ref:`plot_convergence `. - units : str - The unit of the estimate. The default is `None`, which is to use the - unit in the input. Setting this will change the output unit. - final_error : float - The error of the final value in ``units``. If not given, takes the last - error in `backward_error`. - ax : matplotlib.axes.Axes - Matplotlib axes object where the plot will be drawn on. If ``ax=None``, - a new axes will be generated. - - Returns - ------- - matplotlib.axes.Axes - An axes with the forward and backward convergence drawn. - - Note - ---- - The code is taken and modified from - `Alchemical Analysis `_. - - - .. versionchanged:: 1.0.0 - Keyword arg final_error for plotting a horizontal error bar. - The array input has been deprecated. - The units default to `None` which uses the units in the input. - - .. versionchanged:: 0.6.0 - data now takes in dataframe - - .. versionadded:: 0.4.0 + The input could be the result from + :func:`~alchemlyb.convergence.forward_backward_convergence` or + :func:`~alchemlyb.convergence.fwdrev_cumavg_Rc`. The input should be a + :class:`pandas.DataFrame` which has column `Forward`, `Backward` and + :attr:`pandas.DataFrame.attrs` should compile with :ref:`note-on-units`. + The errorbar will be plotted if column `Forward_Error` and `Backward_Error` + is present. + + `Forward`: A column of free energy estimate from the first X% of data, + where optional `Forward_Error` column is the corresponding error. + + `Backward`: A column of free energy estimate from the last X% of data., + where optional `Backward_Error` column is the corresponding error. + + `final_error` is the error of the final value and is shown as the error band around the + final value. It can be provided in case an estimate is available that is more appropriate + than the default, which is the error of the last value in `Backward`. + + Parameters + ---------- + dataframe : Dataframe + Output Dataframe has column `Forward`, `Backward` or optionally + `Forward_Error`, `Backward_Error` see :ref:`plot_convergence `. + units : str + The unit of the estimate. The default is `None`, which is to use the + unit in the input. Setting this will change the output unit. + final_error : float + The error of the final value in ``units``. If not given, takes the last + error in `backward_error`. + ax : matplotlib.axes.Axes + Matplotlib axes object where the plot will be drawn on. If ``ax=None``, + a new axes will be generated. + + Returns + ------- + matplotlib.axes.Axes + An axes with the forward and backward convergence drawn. + + Note + ---- + The code is taken and modified from + `Alchemical Analysis `_. + + + .. versionchanged:: 1.0.0 + Keyword arg final_error for plotting a horizontal error bar. + The array input has been deprecated. + The units default to `None` which uses the units in the input. + + .. versionchanged:: 0.6.0 + data now takes in dataframe + + .. versionadded:: 0.4.0 """ if units is not None: dataframe = get_unit_converter(units)(dataframe) - forward = dataframe['Forward'].to_numpy() - if 'Forward_Error' in dataframe: - forward_error = dataframe['Forward_Error'].to_numpy() + forward = dataframe["Forward"].to_numpy() + if "Forward_Error" in dataframe: + forward_error = dataframe["Forward_Error"].to_numpy() else: forward_error = np.zeros(len(forward)) - backward = dataframe['Backward'].to_numpy() - if 'Backward_Error' in dataframe: - backward_error = dataframe['Backward_Error'].to_numpy() + backward = dataframe["Backward"].to_numpy() + if "Backward_Error" in dataframe: + backward_error = dataframe["Backward_Error"].to_numpy() else: backward_error = np.zeros(len(backward)) - - if ax is None: # pragma: no cover + if ax is None: # pragma: no cover fig, ax = plt.subplots(figsize=(8, 6)) - plt.setp(ax.spines['bottom'], color='#D2B9D3', lw=3, zorder=-2) - plt.setp(ax.spines['left'], color='#D2B9D3', lw=3, zorder=-2) + plt.setp(ax.spines["bottom"], color="#D2B9D3", lw=3, zorder=-2) + plt.setp(ax.spines["left"], color="#D2B9D3", lw=3, zorder=-2) - for dire in ['top', 'right']: - ax.spines[dire].set_color('none') + for dire in ["top", "right"]: + ax.spines[dire].set_color("none") - ax.xaxis.set_ticks_position('bottom') - ax.yaxis.set_ticks_position('left') + ax.xaxis.set_ticks_position("bottom") + ax.yaxis.set_ticks_position("left") f_ts = np.linspace(0, 1, len(forward) + 1)[1:] r_ts = np.linspace(0, 1, len(backward) + 1)[1:] @@ -94,28 +93,54 @@ def plot_convergence(dataframe, units=None, final_error=None, ax=None): if final_error is None: final_error = backward_error[-1] - line0 = ax.fill_between([0, 1], backward[-1] - final_error, - backward[-1] + final_error, color='#D2B9D3', - zorder=1) - line1 = ax.errorbar(f_ts, forward, yerr=forward_error, color='#736AFF', - lw=3, zorder=2, marker='o', - mfc='w', mew=2.5, mec='#736AFF', ms=12,) - line2 = ax.errorbar(r_ts, backward, yerr=backward_error, color='#C11B17', - lw=3, zorder=3, marker='o', - mfc='w', mew=2.5, mec='#C11B17', ms=12, ) + line0 = ax.fill_between( + [0, 1], + backward[-1] - final_error, + backward[-1] + final_error, + color="#D2B9D3", + zorder=1, + ) + line1 = ax.errorbar( + f_ts, + forward, + yerr=forward_error, + color="#736AFF", + lw=3, + zorder=2, + marker="o", + mfc="w", + mew=2.5, + mec="#736AFF", + ms=12, + ) + line2 = ax.errorbar( + r_ts, + backward, + yerr=backward_error, + color="#C11B17", + lw=3, + zorder=3, + marker="o", + mfc="w", + mew=2.5, + mec="#C11B17", + ms=12, + ) xticks_spacing = len(r_ts) // 10 or 1 xticks = r_ts[::xticks_spacing] - plt.xticks(xticks, ['%.2f' % i for i in xticks], fontsize=10) + plt.xticks(xticks, ["%.2f" % i for i in xticks], fontsize=10) plt.yticks(fontsize=10) - ax.legend((line1[0], line2[0]), ('Forward', 'Reverse'), loc=9, - prop=FP(size=18), frameon=False) - ax.set_xlabel(r'Fraction of the simulation time', fontsize=16, - color='#151B54') - ax.set_ylabel(r'$\Delta G$ ({})'.format(units), fontsize=16, color='#151B54') - plt.tick_params(axis='x', color='#D2B9D3') - plt.tick_params(axis='y', color='#D2B9D3') + ax.legend( + (line1[0], line2[0]), + ("Forward", "Reverse"), + loc=9, + prop=FP(size=18), + frameon=False, + ) + ax.set_xlabel(r"Fraction of the simulation time", fontsize=16, color="#151B54") + ax.set_ylabel(r"$\Delta G$ ({})".format(units), fontsize=16, color="#151B54") + plt.tick_params(axis="x", color="#D2B9D3") + plt.tick_params(axis="y", color="#D2B9D3") return ax - - diff --git a/src/alchemlyb/visualisation/dF_state.py b/src/alchemlyb/visualisation/dF_state.py index 8f5a1409..e36fbc21 100644 --- a/src/alchemlyb/visualisation/dF_state.py +++ b/src/alchemlyb/visualisation/dF_state.py @@ -9,15 +9,17 @@ """ import matplotlib.pyplot as plt -from matplotlib.font_manager import FontProperties as FP import numpy as np +from matplotlib.font_manager import FontProperties as FP from ..estimators import TI, BAR, MBAR from ..postprocessors.units import get_unit_converter -def plot_dF_state(estimators, labels=None, colors=None, units=None, - orientation='portrait', nb=10): - '''Plot the dhdl of TI. + +def plot_dF_state( + estimators, labels=None, colors=None, units=None, orientation="portrait", nb=10 +): + """Plot the dhdl of TI. Parameters ---------- @@ -57,11 +59,13 @@ def plot_dF_state(estimators, labels=None, colors=None, units=None, changing the figure legend. .. versionadded:: 0.4.0 - ''' + """ try: len(estimators) except TypeError: - estimators = [estimators, ] + estimators = [ + estimators, + ] formatted_data = [] for dhdl in estimators: @@ -69,10 +73,14 @@ def plot_dF_state(estimators, labels=None, colors=None, units=None, len(dhdl) formatted_data.append(dhdl) except TypeError: - formatted_data.append([dhdl, ]) + formatted_data.append( + [ + dhdl, + ] + ) if units is None: - units = formatted_data[0][0].delta_f_.attrs['energy_unit'] + units = formatted_data[0][0].delta_f_.attrs["energy_unit"] estimators = formatted_data @@ -96,47 +104,69 @@ def plot_dF_state(estimators, labels=None, colors=None, units=None, error_list.append(error) # Get the determine orientation - if orientation == 'landscape': + if orientation == "landscape": if max_length < 8: fig, ax = plt.subplots(figsize=(8, 6)) else: fig, ax = plt.subplots(figsize=(max_length, 6)) - axs = [ax, ] - xs = [np.arange(max_length), ] - elif orientation == 'portrait': + axs = [ + ax, + ] + xs = [ + np.arange(max_length), + ] + elif orientation == "portrait": if max_length < nb: - xs = [np.arange(max_length), ] + xs = [ + np.arange(max_length), + ] fig, ax = plt.subplots(figsize=(8, 6)) - axs = [ax, ] + axs = [ + ax, + ] else: xs = np.array_split(np.arange(max_length), max_length / nb + 1) fig, axs = plt.subplots(nrows=len(xs), figsize=(8, 6)) mnb = max([len(i) for i in xs]) else: - raise ValueError("Not recognising {}, only supports 'landscape' or 'portrait'.".format(orientation)) + raise ValueError( + "Not recognising {}, only supports 'landscape' or 'portrait'.".format( + orientation + ) + ) # Sort out the colors if colors is None: - colors_dict = {'TI': '#C45AEC', 'TI-CUBIC': '#33CC33', - 'DEXP': '#F87431', 'IEXP': '#FF3030', 'GINS': '#EAC117', - 'GDEL': '#347235', 'BAR': '#6698FF', 'UBAR': '#817339', - 'RBAR': '#C11B17', 'MBAR': '#F9B7FF'} + colors_dict = { + "TI": "#C45AEC", + "TI-CUBIC": "#33CC33", + "DEXP": "#F87431", + "IEXP": "#FF3030", + "GINS": "#EAC117", + "GDEL": "#347235", + "BAR": "#6698FF", + "UBAR": "#817339", + "RBAR": "#C11B17", + "MBAR": "#F9B7FF", + } colors = [] for dhdl in estimators: dhdl = dhdl[0] if isinstance(dhdl, TI): - colors.append(colors_dict['TI']) + colors.append(colors_dict["TI"]) elif isinstance(dhdl, BAR): - colors.append(colors_dict['BAR']) + colors.append(colors_dict["BAR"]) elif isinstance(dhdl, MBAR): - colors.append(colors_dict['MBAR']) + colors.append(colors_dict["MBAR"]) else: if len(colors) >= len(estimators): pass else: raise ValueError( - 'Number of colors ({}) should be larger than the number of data ({})'.format( - len(colors), len(estimators))) + "Number of colors ({}) should be larger than the number of data ({})".format( + len(colors), len(estimators) + ) + ) # Sort out the labels if labels is None: @@ -144,21 +174,23 @@ def plot_dF_state(estimators, labels=None, colors=None, units=None, for dhdl in estimators: dhdl = dhdl[0] if isinstance(dhdl, TI): - labels.append('TI') + labels.append("TI") elif isinstance(dhdl, BAR): - labels.append('BAR') + labels.append("BAR") elif isinstance(dhdl, MBAR): - labels.append('MBAR') + labels.append("MBAR") else: if len(labels) == len(estimators): pass else: raise ValueError( - 'Length of labels ({}) should be the same as the number of data ({})'.format( - len(labels), len(estimators))) + "Length of labels ({}) should be the same as the number of data ({})".format( + len(labels), len(estimators) + ) + ) # Plot the figure - width = 1. / (len(estimators) + 1) + width = 1.0 / (len(estimators) + 1) elw = 30 * width ndx = 1 for x, ax in zip(xs, axs): @@ -166,35 +198,49 @@ def plot_dF_state(estimators, labels=None, colors=None, units=None, for i, (dF, error) in enumerate(zip(dF_list, error_list)): y = [dF[j] for j in x] ye = [error[j] for j in x] - if orientation == 'landscape': + if orientation == "landscape": lw = 0.1 * elw - elif orientation == 'portrait': + elif orientation == "portrait": lw = 0.05 * elw - line = ax.bar(x + len(lines) * width, y, width, - color=colors[i], yerr=ye, lw=lw, - error_kw=dict(elinewidth=elw, ecolor='black', - capsize=0.5 * elw)) + line = ax.bar( + x + len(lines) * width, + y, + width, + color=colors[i], + yerr=ye, + lw=lw, + error_kw=dict(elinewidth=elw, ecolor="black", capsize=0.5 * elw), + ) lines += (line[0],) - for dir in ['left', 'right', 'top', 'bottom']: - if dir == 'left': + for dir in ["left", "right", "top", "bottom"]: + if dir == "left": ax.yaxis.set_ticks_position(dir) else: - ax.spines[dir].set_color('none') + ax.spines[dir].set_color("none") - if orientation == 'landscape': + if orientation == "landscape": plt.yticks(fontsize=8) - ax.set_xlim(x[0]-width, x[-1] + len(lines) * width) - plt.xticks(x + 0.5 * width * len(estimators), - tuple(['%d--%d' % (i, i + 1) for i in x]), fontsize=8) - elif orientation == 'portrait': + ax.set_xlim(x[0] - width, x[-1] + len(lines) * width) + plt.xticks( + x + 0.5 * width * len(estimators), + tuple(["%d--%d" % (i, i + 1) for i in x]), + fontsize=8, + ) + elif orientation == "portrait": plt.yticks(fontsize=10) ax.xaxis.set_ticks([]) for i in x + 0.5 * width * len(estimators): - ax.annotate(r'$\mathrm{%d-%d}$' % (i, i + 1), xy=(i, 0), - xycoords=('data', 'axes fraction'), xytext=(0, -2), - size=10, textcoords='offset points', va='top', - ha='center') - ax.set_xlim(x[0]-width, x[-1]+len(lines)*width + (mnb - len(x))) + ax.annotate( + r"$\mathrm{%d-%d}$" % (i, i + 1), + xy=(i, 0), + xycoords=("data", "axes fraction"), + xytext=(0, -2), + size=10, + textcoords="offset points", + va="top", + ha="center", + ) + ax.set_xlim(x[0] - width, x[-1] + len(lines) * width + (mnb - len(x))) ndx += 1 x = np.arange(max_length) @@ -202,18 +248,21 @@ def plot_dF_state(estimators, labels=None, colors=None, units=None, for tick in ax.get_xticklines(): tick.set_visible(False) - if orientation == 'landscape': - leg = plt.legend(lines, labels, loc=3, ncol=2, prop=FP(size=10), - fancybox=True) - plt.title('The free energy change breakdown', fontsize=12) - plt.xlabel('States', fontsize=12, color='#151B54') - plt.ylabel(r'$\Delta G$ ({})'.format(units), fontsize=12, color='#151B54') - elif orientation == 'portrait': - leg = ax.legend(lines, labels, loc=0, ncol=2, - prop=FP(size=8), - title=r'$\Delta G$ ({})'.format(units) + - r'$\mathit{vs.}$ lambda pair', - fancybox=True) + if orientation == "landscape": + leg = plt.legend(lines, labels, loc=3, ncol=2, prop=FP(size=10), fancybox=True) + plt.title("The free energy change breakdown", fontsize=12) + plt.xlabel("States", fontsize=12, color="#151B54") + plt.ylabel(r"$\Delta G$ ({})".format(units), fontsize=12, color="#151B54") + elif orientation == "portrait": + leg = ax.legend( + lines, + labels, + loc=0, + ncol=2, + prop=FP(size=8), + title=r"$\Delta G$ ({})".format(units) + r"$\mathit{vs.}$ lambda pair", + fancybox=True, + ) leg.get_frame().set_alpha(0.5) return fig diff --git a/src/alchemlyb/visualisation/mbar_matrix.py b/src/alchemlyb/visualisation/mbar_matrix.py index 4b2bd952..6bdc068e 100644 --- a/src/alchemlyb/visualisation/mbar_matrix.py +++ b/src/alchemlyb/visualisation/mbar_matrix.py @@ -13,8 +13,9 @@ import matplotlib.pyplot as plt import numpy as np + def plot_mbar_overlap_matrix(matrix, skip_lambda_index=[], ax=None): - '''Plot the MBAR overlap matrix. + """Plot the MBAR overlap matrix. Parameters ---------- @@ -41,7 +42,7 @@ def plot_mbar_overlap_matrix(matrix, skip_lambda_index=[], ax=None): .. versionadded:: 0.4.0 - ''' + """ # Compute the size of the figure, if ax is not given. max_prob = matrix.max() size = len(matrix) @@ -49,25 +50,36 @@ def plot_mbar_overlap_matrix(matrix, skip_lambda_index=[], ax=None): fig, ax = plt.subplots(figsize=(size / 2, size / 2)) ax.set_xticks([]) ax.set_yticks([]) - ax.axis('off') + ax.axis("off") for i in range(size): if i != 0: - ax.axvline(x=i, ls='-', lw=0.5, color='k', alpha=0.25) - ax.axhline(y=i, ls='-', lw=0.5, color='k', alpha=0.25) + ax.axvline(x=i, ls="-", lw=0.5, color="k", alpha=0.25) + ax.axhline(y=i, ls="-", lw=0.5, color="k", alpha=0.25) for j in range(size): if matrix[j, i] < 0.005: - ii = '' + ii = "" elif matrix[j, i] > 0.995: - ii = '1.00' + ii = "1.00" else: - ii = ("{:.2f}".format(matrix[j, i])[1:]) + ii = "{:.2f}".format(matrix[j, i])[1:] alf = matrix[j, i] / max_prob - ax.fill_between([i, i + 1], [size - j, size - j], - [size - (j + 1), size - (j + 1)], color='k', - alpha=alf) - ax.annotate(ii, xy=(i, j), xytext=(i + 0.5, size - (j + 0.5)), - size=8, textcoords='data', va='center', - ha='center', color=('k' if alf < 0.5 else 'w')) + ax.fill_between( + [i, i + 1], + [size - j, size - j], + [size - (j + 1), size - (j + 1)], + color="k", + alpha=alf, + ) + ax.annotate( + ii, + xy=(i, j), + xytext=(i + 0.5, size - (j + 0.5)), + size=8, + textcoords="data", + va="center", + ha="center", + color=("k" if alf < 0.5 else "w"), + ) if skip_lambda_index: ks = [int(l) for l in skip_lambda_index] @@ -75,31 +87,48 @@ def plot_mbar_overlap_matrix(matrix, skip_lambda_index=[], ax=None): else: ks = range(size) for i in range(size): - ax.annotate(ks[i], xy=(i + 0.5, 1), xytext=(i + 0.5, size + 0.5), - size=10, textcoords=('data', 'data'), - va='center', ha='center', color='k') - ax.annotate(ks[i], xy=(-0.5, size - (size - 0.5)), - xytext=(-0.5, size - (i + 0.5)), - size=10, textcoords=('data', 'data'), - va='center', ha='center', color='k') - ax.annotate(r'$\lambda$', xy=(-0.5, size - (size - 0.5)), - xytext=(-0.5, size + 0.5), - size=10, textcoords=('data', 'data'), - va='center', ha='center', color='k') - ax.plot([0, size], [0, 0], 'k-', lw=4.0, solid_capstyle='butt') - ax.plot([size, size], [0, size], 'k-', lw=4.0, solid_capstyle='butt') - ax.plot([0, 0], [0, size], 'k-', lw=2.0, solid_capstyle='butt') - ax.plot([0, size], [size, size], 'k-', lw=2.0, solid_capstyle='butt') + ax.annotate( + ks[i], + xy=(i + 0.5, 1), + xytext=(i + 0.5, size + 0.5), + size=10, + textcoords=("data", "data"), + va="center", + ha="center", + color="k", + ) + ax.annotate( + ks[i], + xy=(-0.5, size - (size - 0.5)), + xytext=(-0.5, size - (i + 0.5)), + size=10, + textcoords=("data", "data"), + va="center", + ha="center", + color="k", + ) + ax.annotate( + r"$\lambda$", + xy=(-0.5, size - (size - 0.5)), + xytext=(-0.5, size + 0.5), + size=10, + textcoords=("data", "data"), + va="center", + ha="center", + color="k", + ) + ax.plot([0, size], [0, 0], "k-", lw=4.0, solid_capstyle="butt") + ax.plot([size, size], [0, size], "k-", lw=4.0, solid_capstyle="butt") + ax.plot([0, 0], [0, size], "k-", lw=2.0, solid_capstyle="butt") + ax.plot([0, size], [size, size], "k-", lw=2.0, solid_capstyle="butt") cx = np.repeat(range(size + 1), 2) cy = sorted(np.repeat(range(size + 1), 2), reverse=True) - ax.plot(cx[2:-1], cy[1:-2], 'k-', lw=2.0) - ax.plot(np.array(cx[2:-3]) + 1, cy[1:-4], 'k-', lw=2.0) - ax.plot(cx[1:-2], np.array(cy[:-3]) - 1, 'k-', lw=2.0) - ax.plot(cx[1:-4], np.array(cy[:-5]) - 2, 'k-', lw=2.0) + ax.plot(cx[2:-1], cy[1:-2], "k-", lw=2.0) + ax.plot(np.array(cx[2:-3]) + 1, cy[1:-4], "k-", lw=2.0) + ax.plot(cx[1:-2], np.array(cy[:-3]) - 1, "k-", lw=2.0) + ax.plot(cx[1:-4], np.array(cy[:-5]) - 2, "k-", lw=2.0) ax.set_xlim(-1, size) ax.set_ylim(0, size + 1) return ax - - diff --git a/src/alchemlyb/visualisation/ti_dhdl.py b/src/alchemlyb/visualisation/ti_dhdl.py index c071a97d..6dacb6dc 100644 --- a/src/alchemlyb/visualisation/ti_dhdl.py +++ b/src/alchemlyb/visualisation/ti_dhdl.py @@ -10,14 +10,14 @@ """ import matplotlib.pyplot as plt -from matplotlib.font_manager import FontProperties as FP import numpy as np +from matplotlib.font_manager import FontProperties as FP from ..postprocessors.units import get_unit_converter -def plot_ti_dhdl(dhdl_data, labels=None, colors=None, units=None, - ax=None): - '''Plot the dhdl of TI. + +def plot_ti_dhdl(dhdl_data, labels=None, colors=None, units=None, ax=None): + """Plot the dhdl of TI. Parameters ---------- @@ -55,7 +55,7 @@ def plot_ti_dhdl(dhdl_data, labels=None, colors=None, units=None, changing the figure legend. .. versionadded:: 0.4.0 - ''' + """ # Make it into a list # separate_dhdl method is used so that the input for the actual plotting # Function are a uniformed list of series object which only contains one @@ -69,7 +69,7 @@ def plot_ti_dhdl(dhdl_data, labels=None, colors=None, units=None, # Convert unit if units is None: - units = dhdl_list[0].attrs['energy_unit'] + units = dhdl_list[0].attrs["energy_unit"] new_unit = [] convert = get_unit_converter(units) @@ -80,11 +80,11 @@ def plot_ti_dhdl(dhdl_data, labels=None, colors=None, units=None, if ax is None: fig, ax = plt.subplots(figsize=(8, 6)) - ax.spines['bottom'].set_position('zero') - ax.spines['top'].set_color('none') - ax.spines['right'].set_color('none') - ax.xaxis.set_ticks_position('bottom') - ax.yaxis.set_ticks_position('left') + ax.spines["bottom"].set_position("zero") + ax.spines["top"].set_color("none") + ax.spines["right"].set_color("none") + ax.xaxis.set_ticks_position("bottom") + ax.yaxis.set_ticks_position("left") for k, spine in ax.spines.items(): spine.set_zorder(12.2) @@ -98,20 +98,24 @@ def plot_ti_dhdl(dhdl_data, labels=None, colors=None, units=None, else: if len(labels) == len(dhdl_list): lv_names2 = labels - else: # pragma: no cover + else: # pragma: no cover raise ValueError( - 'Length of labels ({}) should be the same as the number of data ({})'.format( - len(labels), len(dhdl_list))) + "Length of labels ({}) should be the same as the number of data ({})".format( + len(labels), len(dhdl_list) + ) + ) if colors is None: - colors = ['r', 'g', '#7F38EC', '#9F000F', 'b', 'y'] + colors = ["r", "g", "#7F38EC", "#9F000F", "b", "y"] else: if len(colors) >= len(dhdl_list): pass - else: # pragma: no cover + else: # pragma: no cover raise ValueError( - 'Number of colors ({}) should be larger than the number of data ({})'.format( - len(labels), len(dhdl_list))) + "Number of colors ({}) should be larger than the number of data ({})".format( + len(labels), len(dhdl_list) + ) + ) # Get the real data out xs, ndx, dx = [0], 0, 0.001 @@ -125,16 +129,22 @@ def plot_ti_dhdl(dhdl_data, labels=None, colors=None, units=None, for i in range(len(x) - 1): if i % 2 == 0: - ax.fill_between(x[i:i + 2] + ndx, 0, y[i:i + 2], - color=colors[ndx], alpha=1.0) + ax.fill_between( + x[i : i + 2] + ndx, 0, y[i : i + 2], color=colors[ndx], alpha=1.0 + ) else: - ax.fill_between(x[i:i + 2] + ndx, 0, y[i:i + 2], - color=colors[ndx], alpha=0.5) + ax.fill_between( + x[i : i + 2] + ndx, 0, y[i : i + 2], color=colors[ndx], alpha=0.5 + ) xlegend = [-100 * wnum for wnum in range(len(lv_names2))] - ax.plot(xlegend, [0 * wnum for wnum in xlegend], ls='-', - color=colors[ndx], - label=lv_names2[ndx]) + ax.plot( + xlegend, + [0 * wnum for wnum in xlegend], + ls="-", + color=colors[ndx], + label=lv_names2[ndx], + ) xs += (x + ndx).tolist()[1:] ndx += 1 @@ -159,7 +169,7 @@ def getInd(r=ri, z=[0]): if i in getInd(): xt.append(i) else: - xt.append('') + xt.append("") plt.xticks(xs[1:], xt[1:], fontsize=10) ax.yaxis.label.set_size(10) @@ -172,31 +182,46 @@ def getInd(r=ri, z=[0]): max_y *= 1.01 # Modified so that the x label won't conflict with the lambda label - min_y -= (max_y-min_y)*0.1 + min_y -= (max_y - min_y) * 0.1 ax.set_ylim(min_y, max_y) for i, j in zip(xs[1:], xt[1:]): ax.annotate( - ('%.2f' % (i - 1.0 if i > 1.0 else i) if not j == '' else ''), - xy=(i, 0), size=10, rotation=90, va='bottom', ha='center', - color='#151B54') + ("%.2f" % (i - 1.0 if i > 1.0 else i) if not j == "" else ""), + xy=(i, 0), + size=10, + rotation=90, + va="bottom", + ha="center", + color="#151B54", + ) if ndx > 1: lenticks = len(ax.get_ymajorticklabels()) - 1 - if min_y < 0: lenticks -= 1 + if min_y < 0: + lenticks -= 1 if lenticks < 5: # pragma: no cover from matplotlib.ticker import AutoMinorLocator as AML + ax.yaxis.set_minor_locator(AML()) - ax.grid(which='both', color='w', lw=0.25, axis='y', zorder=12) + ax.grid(which="both", color="w", lw=0.25, axis="y", zorder=12) ax.set_ylabel( - r'$\langle{\frac{\partial U}{\partial\lambda}}\rangle_{\lambda}$' + - '({})'.format(units), - fontsize=20, color='#151B54') - ax.annotate(r'$\mathit{\lambda}$', xy=(0, 0), xytext=(0.5, -0.05), size=18, - textcoords='axes fraction', va='top', ha='center', - color='#151B54') + r"$\langle{\frac{\partial U}{\partial\lambda}}\rangle_{\lambda}$" + + "({})".format(units), + fontsize=20, + color="#151B54", + ) + ax.annotate( + r"$\mathit{\lambda}$", + xy=(0, 0), + xytext=(0.5, -0.05), + size=18, + textcoords="axes fraction", + va="top", + ha="center", + color="#151B54", + ) lege = ax.legend(prop=FP(size=14), frameon=False, loc=1) for l in lege.legendHandles: l.set_linewidth(10) return ax - diff --git a/src/alchemlyb/workflows/__init__.py b/src/alchemlyb/workflows/__init__.py index 6b35d460..a6a156cf 100644 --- a/src/alchemlyb/workflows/__init__.py +++ b/src/alchemlyb/workflows/__init__.py @@ -1,4 +1,5 @@ __all__ = [ - 'base', + "base", ] + from .abfe import ABFE diff --git a/src/alchemlyb/workflows/abfe.py b/src/alchemlyb/workflows/abfe.py index 32e51a7c..9fef25c1 100644 --- a/src/alchemlyb/workflows/abfe.py +++ b/src/alchemlyb/workflows/abfe.py @@ -1,26 +1,31 @@ +import logging import os -from os.path import join from glob import glob -import pandas as pd -import numpy as np -import logging +from os.path import join + import matplotlib.pyplot as plt +import numpy as np +import pandas as pd from .base import WorkflowBase -from ..parsing import gmx, amber -from ..preprocessing.subsampling import decorrelate_dhdl, decorrelate_u_nk -from ..estimators import BAR, TI, FEP_ESTIMATORS, TI_ESTIMATORS +from .. import __version__ +from .. import concat +from ..convergence import forward_backward_convergence from ..estimators import AutoMBAR as MBAR -from ..visualisation import (plot_mbar_overlap_matrix, plot_ti_dhdl, - plot_dF_state, plot_convergence) +from ..estimators import BAR, TI, FEP_ESTIMATORS, TI_ESTIMATORS +from ..parsing import gmx, amber from ..postprocessors.units import get_unit_converter -from ..convergence import forward_backward_convergence -from .. import concat -from .. import __version__ +from ..preprocessing.subsampling import decorrelate_dhdl, decorrelate_u_nk +from ..visualisation import ( + plot_mbar_overlap_matrix, + plot_ti_dhdl, + plot_dF_state, + plot_convergence, +) class ABFE(WorkflowBase): - '''Workflow for absolute and relative binding free energy calculations. + """Workflow for absolute and relative binding free energy calculations. This workflow provides functionality similar to the ``alchemical-analysis.py`` script. It loads multiple input files from alchemical free energy calculations and computes the @@ -58,42 +63,50 @@ class ABFE(WorkflowBase): .. versionadded:: 1.0.0 - ''' - def __init__(self, T, units='kT', software='GROMACS', dir=os.path.curdir, - prefix='dhdl', suffix='xvg', - outdirectory=os.path.curdir): + """ + + def __init__( + self, + T, + units="kT", + software="GROMACS", + dir=os.path.curdir, + prefix="dhdl", + suffix="xvg", + outdirectory=os.path.curdir, + ): super().__init__(units, software, T, outdirectory) - self.logger = logging.getLogger('alchemlyb.workflows.ABFE') - self.logger.info('Initialise Alchemlyb ABFE Workflow') - self.logger.info(f'Alchemlyb Version: f{__version__}') - self.logger.info(f'Set Temperature to {T} K.') - self.logger.info(f'Set Software to {software}.') + self.logger = logging.getLogger("alchemlyb.workflows.ABFE") + self.logger.info("Initialise Alchemlyb ABFE Workflow") + self.logger.info(f"Alchemlyb Version: f{__version__}") + self.logger.info(f"Set Temperature to {T} K.") + self.logger.info(f"Set Software to {software}.") self.update_units(units) - self.logger.info(f'Finding files with prefix: {prefix}, suffix: ' - f'{suffix} under directory {dir} produced by ' - f'{software}') - self.file_list = glob(dir + '/**/' + prefix + '*' + suffix, - recursive=True) + self.logger.info( + f"Finding files with prefix: {prefix}, suffix: " + f"{suffix} under directory {dir} produced by " + f"{software}" + ) + self.file_list = glob(dir + "/**/" + prefix + "*" + suffix, recursive=True) - self.logger.info(f'Found {len(self.file_list)} xvg files.') - self.logger.info("Unsorted file list: \n %s", '\n'.join( - self.file_list)) + self.logger.info(f"Found {len(self.file_list)} xvg files.") + self.logger.info("Unsorted file list: \n %s", "\n".join(self.file_list)) - if software == 'GROMACS': - self.logger.info(f'Using {software} parser to read the data.') + if software == "GROMACS": + self.logger.info(f"Using {software} parser to read the data.") self._extract_u_nk = gmx.extract_u_nk self._extract_dHdl = gmx.extract_dHdl - elif software == 'AMBER': + elif software == "AMBER": self._extract_u_nk = amber.extract_u_nk self._extract_dHdl = amber.extract_dHdl else: - raise NotImplementedError(f'{software} parser not found.') + raise NotImplementedError(f"{software} parser not found.") def read(self, read_u_nk=True, read_dHdl=True): - '''Read the u_nk and dHdL data from the + """Read the u_nk and dHdL data from the :attr:`~alchemlyb.workflows.ABFE.file_list` Parameters @@ -109,7 +122,7 @@ def read(self, read_u_nk=True, read_dHdl=True): A list of :class:`pandas.DataFrame` of u_nk. dHdl_list : list A list of :class:`pandas.DataFrame` of dHdl. - ''' + """ self.u_nk_sample_list = None self.dHdl_sample_list = None @@ -119,46 +132,46 @@ def read(self, read_u_nk=True, read_dHdl=True): if read_u_nk: try: u_nk = self._extract_u_nk(file, T=self.T) - self.logger.info( - f'Reading {len(u_nk)} lines of u_nk from {file}') + self.logger.info(f"Reading {len(u_nk)} lines of u_nk from {file}") u_nk_list.append(u_nk) except Exception as exc: - msg = f'Error reading u_nk from {file}.' + msg = f"Error reading u_nk from {file}." self.logger.error(msg) raise OSError(msg) from exc if read_dHdl: try: dhdl = self._extract_dHdl(file, T=self.T) - self.logger.info( - f'Reading {len(dhdl)} lines of dhdl from {file}') + self.logger.info(f"Reading {len(dhdl)} lines of dhdl from {file}") dHdl_list.append(dhdl) except Exception as exc: - msg = f'Error reading dHdl from {file}.' + msg = f"Error reading dHdl from {file}." self.logger.error(msg) raise OSError(msg) from exc # Sort the files according to the state if read_u_nk: - self.logger.info('Sort files according to the u_nk.') + self.logger.info("Sort files according to the u_nk.") column_names = u_nk_list[0].columns.values.tolist() - index_list = sorted(range(len(self.file_list)), - key=lambda x: column_names.index( - u_nk_list[x].reset_index( - 'time').index.values[0])) + index_list = sorted( + range(len(self.file_list)), + key=lambda x: column_names.index( + u_nk_list[x].reset_index("time").index.values[0] + ), + ) elif read_dHdl: - self.logger.info('Sort files according to the dHdl.') - index_list = sorted(range(len(self.file_list)), - key=lambda x: - dHdl_list[x].reset_index( - 'time').index.values[0]) + self.logger.info("Sort files according to the dHdl.") + index_list = sorted( + range(len(self.file_list)), + key=lambda x: dHdl_list[x].reset_index("time").index.values[0], + ) else: self.u_nk_list = [] self.dHdl_list = [] return self.file_list = [self.file_list[i] for i in index_list] - self.logger.info("Sorted file list: \n%s", '\n'.join(self.file_list)) + self.logger.info("Sorted file list: \n%s", "\n".join(self.file_list)) if read_u_nk: self.u_nk_list = [u_nk_list[i] for i in index_list] else: @@ -169,11 +182,19 @@ def read(self, read_u_nk=True, read_dHdl=True): else: self.dHdl_list = [] - - def run(self, skiptime=0, uncorr='dE', threshold=50, - estimators=('MBAR', 'BAR', 'TI'), overlap='O_MBAR.pdf', - breakdown=True, forwrev=None, *args, **kwargs): - ''' The method for running the automatic analysis. + def run( + self, + skiptime=0, + uncorr="dE", + threshold=50, + estimators=("MBAR", "BAR", "TI"), + overlap="O_MBAR.pdf", + breakdown=True, + forwrev=None, + *args, + **kwargs, + ): + """The method for running the automatic analysis. Parameters ---------- @@ -214,29 +235,32 @@ def run(self, skiptime=0, uncorr='dE', threshold=50, The summary of the convergence results. See :func:`~alchemlyb.convergence.forward_backward_convergence` for further explanation. - ''' + """ use_FEP = False use_TI = False if estimators is not None: if isinstance(estimators, str): - estimators = [estimators, ] + estimators = [ + estimators, + ] for estimator in estimators: if estimator in FEP_ESTIMATORS: use_FEP = True elif estimator in TI_ESTIMATORS: use_TI = True else: - msg = f"Estimator {estimator} is not supported. Choose one from " \ - f"{FEP_ESTIMATORS + TI_ESTIMATORS}." + msg = ( + f"Estimator {estimator} is not supported. Choose one from " + f"{FEP_ESTIMATORS + TI_ESTIMATORS}." + ) self.logger.error(msg) raise ValueError(msg) self.read(use_FEP, use_TI) if uncorr is not None: - self.preprocess(skiptime=skiptime, uncorr=uncorr, - threshold=threshold) + self.preprocess(skiptime=skiptime, uncorr=uncorr, threshold=threshold) if estimators is not None: self.estimate(estimators) self.generate_result() @@ -251,31 +275,30 @@ def run(self, skiptime=0, uncorr='dE', threshold=50, plt.close(ax.figure) fig = self.plot_dF_state() plt.close(fig) - fig = self.plot_dF_state(dF_state='dF_state_long.pdf', - orientation='landscape') + fig = self.plot_dF_state( + dF_state="dF_state_long.pdf", orientation="landscape" + ) plt.close(fig) if forwrev is not None: - ax = self.check_convergence(forwrev, estimator='MBAR', - dF_t='dF_t.pdf') + ax = self.check_convergence(forwrev, estimator="MBAR", dF_t="dF_t.pdf") plt.close(ax.figure) - def update_units(self, units=None): - '''Update the unit. + """Update the unit. Parameters ---------- units : {'kcal/mol', 'kJ/mol', 'kT'} The unit used for printing and plotting results. - ''' + """ if units is not None: - self.logger.info(f'Set unit to {units}.') + self.logger.info(f"Set unit to {units}.") self.units = units or None - def preprocess(self, skiptime=0, uncorr='dE', threshold=50): - '''Preprocess the data by removing the equilibration time and + def preprocess(self, skiptime=0, uncorr="dE", threshold=50): + """Preprocess the data by removing the equilibration time and decorrelate the date. Parameters @@ -296,54 +319,65 @@ def preprocess(self, skiptime=0, uncorr='dE', threshold=50): The list of u_nk after decorrelation. dHdl_sample_list : list The list of dHdl after decorrelation. - ''' - self.logger.info(f'Start preprocessing with skiptime of {skiptime} ' - f'uncorrelation method of {uncorr} and threshold of ' - f'{threshold}') + """ + self.logger.info( + f"Start preprocessing with skiptime of {skiptime} " + f"uncorrelation method of {uncorr} and threshold of " + f"{threshold}" + ) if len(self.u_nk_list) > 0: self.logger.info( - f'Processing the u_nk data set with skiptime of {skiptime}.') + f"Processing the u_nk data set with skiptime of {skiptime}." + ) self.u_nk_sample_list = [] for index, u_nk in enumerate(self.u_nk_list): # Find the starting frame - u_nk = u_nk[u_nk.index.get_level_values('time') >= skiptime] + u_nk = u_nk[u_nk.index.get_level_values("time") >= skiptime] subsample = decorrelate_u_nk(u_nk, uncorr, remove_burnin=True) if len(subsample) < threshold: - self.logger.warning(f'Number of u_nk {len(subsample)} ' - f'for state {index} is less than the ' - f'threshold {threshold}.') - self.logger.info(f'Take all the u_nk for state {index}.') + self.logger.warning( + f"Number of u_nk {len(subsample)} " + f"for state {index} is less than the " + f"threshold {threshold}." + ) + self.logger.info(f"Take all the u_nk for state {index}.") self.u_nk_sample_list.append(u_nk) else: - self.logger.info(f'Take {len(subsample)} uncorrelated ' - f'u_nk for state {index}.') + self.logger.info( + f"Take {len(subsample)} uncorrelated " + f"u_nk for state {index}." + ) self.u_nk_sample_list.append(subsample) else: - self.logger.info('No u_nk data being subsampled') + self.logger.info("No u_nk data being subsampled") if len(self.dHdl_list) > 0: self.dHdl_sample_list = [] for index, dHdl in enumerate(self.dHdl_list): - dHdl = dHdl[dHdl.index.get_level_values('time') >= skiptime] + dHdl = dHdl[dHdl.index.get_level_values("time") >= skiptime] subsample = decorrelate_dhdl(dHdl, remove_burnin=True) if len(subsample) < threshold: - self.logger.warning(f'Number of dHdl {len(subsample)} for ' - f'state {index} is less than the ' - f'threshold {threshold}.') - self.logger.info(f'Take all the dHdl for state {index}.') + self.logger.warning( + f"Number of dHdl {len(subsample)} for " + f"state {index} is less than the " + f"threshold {threshold}." + ) + self.logger.info(f"Take all the dHdl for state {index}.") self.dHdl_sample_list.append(dHdl) else: - self.logger.info(f'Take {len(subsample)} uncorrelated ' - f'dHdl for state {index}.') + self.logger.info( + f"Take {len(subsample)} uncorrelated " + f"dHdl for state {index}." + ) self.dHdl_sample_list.append(subsample) else: - self.logger.info('No dHdl data being subsampled') + self.logger.info("No dHdl data being subsampled") - def estimate(self, estimators=('MBAR', 'BAR', 'TI'), **kwargs): - '''Estimate the free energy using the selected estimator. + def estimate(self, estimators=("MBAR", "BAR", "TI"), **kwargs): + """Estimate the free energy using the selected estimator. Parameters ---------- @@ -368,10 +402,10 @@ def estimate(self, estimators=('MBAR', 'BAR', 'TI'), **kwargs): behavior of :class:`~alchemlyb.estimators.MBAR`. (:code:`estimate(estimators='MBAR', method='adaptive')`) - ''' + """ # Make estimators into a tuple if isinstance(estimators, str): - estimators = (estimators, ) + estimators = (estimators,) for estimator in estimators: if estimator not in (FEP_ESTIMATORS + TI_ESTIMATORS): @@ -379,42 +413,38 @@ def estimate(self, estimators=('MBAR', 'BAR', 'TI'), **kwargs): self.logger.error(msg) raise ValueError(msg) - self.logger.info( - f"Start running estimator: {','.join(estimators)}.") + self.logger.info(f"Start running estimator: {','.join(estimators)}.") self.estimator = {} # Use unprocessed data if preprocess is not performed. - if 'TI' in estimators: + if "TI" in estimators: if self.dHdl_sample_list is not None: dHdl = concat(self.dHdl_sample_list) else: dHdl = concat(self.dHdl_list) - self.logger.warning('dHdl has not been preprocessed.') - self.logger.info( - f'A total {len(dHdl)} lines of dHdl is used.') + self.logger.warning("dHdl has not been preprocessed.") + self.logger.info(f"A total {len(dHdl)} lines of dHdl is used.") - if 'BAR' in estimators or 'MBAR' in estimators: + if "BAR" in estimators or "MBAR" in estimators: if self.u_nk_sample_list is not None: u_nk = concat(self.u_nk_sample_list) else: u_nk = concat(self.u_nk_list) - self.logger.warning('u_nk has not been preprocessed.') - self.logger.info( - f'A total {len(u_nk)} lines of u_nk is used.') + self.logger.warning("u_nk has not been preprocessed.") + self.logger.info(f"A total {len(u_nk)} lines of u_nk is used.") for estimator in estimators: - if estimator == 'MBAR': - self.logger.info('Run MBAR estimator.') + if estimator == "MBAR": + self.logger.info("Run MBAR estimator.") self.estimator[estimator] = MBAR(**kwargs).fit(u_nk) - elif estimator == 'BAR': - self.logger.info('Run BAR estimator.') + elif estimator == "BAR": + self.logger.info("Run BAR estimator.") self.estimator[estimator] = BAR(**kwargs).fit(u_nk) - elif estimator == 'TI': - self.logger.info('Run TI estimator.') + elif estimator == "TI": + self.logger.info("Run TI estimator.") self.estimator[estimator] = TI(**kwargs).fit(dHdl) - def generate_result(self): - '''Summarise the result into a dataframe. + """Summarise the result into a dataframe. Returns ------- @@ -460,38 +490,37 @@ def generate_result(self): ---------- summary : Dataframe The summary of the free energy estimate. - ''' + """ # Write estimate - self.logger.info('Summarise the estimate into a dataframe.') + self.logger.info("Summarise the estimate into a dataframe.") # Make the header name - self.logger.info('Generate the row names.') + self.logger.info("Generate the row names.") estimator_names = list(self.estimator.keys()) num_states = len(self.estimator[estimator_names[0]].states_) - data_dict = {'name': [], - 'state': []} + data_dict = {"name": [], "state": []} for i in range(num_states - 1): - data_dict['name'].append(str(i) + ' -- ' + str(i+1)) - data_dict['state'].append('States') + data_dict["name"].append(str(i) + " -- " + str(i + 1)) + data_dict["state"].append("States") try: u_nk = self.u_nk_list[0] - stages = u_nk.reset_index('time').index.names - self.logger.info('use the stage name from u_nk') + stages = u_nk.reset_index("time").index.names + self.logger.info("use the stage name from u_nk") except: dHdl = self.dHdl_list[0] - stages = dHdl.reset_index('time').index.names - self.logger.info('use the stage name from dHdl') + stages = dHdl.reset_index("time").index.names + self.logger.info("use the stage name from dHdl") for stage in stages: - data_dict['name'].append(stage.split('-')[0]) - data_dict['state'].append('Stages') - data_dict['name'].append('TOTAL') - data_dict['state'].append('Stages') + data_dict["name"].append(stage.split("-")[0]) + data_dict["state"].append("Stages") + data_dict["name"].append("TOTAL") + data_dict["state"].append("Stages") col_names = [] for estimator_name, estimator in self.estimator.items(): - self.logger.info(f'Read the results from estimator {estimator_name}') + self.logger.info(f"Read the results from estimator {estimator_name}") # Do the unit conversion delta_f_ = estimator.delta_f_ @@ -499,26 +528,26 @@ def generate_result(self): # Write the estimator header col_names.append(estimator_name) - col_names.append(estimator_name + '_Error') + col_names.append(estimator_name + "_Error") data_dict[estimator_name] = [] - data_dict[estimator_name + '_Error'] = [] + data_dict[estimator_name + "_Error"] = [] for index in range(1, num_states): - data_dict[estimator_name].append( - delta_f_.iloc[index-1, index]) - data_dict[estimator_name + '_Error'].append( - d_delta_f_.iloc[index - 1, index]) + data_dict[estimator_name].append(delta_f_.iloc[index - 1, index]) + data_dict[estimator_name + "_Error"].append( + d_delta_f_.iloc[index - 1, index] + ) - self.logger.info(f'Generate the staged result from estimator {estimator_name}') + self.logger.info( + f"Generate the staged result from estimator {estimator_name}" + ) for index, stage in enumerate(stages): if len(stages) == 1: start = 0 end = len(estimator.states_) - 1 else: # Get the start and the end of the state - lambda_min = min( - [state[index] for state in estimator.states_]) - lambda_max = max( - [state[index] for state in estimator.states_]) + lambda_min = min([state[index] for state in estimator.states_]) + lambda_max = max([state[index] for state in estimator.states_]) if lambda_min == lambda_max: # Deal with the case where a certain lambda is used but # not perturbed @@ -529,35 +558,39 @@ def generate_result(self): start = list(reversed(states)).index(lambda_min) start = num_states - start - 1 end = states.index(lambda_max) - self.logger.info( - f'Stage {stage} is from state {start} to state {end}.') + self.logger.info(f"Stage {stage} is from state {start} to state {end}.") # This assumes that the indexes are sorted as the # preprocessing should sort the index of the df. result = delta_f_.iloc[start, end] - if estimator_name != 'BAR': + if estimator_name != "BAR": error = d_delta_f_.iloc[start, end] else: - error = np.sqrt(sum( - [d_delta_f_.iloc[start, start+1]**2 - for i in range(start, end + 1)])) + error = np.sqrt( + sum( + [ + d_delta_f_.iloc[start, start + 1] ** 2 + for i in range(start, end + 1) + ] + ) + ) data_dict[estimator_name].append(result) - data_dict[estimator_name + '_Error'].append(error) + data_dict[estimator_name + "_Error"].append(error) # Total result # This assumes that the indexes are sorted as the # preprocessing should sort the index of the df. result = delta_f_.iloc[0, -1] - if estimator_name != 'BAR': + if estimator_name != "BAR": error = d_delta_f_.iloc[0, -1] else: - error = np.sqrt(sum( - [d_delta_f_.iloc[i, i + 1] ** 2 - for i in range(num_states - 1)])) + error = np.sqrt( + sum([d_delta_f_.iloc[i, i + 1] ** 2 for i in range(num_states - 1)]) + ) data_dict[estimator_name].append(result) - data_dict[estimator_name + '_Error'].append(error) + data_dict[estimator_name + "_Error"].append(error) summary = pd.DataFrame.from_dict(data_dict) - summary = summary.set_index(['state', 'name']) + summary = summary.set_index(["state", "name"]) # Make sure that the columns are in the right order summary = summary[col_names] # Remove the name of the index column to make it prettier @@ -567,11 +600,11 @@ def generate_result(self): converter = get_unit_converter(self.units) summary = converter(summary) self.summary = summary - self.logger.info(f'Write results:\n{summary.to_string()}') + self.logger.info(f"Write results:\n{summary.to_string()}") return summary - def plot_overlap_matrix(self, overlap='O_MBAR.pdf', ax=None): - '''Plot the overlap matrix for MBAR estimator using + def plot_overlap_matrix(self, overlap="O_MBAR.pdf", ax=None): + """Plot the overlap matrix for MBAR estimator using :func:`~alchemlyb.visualisation.plot_mbar_overlap_matrix`. Parameters @@ -586,21 +619,20 @@ def plot_overlap_matrix(self, overlap='O_MBAR.pdf', ax=None): ------- matplotlib.axes.Axes An axes with the overlap matrix drawn. - ''' - self.logger.info('Plot overlap matrix.') - if 'MBAR' in self.estimator: - ax = plot_mbar_overlap_matrix(self.estimator['MBAR'].overlap_matrix, - ax=ax) + """ + self.logger.info("Plot overlap matrix.") + if "MBAR" in self.estimator: + ax = plot_mbar_overlap_matrix(self.estimator["MBAR"].overlap_matrix, ax=ax) ax.figure.savefig(join(self.out, overlap)) - self.logger.info(f'Plot overlap matrix to {self.out} under {overlap}.') + self.logger.info(f"Plot overlap matrix to {self.out} under {overlap}.") return ax else: - self.logger.warning('MBAR estimator not found. ' - 'Overlap matrix not plotted.') + self.logger.warning( + "MBAR estimator not found. " "Overlap matrix not plotted." + ) - def plot_ti_dhdl(self, dhdl_TI='dhdl_TI.pdf', labels=None, colors=None, - ax=None): - '''Plot the dHdl for TI estimator using + def plot_ti_dhdl(self, dhdl_TI="dhdl_TI.pdf", labels=None, colors=None, ax=None): + """Plot the dHdl for TI estimator using :func:`~alchemlyb.visualisation.plot_ti_dhdl`. Parameters @@ -620,20 +652,31 @@ def plot_ti_dhdl(self, dhdl_TI='dhdl_TI.pdf', labels=None, colors=None, ------- matplotlib.axes.Axes An axes with the TI dhdl drawn. - ''' - self.logger.info('Plot TI dHdl.') - if 'TI' in self.estimator: - ax = plot_ti_dhdl(self.estimator['TI'], units=self.units, - labels=labels, colors=colors, ax=ax) + """ + self.logger.info("Plot TI dHdl.") + if "TI" in self.estimator: + ax = plot_ti_dhdl( + self.estimator["TI"], + units=self.units, + labels=labels, + colors=colors, + ax=ax, + ) ax.figure.savefig(join(self.out, dhdl_TI)) - self.logger.info(f'Plot TI dHdl to {dhdl_TI} under {self.out}.') + self.logger.info(f"Plot TI dHdl to {dhdl_TI} under {self.out}.") return ax else: - raise ValueError('No TI data available in estimators.') - - def plot_dF_state(self, dF_state='dF_state.pdf', labels=None, colors=None, - orientation='portrait', nb=10): - '''Plot the dF states using + raise ValueError("No TI data available in estimators.") + + def plot_dF_state( + self, + dF_state="dF_state.pdf", + labels=None, + colors=None, + orientation="portrait", + nb=10, + ): + """Plot the dF states using :func:`~alchemlyb.visualisation.plot_dF_state`. Parameters @@ -653,18 +696,24 @@ def plot_dF_state(self, dF_state='dF_state.pdf', labels=None, colors=None, ------- matplotlib.figure.Figure An Figure with the dF states drawn. - ''' - self.logger.info('Plot dF states.') - fig = plot_dF_state(self.estimator.values(), labels=labels, colors=colors, - units=self.units, - orientation=orientation, nb=nb) + """ + self.logger.info("Plot dF states.") + fig = plot_dF_state( + self.estimator.values(), + labels=labels, + colors=colors, + units=self.units, + orientation=orientation, + nb=nb, + ) fig.savefig(join(self.out, dF_state)) - self.logger.info(f'Plot dF state to {dF_state} under {self.out}.') + self.logger.info(f"Plot dF state to {dF_state} under {self.out}.") return fig - def check_convergence(self, forwrev, estimator='MBAR', dF_t='dF_t.pdf', - ax=None, **kwargs): - '''Compute the forward and backward convergence using + def check_convergence( + self, forwrev, estimator="MBAR", dF_t="dF_t.pdf", ax=None, **kwargs + ): + """Compute the forward and backward convergence using :func:`~alchemlyb.convergence.forward_backward_convergence`and plot with :func:`~alchemlyb.visualisation.plot_convergence`. @@ -701,59 +750,63 @@ def check_convergence(self, forwrev, estimator='MBAR', dF_t='dF_t.pdf', of :class:`~alchemlyb.estimators.MBAR`. (:code:`check_convergence(10, estimator='MBAR', method='adaptive')`) - ''' - self.logger.info('Start convergence analysis.') - self.logger.info('Checking data availability.') + """ + self.logger.info("Start convergence analysis.") + self.logger.info("Checking data availability.") if estimator in FEP_ESTIMATORS: if self.u_nk_sample_list is not None: u_nk_list = self.u_nk_sample_list - self.logger.info('Subsampled u_nk is available.') + self.logger.info("Subsampled u_nk is available.") else: if self.u_nk_list is not None: u_nk_list = self.u_nk_list - self.logger.info('Subsampled u_nk not available, ' - 'use original data instead.') + self.logger.info( + "Subsampled u_nk not available, " "use original data instead." + ) else: - msg = f"u_nk is needed for the f{estimator} estimator. " \ - f"If the dataset only has dHdl, " \ - f"run ABFE.check_convergence(estimator='TI') to " \ - f"use a TI estimator." + msg = ( + f"u_nk is needed for the f{estimator} estimator. " + f"If the dataset only has dHdl, " + f"run ABFE.check_convergence(estimator='TI') to " + f"use a TI estimator." + ) self.logger.error(msg) raise ValueError(msg) - convergence = forward_backward_convergence(u_nk_list, - estimator=estimator, - num=forwrev, **kwargs) + convergence = forward_backward_convergence( + u_nk_list, estimator=estimator, num=forwrev, **kwargs + ) elif estimator in TI_ESTIMATORS: - self.logger.warning('No valid FEP estimator or dataset found. ' - 'Fallback to TI.') + self.logger.warning( + "No valid FEP estimator or dataset found. " "Fallback to TI." + ) if self.dHdl_sample_list is not None: dHdl_list = self.dHdl_sample_list - self.logger.info('Subsampled dHdl is available.') + self.logger.info("Subsampled dHdl is available.") else: if self.dHdl_list is not None: dHdl_list = self.dHdl_list - self.logger.info('Subsampled dHdl not available, ' - 'use original data instead.') + self.logger.info( + "Subsampled dHdl not available, " "use original data instead." + ) else: - self.logger.error( - f'dHdl is needed for the f{estimator} estimator.') - raise ValueError( - f'dHdl is needed for the f{estimator} estimator.') - convergence = forward_backward_convergence(dHdl_list, - estimator=estimator, - num=forwrev, **kwargs) + self.logger.error(f"dHdl is needed for the f{estimator} estimator.") + raise ValueError(f"dHdl is needed for the f{estimator} estimator.") + convergence = forward_backward_convergence( + dHdl_list, estimator=estimator, num=forwrev, **kwargs + ) else: - msg = f"Estimator {estimator} is not supported. Choose one from " \ - f"{FEP_ESTIMATORS+TI_ESTIMATORS}." + msg = ( + f"Estimator {estimator} is not supported. Choose one from " + f"{FEP_ESTIMATORS + TI_ESTIMATORS}." + ) self.logger.error(msg) raise ValueError(msg) self.convergence = get_unit_converter(self.units)(convergence) - self.logger.info(f'Plot convergence analysis to {dF_t} under {self.out}.') + self.logger.info(f"Plot convergence analysis to {dF_t} under {self.out}.") - ax = plot_convergence(self.convergence, - units=self.units, ax=ax) + ax = plot_convergence(self.convergence, units=self.units, ax=ax) ax.figure.savefig(join(self.out, dF_t)) return ax diff --git a/src/alchemlyb/workflows/base.py b/src/alchemlyb/workflows/base.py index 1728f7f0..1b4cbe41 100644 --- a/src/alchemlyb/workflows/base.py +++ b/src/alchemlyb/workflows/base.py @@ -2,7 +2,8 @@ import pandas as pd -class WorkflowBase(): + +class WorkflowBase: """The base class for the Workflow. This is the base class for the creation of new Workflow. The @@ -37,9 +38,10 @@ class WorkflowBase(): .. versionadded:: 0.7.0 """ - def __init__(self, units='kT', software='Gromacs', T=298, out='./', *args, - **kwargs): + def __init__( + self, units="kT", software="Gromacs", T=298, out="./", *args, **kwargs + ): self.T = T self.software = software self.unit = units @@ -47,7 +49,7 @@ def __init__(self, units='kT', software='Gromacs', T=298, out='./', *args, self.out = out def run(self, *args, **kwargs): - """ Run the workflow in an automatic fashion. + """Run the workflow in an automatic fashion. This method would execute the :func:`~alchemlyb.workflows.WorkflowBase.read`, @@ -88,7 +90,7 @@ def run(self, *args, **kwargs): self.plot(*args, **kwargs) def read(self, *args, **kwargs): - """ The function that reads the files in `file_list` and parse them + """The function that reads the files in `file_list` and parse them into u_nk and dHdl files. Attributes @@ -104,7 +106,7 @@ def read(self, *args, **kwargs): self.dHdl_list = [] def preprocess(self, *args, **kwargs): - """ The function that subsample the u_nk and dHdl in `u_nk_list` and + """The function that subsample the u_nk and dHdl in `u_nk_list` and `dHdl_list`. Attributes @@ -120,7 +122,7 @@ def preprocess(self, *args, **kwargs): self.u_nk_sample_list = [] def estimate(self, *args, **kwargs): - """ The function that runs the estimator based on `u_nk_sample_list` + """The function that runs the estimator based on `u_nk_sample_list` and `dHdl_sample_list`. Attributes @@ -133,7 +135,7 @@ def estimate(self, *args, **kwargs): self.result = pd.DataFrame() def check_convergence(self, *args, **kwargs): - """ The function for doing convergence analysis. + """The function for doing convergence analysis. Attributes ---------- @@ -145,7 +147,5 @@ def check_convergence(self, *args, **kwargs): self.convergence = pd.DataFrame() def plot(self, *args, **kwargs): - """ The function for producing any plots. - - """ + """The function for producing any plots.""" pass From 654e99930f6406aa45cc91bd492797124ed43d08 Mon Sep 17 00:00:00 2001 From: "William (Zhiyi) Wu" Date: Tue, 6 Dec 2022 10:18:11 +0000 Subject: [PATCH 18/21] ci --- .github/workflows/ci.yaml | 2 +- devtools/conda-envs/test_env.yaml | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index caa14bdb..0c7abcd7 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -63,7 +63,7 @@ jobs: - name: Run tests run: | - pytest -v -n 2 --cov=alchemlyb --cov-report=xml --color=yes src/alchemlyb/tests + pytest -v -n 2 --black --cov=alchemlyb --cov-report=xml --color=yes src/alchemlyb/tests env: MPLBACKEND: agg diff --git a/devtools/conda-envs/test_env.yaml b/devtools/conda-envs/test_env.yaml index 009c3c3a..608b8843 100644 --- a/devtools/conda-envs/test_env.yaml +++ b/devtools/conda-envs/test_env.yaml @@ -14,4 +14,5 @@ dependencies: - pytest - pytest-cov - pytest-xdist +- pytest-black - codecov From f07fb18b09bb4e11af3ac381fb8b304a8225c550 Mon Sep 17 00:00:00 2001 From: "William (Zhiyi) Wu" Date: Tue, 6 Dec 2022 10:27:08 +0000 Subject: [PATCH 19/21] change log --- CHANGES | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/CHANGES b/CHANGES index c8f0d233..c1f1e1f0 100644 --- a/CHANGES +++ b/CHANGES @@ -13,10 +13,13 @@ The rules for this file: * release numbers follow "Semantic Versioning" https://semver.org ------------------------------------------------------------------------------ -??/??/2022 DrDomenicoMarson +??/??/2022 DrDomenicoMarson, xiki-tempula * 1.0.1 +Enhancements + - Blackfy the codebase (PR #280). + Fixes - Remove most of the iloc in the tests (issue #202, PR #254). - AMBER parser now raises ValueError when the initial simulation time From e950af244efc7e5cb4d8df63040a18a7f53f698d Mon Sep 17 00:00:00 2001 From: "William (Zhiyi) Wu" Date: Tue, 6 Dec 2022 10:42:32 +0000 Subject: [PATCH 20/21] black --- src/alchemlyb/convergence/convergence.py | 125 +++--- src/alchemlyb/estimators/__init__.py | 2 +- src/alchemlyb/estimators/bar_.py | 52 ++- src/alchemlyb/estimators/base.py | 9 +- src/alchemlyb/estimators/mbar_.py | 106 +++-- src/alchemlyb/estimators/ti_.py | 47 +- src/alchemlyb/parsing/__init__.py | 21 +- src/alchemlyb/parsing/amber.py | 186 ++++---- src/alchemlyb/parsing/gmx.py | 156 ++++--- src/alchemlyb/parsing/gomc.py | 68 +-- src/alchemlyb/parsing/namd.py | 170 ++++--- src/alchemlyb/parsing/util.py | 28 +- src/alchemlyb/postprocessors/__init__.py | 2 +- src/alchemlyb/postprocessors/units.py | 53 +-- src/alchemlyb/preprocessing/__init__.py | 24 +- src/alchemlyb/preprocessing/subsampling.py | 173 ++++--- src/alchemlyb/tests/conftest.py | 1 - src/alchemlyb/tests/parsing/test_amber.py | 100 +++-- src/alchemlyb/tests/parsing/test_gmx.py | 252 ++++++----- src/alchemlyb/tests/parsing/test_gomc.py | 35 +- src/alchemlyb/tests/parsing/test_namd.py | 247 ++++++---- src/alchemlyb/tests/parsing/test_util.py | 85 ++-- src/alchemlyb/tests/test_convergence.py | 3 +- src/alchemlyb/tests/test_fep_estimators.py | 1 - src/alchemlyb/tests/test_import.py | 3 +- src/alchemlyb/tests/test_preprocessing.py | 5 +- src/alchemlyb/tests/test_version.py | 4 +- src/alchemlyb/tests/test_workflow.py | 17 +- src/alchemlyb/tests/test_workflow_ABFE.py | 456 +++++++++++-------- src/alchemlyb/visualisation/__init__.py | 4 +- src/alchemlyb/visualisation/convergence.py | 201 +++++---- src/alchemlyb/visualisation/dF_state.py | 171 ++++--- src/alchemlyb/visualisation/mbar_matrix.py | 99 +++-- src/alchemlyb/visualisation/ti_dhdl.py | 103 +++-- src/alchemlyb/workflows/__init__.py | 3 +- src/alchemlyb/workflows/abfe.py | 495 ++++++++++++--------- src/alchemlyb/workflows/base.py | 22 +- 37 files changed, 2053 insertions(+), 1476 deletions(-) diff --git a/src/alchemlyb/convergence/convergence.py b/src/alchemlyb/convergence/convergence.py index 372bd176..654fea2d 100644 --- a/src/alchemlyb/convergence/convergence.py +++ b/src/alchemlyb/convergence/convergence.py @@ -3,17 +3,16 @@ import logging from warnings import warn -import pandas as pd import numpy as np +import pandas as pd -from ..estimators import BAR, TI, FEP_ESTIMATORS, TI_ESTIMATORS -from ..estimators import AutoMBAR as MBAR from .. import concat +from ..estimators import FEP_ESTIMATORS, TI_ESTIMATORS from ..postprocessors.units import to_kT -def forward_backward_convergence(df_list, estimator='MBAR', num=10, **kwargs): - '''Forward and backward convergence of the free energy estimate. +def forward_backward_convergence(df_list, estimator="MBAR", num=10, **kwargs): + """Forward and backward convergence of the free energy estimate. Generate the free energy estimate as a function of time in both directions, with the specified number of equally spaced points in the time @@ -69,16 +68,17 @@ def forward_backward_convergence(df_list, estimator='MBAR', num=10, **kwargs): The default for using ``estimator='MBAR'`` was changed from :class:`~alchemlyb.estimators.MBAR` to :class:`~alchemlyb.estimators.AutoMBAR`. - ''' - logger = logging.getLogger('alchemlyb.convergence.' - 'forward_backward_convergence') - logger.info('Start convergence analysis.') - logger.info('Check data availability.') + """ + logger = logging.getLogger("alchemlyb.convergence." "forward_backward_convergence") + logger.info("Start convergence analysis.") + logger.info("Check data availability.") if estimator.upper() != estimator: - warn("Using lower-case strings for the 'estimator' kwarg in " - "convergence.forward_backward_convergence() is deprecated in " - "1.0.0 and only upper case will be accepted in 2.0.0", - DeprecationWarning) + warn( + "Using lower-case strings for the 'estimator' kwarg in " + "convergence.forward_backward_convergence() is deprecated in " + "1.0.0 and only upper case will be accepted in 2.0.0", + DeprecationWarning, + ) estimator = estimator.upper() if estimator not in (FEP_ESTIMATORS + TI_ESTIMATORS): @@ -88,61 +88,77 @@ def forward_backward_convergence(df_list, estimator='MBAR', num=10, **kwargs): else: # select estimator class by name estimator_fit = globals()[estimator](**kwargs).fit - logger.info(f'Use {estimator} estimator for convergence analysis.') + logger.info(f"Use {estimator} estimator for convergence analysis.") - logger.info('Begin forward analysis') + logger.info("Begin forward analysis") forward_list = [] forward_error_list = [] for i in range(1, num + 1): - logger.info('Forward analysis: {:.2f}%'.format(100 * i / num)) + logger.info("Forward analysis: {:.2f}%".format(100 * i / num)) sample = [] for data in df_list: - sample.append(data[:len(data) // num * i]) + sample.append(data[: len(data) // num * i]) sample = concat(sample) result = estimator_fit(sample) forward_list.append(result.delta_f_.iloc[0, -1]) - if estimator.lower() == 'bar': - error = np.sqrt(sum( - [result.d_delta_f_.iloc[i, i + 1] ** 2 - for i in range(len(result.d_delta_f_) - 1)])) + if estimator.lower() == "bar": + error = np.sqrt( + sum( + [ + result.d_delta_f_.iloc[i, i + 1] ** 2 + for i in range(len(result.d_delta_f_) - 1) + ] + ) + ) forward_error_list.append(error) else: forward_error_list.append(result.d_delta_f_.iloc[0, -1]) - logger.info('{:.2f} +/- {:.2f} kT'.format(forward_list[-1], - forward_error_list[-1])) + logger.info( + "{:.2f} +/- {:.2f} kT".format(forward_list[-1], forward_error_list[-1]) + ) - logger.info('Begin backward analysis') + logger.info("Begin backward analysis") backward_list = [] backward_error_list = [] for i in range(1, num + 1): - logger.info('Backward analysis: {:.2f}%'.format(100 * i / num)) + logger.info("Backward analysis: {:.2f}%".format(100 * i / num)) sample = [] for data in df_list: - sample.append(data[-len(data) // num * i:]) + sample.append(data[-len(data) // num * i :]) sample = concat(sample) result = estimator_fit(sample) backward_list.append(result.delta_f_.iloc[0, -1]) - if estimator.lower() == 'bar': - error = np.sqrt(sum( - [result.d_delta_f_.iloc[i, i + 1] ** 2 - for i in range(len(result.d_delta_f_) - 1)])) + if estimator.lower() == "bar": + error = np.sqrt( + sum( + [ + result.d_delta_f_.iloc[i, i + 1] ** 2 + for i in range(len(result.d_delta_f_) - 1) + ] + ) + ) backward_error_list.append(error) else: backward_error_list.append(result.d_delta_f_.iloc[0, -1]) - logger.info('{:.2f} +/- {:.2f} kT'.format(backward_list[-1], - backward_error_list[-1])) + logger.info( + "{:.2f} +/- {:.2f} kT".format(backward_list[-1], backward_error_list[-1]) + ) convergence = pd.DataFrame( - {'Forward': forward_list, - 'Forward_Error': forward_error_list, - 'Backward': backward_list, - 'Backward_Error': backward_error_list, - 'data_fraction': [i / num for i in range(1, num + 1)]}) + { + "Forward": forward_list, + "Forward_Error": forward_error_list, + "Backward": backward_list, + "Backward_Error": backward_error_list, + "data_fraction": [i / num for i in range(1, num + 1)], + } + ) convergence.attrs = df_list[0].attrs return convergence + def _cummean(vals, out_length): - '''The cumulative mean of an array. + """The cumulative mean of an array. This function computes the cumulative mean and shapes the result to the desired length. @@ -167,18 +183,19 @@ def _cummean(vals, out_length): .. versionadded:: 1.0.0 - ''' + """ in_length = len(vals) if in_length < out_length: out_length = in_length block = in_length // out_length - reshape = vals[: block*out_length].reshape(block, out_length) + reshape = vals[: block * out_length].reshape(block, out_length) mean = np.mean(reshape, axis=0) - result = np.cumsum(mean) / np.arange(1, out_length+1) + result = np.cumsum(mean) / np.arange(1, out_length + 1) return result + def fwdrev_cumavg_Rc(series, precision=0.01, tol=2): - '''Generate the convergence criteria :math:`R_c` for a single simulation. + """Generate the convergence criteria :math:`R_c` for a single simulation. The input will be :class:`pandas.Series` generated by :func:`~alchemlyb.preprocessing.subsampling.decorrelate_u_nk` or @@ -241,7 +258,7 @@ def fwdrev_cumavg_Rc(series, precision=0.01, tol=2): .. _`equation 16`: https://www.ncbi.nlm.nih.gov/pmc/articles/PMC8397498/#FD16 - ''' + """ series = to_kT(series) array = series.to_numpy() out_length = int(1 / precision) @@ -250,9 +267,12 @@ def fwdrev_cumavg_Rc(series, precision=0.01, tol=2): length = len(g_forward) convergence = pd.DataFrame( - {'Forward': g_forward, - 'Backward': g_backward, - 'data_fraction': [i / length for i in range(1, length + 1)]}) + { + "Forward": g_forward, + "Backward": g_backward, + "data_fraction": [i / length for i in range(1, length + 1)], + } + ) convergence.attrs = series.attrs # Final value @@ -270,8 +290,9 @@ def fwdrev_cumavg_Rc(series, precision=0.01, tol=2): # the same as this branch will be triggered. return 1.0, convergence + def A_c(series_list, precision=0.01, tol=2): - '''Generate the ensemble convergence criteria :math:`A_c` for a set of simulations. + """Generate the ensemble convergence criteria :math:`A_c` for a set of simulations. The input is a :class:`list` of :class:`pandas.Series` generated by :func:`~alchemlyb.preprocessing.subsampling.decorrelate_u_nk` or @@ -317,11 +338,11 @@ def A_c(series_list, precision=0.01, tol=2): .. _`equation 18`: https://www.ncbi.nlm.nih.gov/pmc/articles/PMC8397498/#FD18 - ''' - logger = logging.getLogger('alchemlyb.convergence.A_c') + """ + logger = logging.getLogger("alchemlyb.convergence.A_c") n_R_c = len(series_list) R_c_list = [fwdrev_cumavg_Rc(series, precision, tol)[0] for series in series_list] - logger.info(f'R_c list: {R_c_list}') + logger.info(f"R_c list: {R_c_list}") # Integrate the R_c_list <= R_c over the range of 0 to 1 array_01 = np.hstack((R_c_list, [0, 1])) sorted_array = np.sort(np.unique(array_01)) @@ -330,6 +351,6 @@ def A_c(series_list, precision=0.01, tol=2): if i == 0: continue else: - d_R_c = sorted_array[-i] - sorted_array[-i-1] + d_R_c = sorted_array[-i] - sorted_array[-i - 1] result += d_R_c * sum(R_c_list <= element) / n_R_c return result diff --git a/src/alchemlyb/estimators/__init__.py b/src/alchemlyb/estimators/__init__.py index ca48015b..4b4e7771 100644 --- a/src/alchemlyb/estimators/__init__.py +++ b/src/alchemlyb/estimators/__init__.py @@ -1,5 +1,5 @@ -from .mbar_ import MBAR, AutoMBAR from .bar_ import BAR +from .mbar_ import MBAR, AutoMBAR from .ti_ import TI FEP_ESTIMATORS = [MBAR.__name__, AutoMBAR.__name__, BAR.__name__] diff --git a/src/alchemlyb/estimators/bar_.py b/src/alchemlyb/estimators/bar_.py index 3a7150b2..7bf39bc7 100644 --- a/src/alchemlyb/estimators/bar_.py +++ b/src/alchemlyb/estimators/bar_.py @@ -1,11 +1,11 @@ import numpy as np import pandas as pd - -from sklearn.base import BaseEstimator from pymbar import BAR as BAR_ +from sklearn.base import BaseEstimator from .base import _EstimatorMixOut + class BAR(BaseEstimator, _EstimatorMixOut): """Bennett acceptance ratio (BAR). @@ -57,7 +57,13 @@ class BAR(BaseEstimator, _EstimatorMixOut): """ - def __init__(self, maximum_iterations=10000, relative_tolerance=1.0e-7, method='false-position', verbose=False): + def __init__( + self, + maximum_iterations=10000, + relative_tolerance=1.0e-7, + method="false-position", + verbose=False, + ): self.maximum_iterations = maximum_iterations self.relative_tolerance = relative_tolerance @@ -87,7 +93,10 @@ def fit(self, u_nk): # group u_nk by lambda states groups = u_nk.groupby(level=u_nk.index.names[1:]) - N_k = [(len(groups.get_group(i)) if i in groups.groups else 0) for i in u_nk.columns] + N_k = [ + (len(groups.get_group(i)) if i in groups.groups else 0) + for i in u_nk.columns + ] # Now get free energy differences and their uncertainties for each step deltas = np.array([]) @@ -96,19 +105,22 @@ def fit(self, u_nk): # get us from lambda step k uk = groups.get_group(self._states_[k]) # get w_F - w_f = uk.iloc[:, k+1] - uk.iloc[:, k] + w_f = uk.iloc[:, k + 1] - uk.iloc[:, k] # get us from lambda step k+1 - uk1 = groups.get_group(self._states_[k+1]) + uk1 = groups.get_group(self._states_[k + 1]) # get w_R - w_r = uk1.iloc[:, k] - uk1.iloc[:, k+1] + w_r = uk1.iloc[:, k] - uk1.iloc[:, k + 1] # now determine df and ddf using pymbar.BAR - df, ddf = BAR_(w_f, w_r, - method=self.method, - maximum_iterations=self.maximum_iterations, - relative_tolerance=self.relative_tolerance, - verbose=self.verbose) + df, ddf = BAR_( + w_f, + w_r, + method=self.method, + maximum_iterations=self.maximum_iterations, + relative_tolerance=self.relative_tolerance, + verbose=self.verbose, + ) deltas = np.append(deltas, df) d_deltas = np.append(d_deltas, ddf**2) @@ -121,14 +133,14 @@ def fit(self, u_nk): out = [] dout = [] for i in range(len(deltas) - j): - out.append(deltas[i:i + j + 1].sum()) + out.append(deltas[i : i + j + 1].sum()) # See https://github.com/alchemistry/alchemlyb/pull/60#issuecomment-430720742 # Error estimate generated by BAR ARE correlated # Use the BAR uncertainties between two neighbour states if j == 0: - dout.append(d_deltas[i:i + j + 1].sum()) + dout.append(d_deltas[i : i + j + 1].sum()) # Other uncertainties are unknown at this point else: dout.append(np.nan) @@ -137,14 +149,14 @@ def fit(self, u_nk): ad_delta += np.diagflat(np.array(dout), k=j + 1) # yield standard delta_f_ free energies between each state - self._delta_f_ = pd.DataFrame(adelta - adelta.T, - columns=self._states_, - index=self._states_) + self._delta_f_ = pd.DataFrame( + adelta - adelta.T, columns=self._states_, index=self._states_ + ) # yield standard deviation d_delta_f_ between each state - self._d_delta_f_ = pd.DataFrame(np.sqrt(ad_delta + ad_delta.T), - columns=self._states_, - index=self._states_) + self._d_delta_f_ = pd.DataFrame( + np.sqrt(ad_delta + ad_delta.T), columns=self._states_, index=self._states_ + ) self._delta_f_.attrs = u_nk.attrs self._d_delta_f_.attrs = u_nk.attrs diff --git a/src/alchemlyb/estimators/base.py b/src/alchemlyb/estimators/base.py index e6b1b8be..93f3da8a 100644 --- a/src/alchemlyb/estimators/base.py +++ b/src/alchemlyb/estimators/base.py @@ -1,9 +1,11 @@ -class _EstimatorMixOut(): - '''This class creates view for the d_delta_f_, delta_f_, states_ for the - estimator class to consume.''' +class _EstimatorMixOut: + """This class creates view for the d_delta_f_, delta_f_, states_ for the + estimator class to consume.""" + _d_delta_f_ = None _delta_f_ = None _states_ = None + @property def d_delta_f_(self): return self._d_delta_f_ @@ -15,4 +17,3 @@ def delta_f_(self): @property def states_(self): return self._states_ - \ No newline at end of file diff --git a/src/alchemlyb/estimators/mbar_.py b/src/alchemlyb/estimators/mbar_.py index d34434e9..4759d687 100644 --- a/src/alchemlyb/estimators/mbar_.py +++ b/src/alchemlyb/estimators/mbar_.py @@ -1,12 +1,12 @@ -import numpy as np -import pandas as pd import logging -from sklearn.base import BaseEstimator +import pandas as pd import pymbar +from sklearn.base import BaseEstimator from .base import _EstimatorMixOut + class MBAR(BaseEstimator, _EstimatorMixOut): """Multi-state Bennett acceptance ratio (MBAR). @@ -62,14 +62,20 @@ class MBAR(BaseEstimator, _EstimatorMixOut): `delta_f_`, `d_delta_f_`, `states_` are view of the original object. """ - def __init__(self, maximum_iterations=10000, relative_tolerance=1.0e-7, - initial_f_k=None, method='hybr', verbose=False): + def __init__( + self, + maximum_iterations=10000, + relative_tolerance=1.0e-7, + initial_f_k=None, + method="hybr", + verbose=False, + ): self.maximum_iterations = maximum_iterations self.relative_tolerance = relative_tolerance self.initial_f_k = initial_f_k self.method = method self.verbose = verbose - self.logger = logging.getLogger('alchemlyb.estimators.MBAR') + self.logger = logging.getLogger("alchemlyb.estimators.MBAR") # handle for pymbar.MBAR object self._mbar = None @@ -90,22 +96,24 @@ def fit(self, u_nk): u_nk = u_nk.sort_index(level=u_nk.index.names[1:]) groups = u_nk.groupby(level=u_nk.index.names[1:]) - N_k = [(len(groups.get_group(i)) if i in groups.groups else 0) for i in - u_nk.columns] + N_k = [ + (len(groups.get_group(i)) if i in groups.groups else 0) + for i in u_nk.columns + ] self._states_ = u_nk.columns.values.tolist() # Prepare the solver_protocol as stated in https://github.com/choderalab/pymbar/issues/419#issuecomment-803714103 - solver_options = {"maximum_iterations": self.maximum_iterations, - "verbose": self.verbose} - solver_protocol = {"method": self.method, - "options": solver_options} + solver_options = { + "maximum_iterations": self.maximum_iterations, + "verbose": self.verbose, + } + solver_protocol = {"method": self.method, "options": solver_options} self._mbar, out = self._do_MBAR(u_nk, N_k, solver_protocol) - free_energy_differences = [pd.DataFrame(i, - columns=self._states_, - index=self._states_) for i in - out] + free_energy_differences = [ + pd.DataFrame(i, columns=self._states_, index=self._states_) for i in out + ] (self._delta_f_, self._d_delta_f_, self.theta_) = free_energy_differences @@ -118,15 +126,20 @@ def predict(self, u_ln): pass def _do_MBAR(self, u_nk, N_k, solver_protocol): - mbar = pymbar.MBAR(u_nk.T, N_k, - relative_tolerance=self.relative_tolerance, - initial_f_k=self.initial_f_k, - solver_protocol=(solver_protocol,)) - self.logger.info("Solved MBAR equations with method %r and " - "maximum_iterations=%d, relative_tolerance=%g", - solver_protocol['method'], - solver_protocol['options']['maximum_iterations'], - self.relative_tolerance) + mbar = pymbar.MBAR( + u_nk.T, + N_k, + relative_tolerance=self.relative_tolerance, + initial_f_k=self.initial_f_k, + solver_protocol=(solver_protocol,), + ) + self.logger.info( + "Solved MBAR equations with method %r and " + "maximum_iterations=%d, relative_tolerance=%g", + solver_protocol["method"], + solver_protocol["options"]["maximum_iterations"], + self.relative_tolerance, + ) # set attributes out = mbar.getFreeEnergyDifferences(return_theta=True) return mbar, out @@ -145,7 +158,7 @@ def overlap_matrix(self): --------- pymbar.mbar.MBAR.computeOverlap """ - return self._mbar.computeOverlap()['matrix'] + return self._mbar.computeOverlap()["matrix"] class AutoMBAR(MBAR): @@ -188,31 +201,42 @@ class AutoMBAR(MBAR): .. versionchanged:: 1.0.0 AutoMBAR accepts the `method` argument. """ - def __init__(self, maximum_iterations=10000, relative_tolerance=1.0e-7, - initial_f_k=None, verbose=False, method=None): - super().__init__(maximum_iterations=maximum_iterations, - relative_tolerance=relative_tolerance, - initial_f_k=initial_f_k, - verbose=verbose, method=method) - self.logger = logging.getLogger('alchemlyb.estimators.AutoMBAR') + + def __init__( + self, + maximum_iterations=10000, + relative_tolerance=1.0e-7, + initial_f_k=None, + verbose=False, + method=None, + ): + super().__init__( + maximum_iterations=maximum_iterations, + relative_tolerance=relative_tolerance, + initial_f_k=initial_f_k, + verbose=verbose, + method=method, + ) + self.logger = logging.getLogger("alchemlyb.estimators.AutoMBAR") def _do_MBAR(self, u_nk, N_k, solver_protocol): if solver_protocol["method"] is None: - self.logger.info('Initialise the automatic routine of the MBAR ' - 'estimator.') + self.logger.info( + "Initialise the automatic routine of the MBAR " "estimator." + ) # Try the fastest method first try: - self.logger.info('Trying the hybr method.') - solver_protocol["method"] = 'hybr' + self.logger.info("Trying the hybr method.") + solver_protocol["method"] = "hybr" mbar, out = super()._do_MBAR(u_nk, N_k, solver_protocol) except pymbar.utils.ParameterError: try: - self.logger.info('Trying the adaptive method.') - solver_protocol["method"] = 'adaptive' + self.logger.info("Trying the adaptive method.") + solver_protocol["method"] = "adaptive" mbar, out = super()._do_MBAR(u_nk, N_k, solver_protocol) except pymbar.utils.ParameterError: - self.logger.info('Trying the BFGS method.') - solver_protocol["method"] = 'BFGS' + self.logger.info("Trying the BFGS method.") + solver_protocol["method"] = "BFGS" mbar, out = super()._do_MBAR(u_nk, N_k, solver_protocol) return mbar, out else: diff --git a/src/alchemlyb/estimators/ti_.py b/src/alchemlyb/estimators/ti_.py index e01b8f72..bef379cc 100644 --- a/src/alchemlyb/estimators/ti_.py +++ b/src/alchemlyb/estimators/ti_.py @@ -1,10 +1,10 @@ import numpy as np import pandas as pd - from sklearn.base import BaseEstimator from .base import _EstimatorMixOut + class TI(BaseEstimator, _EstimatorMixOut): """Thermodynamic integration (TI). @@ -71,43 +71,49 @@ def fit(self, dHdl): dl = means.reset_index()[means.index.names[:]].diff().iloc[1:].values # apply trapezoid rule to obtain DF between each adjacent state - deltas = (dl * (means.iloc[:-1].values + means.iloc[1:].values)/2).sum(axis=1) + deltas = (dl * (means.iloc[:-1].values + means.iloc[1:].values) / 2).sum(axis=1) # build matrix of deltas between each state - adelta = np.zeros((len(deltas)+1, len(deltas)+1)) + adelta = np.zeros((len(deltas) + 1, len(deltas) + 1)) ad_delta = np.zeros_like(adelta) for j in range(len(deltas)): out = [] dout = [] for i in range(len(deltas) - j): - out.append(deltas[i] + deltas[i+1:i+j+1].sum()) + out.append(deltas[i] + deltas[i + 1 : i + j + 1].sum()) # Define additional zero lambda a = [0.0] * len(l_types) # Define dl series' with additional zero lambda on the left and right - dll = np.insert(dl[i:i + j + 1], 0, [a], axis=0) - dlr = np.append(dl[i:i + j + 1], [a], axis=0) + dll = np.insert(dl[i : i + j + 1], 0, [a], axis=0) + dlr = np.append(dl[i : i + j + 1], [a], axis=0) # Get a series of the form: x1, x1 + x2, ..., x(n-1) + x(n), x(n) dllr = dll + dlr # Append deviation of free energy difference between state i and i+j+1 - dout.append((dllr ** 2 * variances.iloc[i:i + j + 2].values / 4).sum(axis=1).sum()) - adelta += np.diagflat(np.array(out), k=j+1) - ad_delta += np.diagflat(np.array(dout), k=j+1) + dout.append( + (dllr**2 * variances.iloc[i : i + j + 2].values / 4) + .sum(axis=1) + .sum() + ) + adelta += np.diagflat(np.array(out), k=j + 1) + ad_delta += np.diagflat(np.array(dout), k=j + 1) # yield standard delta_f_ free energies between each state - self._delta_f_ = pd.DataFrame(adelta - adelta.T, - columns=means.index.values, - index=means.index.values) + self._delta_f_ = pd.DataFrame( + adelta - adelta.T, columns=means.index.values, index=means.index.values + ) self.dhdl = means # yield standard deviation d_delta_f_ between each state - self._d_delta_f_ = pd.DataFrame(np.sqrt(ad_delta + ad_delta.T), - columns=variances.index.values, - index=variances.index.values) + self._d_delta_f_ = pd.DataFrame( + np.sqrt(ad_delta + ad_delta.T), + columns=variances.index.values, + index=variances.index.values, + ) self._states_ = means.index.values.tolist() @@ -135,7 +141,9 @@ def separate_dhdl(self): """ if len(self.dhdl.index.names) == 1: name = self.dhdl.columns[0] - return [self.dhdl[name], ] + return [ + self.dhdl[name], + ] dhdl_list = [] # get the lambda names l_types = self.dhdl.index.names @@ -143,14 +151,14 @@ def separate_dhdl(self): # Fix issue #148, where for pandas == 1.3.0 # lambdas = self.dhdl.reset_index()[list(l_types)] lambdas = self.dhdl.reset_index()[l_types] - diff = lambdas.diff().to_numpy(dtype='bool') + diff = lambdas.diff().to_numpy(dtype="bool") # diff will give the first row as NaN so need to fix that diff[0, :] = diff[1, :] # Make sure that the start point is set to true as well diff[:-1, :] = diff[:-1, :] | diff[1:, :] for i in range(len(l_types)): - if any(diff[:,i]): - new = self.dhdl.iloc[diff[:,i], i] + if any(diff[:, i]): + new = self.dhdl.iloc[diff[:, i], i] # drop all other index for l in l_types: if l != l_types[i]: @@ -158,4 +166,3 @@ def separate_dhdl(self): new.attrs = self.dhdl.attrs dhdl_list.append(new) return dhdl_list - diff --git a/src/alchemlyb/parsing/__init__.py b/src/alchemlyb/parsing/__init__.py index 60165ac9..dc048732 100644 --- a/src/alchemlyb/parsing/__init__.py +++ b/src/alchemlyb/parsing/__init__.py @@ -1,33 +1,38 @@ from functools import wraps + def _init_attrs(func): - '''Add temperature to the parsed dataframe. + """Add temperature to the parsed dataframe. The temperature is added to the dataframe as dataframe.attrs['temperature'] and the energy unit is initiated as dataframe.attrs['energy_unit'] = 'kT'. - ''' + """ + @wraps(func) def wrapper(outfile, T, *args, **kwargs): dataframe = func(outfile, T, *args, **kwargs) if dataframe is not None: - dataframe.attrs['temperature'] = T - dataframe.attrs['energy_unit'] = 'kT' + dataframe.attrs["temperature"] = T + dataframe.attrs["energy_unit"] = "kT" return dataframe + return wrapper def _init_attrs_dict(func): - '''Add temperature and energy units to the parsed dataframes. + """Add temperature and energy units to the parsed dataframes. The temperature is added to the dataframe as dataframe.attrs['temperature'] and the energy unit is initiated as dataframe.attrs['energy_unit'] = 'kT'. - ''' + """ + @wraps(func) def wrapper(outfile, T, *args, **kwargs): dict_with_df = func(outfile, T, *args, **kwargs) for k in dict_with_df.keys(): if dict_with_df[k] is not None: - dict_with_df[k].attrs['temperature'] = T - dict_with_df[k].attrs['energy_unit'] = 'kT' + dict_with_df[k].attrs["temperature"] = T + dict_with_df[k].attrs["energy_unit"] = "kT" return dict_with_df + return wrapper diff --git a/src/alchemlyb/parsing/amber.py b/src/alchemlyb/parsing/amber.py index 8e23ada1..d129a064 100644 --- a/src/alchemlyb/parsing/amber.py +++ b/src/alchemlyb/parsing/amber.py @@ -11,21 +11,21 @@ """ -import re import logging +import re -import pandas as pd import numpy as np +import pandas as pd -from .util import anyopen from . import _init_attrs_dict +from .util import anyopen from ..postprocessors.units import R_kJmol, kJ2kcal logger = logging.getLogger("alchemlyb.parsers.Amber") k_b = R_kJmol * kJ2kcal -_FP_RE = r'[+-]?(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)?' +_FP_RE = r"[+-]?(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)?" def convert_to_pandas(file_datum): @@ -39,10 +39,13 @@ def convert_to_pandas(file_datum): data_dic["lambdas"].append(file_datum.clambda) frame_time = file_datum.t0 + (frame_index + 1) * file_datum.dt * file_datum.ntpr data_dic["time"].append(frame_time) - df = pd.DataFrame(data_dic["dHdl"], columns=["dHdl"], - index=pd.Index(data_dic["time"], name='time', dtype='Float64')) + df = pd.DataFrame( + data_dic["dHdl"], + columns=["dHdl"], + index=pd.Index(data_dic["time"], name="time", dtype="Float64"), + ) df["lambdas"] = data_dic["lambdas"][0] - df = df.reset_index().set_index(['time'] + ['lambdas']) + df = df.reset_index().set_index(["time"] + ["lambdas"]) return df @@ -59,7 +62,7 @@ def _pre_gen(it, first): return -class SectionParser(): +class SectionParser: """ A simple parser to extract data values from sections. """ @@ -68,7 +71,7 @@ def __init__(self, filename): """Opens a file according to its file type.""" self.filename = filename try: - self.fileh = anyopen(self.filename, 'r') + self.fileh = anyopen(self.filename, "r") except: logger.exception("Cannot open file %s", filename) raise @@ -93,7 +96,7 @@ def skip_after(self, pattern): break return Found_pattern - def extract_section(self, start, end, fields, limit=None, extra=''): + def extract_section(self, start, end, fields, limit=None, extra=""): """ Extract data values (int, float) in fields from a section marked with start and end regexes. Do not read further than @@ -109,15 +112,15 @@ def extract_section(self, start, end, fields, limit=None, extra=''): if inside: if re.search(end, line): break - lines.append(line.rstrip('\n')) - line = ''.join(lines) + lines.append(line.rstrip("\n")) + line = "".join(lines) result = [] for field in fields: - match = re.search(fr' {field}\s*=\s*(\*+|{_FP_RE}|\d+)', line) + match = re.search(rf" {field}\s*=\s*(\*+|{_FP_RE}|\d+)", line) if match: value = match.group(1) - if '*' in value: # catch fortran format overflow - result.append(float('Inf')) + if "*" in value: # catch fortran format overflow + result.append(float("Inf")) else: try: result.append(int(value)) @@ -146,12 +149,21 @@ def __exit__(self, typ, value, traceback): self.close() -class FEData(): +class FEData: """A simple struct container to collect data from individual files.""" - __slots__ = ['clambda', 't0', 'dt', 'T', 'ntpr', 'gradients', - 'mbar_energies', - 'have_mbar', 'mbar_lambdas', 'mbar_lambda_idx'] + __slots__ = [ + "clambda", + "t0", + "dt", + "T", + "ntpr", + "gradients", + "mbar_energies", + "have_mbar", + "mbar_lambdas", + "mbar_lambda_idx", + ] def __init__(self): self.clambda = -1.0 @@ -170,7 +182,7 @@ def file_validation(outfile): """ Function that validate and parse an AMBER output file. :exc:`ValueError` are risen if inconsinstencies in the input file are found. - + Parameters ---------- outfile : str @@ -189,76 +201,81 @@ def file_validation(outfile): if not line: logger.error("The file %s does not contain any data, it's empty.", outfile) - raise ValueError(f'file {outfile} does not contain any data.') + raise ValueError(f"file {outfile} does not contain any data.") - if not secp.skip_after('^ 2. CONTROL DATA FOR THE RUN'): + if not secp.skip_after("^ 2. CONTROL DATA FOR THE RUN"): logger.error('No "CONTROL DATA" section found in file %s.', outfile) raise ValueError(f'no "CONTROL DATA" section found in file {outfile}') - ntpr, = secp.extract_section('^Nature and format of output:', '^$', - ['ntpr']) - nstlim, dt = secp.extract_section('Molecular dynamics:', '^$', - ['nstlim', 'dt']) - T, = secp.extract_section('temperature regulation:', '^$', - ['temp0']) + (ntpr,) = secp.extract_section("^Nature and format of output:", "^$", ["ntpr"]) + nstlim, dt = secp.extract_section("Molecular dynamics:", "^$", ["nstlim", "dt"]) + (T,) = secp.extract_section("temperature regulation:", "^$", ["temp0"]) if not T: logger.error('No valid "temp0" record found in file %s.', outfile) raise ValueError(f'no valid "temp0" record found in file {outfile}') - clambda, = secp.extract_section('^Free energy options:', '^$', - ['clambda'], '^---') + (clambda,) = secp.extract_section( + "^Free energy options:", "^$", ["clambda"], "^---" + ) if clambda is None: - logger.error('No free energy section found in file %s, "clambda" was None.', outfile) - raise ValueError(f'no free energy section found in file {outfile}') + logger.error( + 'No free energy section found in file %s, "clambda" was None.', outfile + ) + raise ValueError(f"no free energy section found in file {outfile}") mbar_ndata = 0 - have_mbar, mbar_ndata, mbar_states = secp.extract_section('^FEP MBAR options:', - '^$', - ['ifmbar', - 'bar_intervall', - 'mbar_states'], - '^---') + have_mbar, mbar_ndata, mbar_states = secp.extract_section( + "^FEP MBAR options:", + "^$", + ["ifmbar", "bar_intervall", "mbar_states"], + "^---", + ) if have_mbar: mbar_ndata = int(nstlim / mbar_ndata) mbar_lambdas = _process_mbar_lambdas(secp) file_datum.mbar_lambdas = mbar_lambdas - clambda_str = f'{clambda:6.4f}' + clambda_str = f"{clambda:6.4f}" if clambda_str not in mbar_lambdas: - logger.warning('WARNING: lamba %s not contained in set of ' - 'MBAR lambas: %s\nNot using MBAR.', - clambda_str, ', '.join(mbar_lambdas)) + logger.warning( + "WARNING: lamba %s not contained in set of " + "MBAR lambas: %s\nNot using MBAR.", + clambda_str, + ", ".join(mbar_lambdas), + ) have_mbar = False else: mbar_nlambda = len(mbar_lambdas) if mbar_nlambda != mbar_states: logger.error( - 'the number of lambda windows read (%s)' - 'is different from what expected (%d)', - ','.join(mbar_lambdas), mbar_states) + "the number of lambda windows read (%s)" + "is different from what expected (%d)", + ",".join(mbar_lambdas), + mbar_states, + ) raise ValueError( - f'the number of lambda windows read ({mbar_nlambda})' - f' is different from what expected ({mbar_states})') + f"the number of lambda windows read ({mbar_nlambda})" + f" is different from what expected ({mbar_states})" + ) mbar_lambda_idx = mbar_lambdas.index(clambda_str) file_datum.mbar_lambda_idx = mbar_lambda_idx for _ in range(mbar_nlambda): file_datum.mbar_energies.append([]) - if not secp.skip_after('^ 3. ATOMIC '): + if not secp.skip_after("^ 3. ATOMIC "): logger.error('No "ATOMIC" section found in the file %s.', outfile) raise ValueError(f'no "ATOMIC" section found in file {outfile}') - t0, = secp.extract_section('^ begin time', '^$', ['coords']) + (t0,) = secp.extract_section("^ begin time", "^$", ["coords"]) if t0 is None: - logger.error('No starting simulation time in file %s.', outfile) - raise ValueError(f'No starting simulation time in file {outfile}') + logger.error("No starting simulation time in file %s.", outfile) + raise ValueError(f"No starting simulation time in file {outfile}") - if not secp.skip_after('^ 4. RESULTS'): + if not secp.skip_after("^ 4. RESULTS"): logger.error('No "RESULTS" section found in the file %s.', outfile) raise ValueError(f'no "RESULTS" section found in file {outfile}') - file_datum.clambda = clambda file_datum.t0 = t0 file_datum.dt = dt @@ -293,13 +310,13 @@ def extract(outfile, T): """ - beta = 1/(k_b * T) + beta = 1 / (k_b * T) file_datum = file_validation(outfile) if not np.isclose(T, file_datum.T, atol=0.01): - msg = f'The temperature read from the input file ({file_datum.T:.2f} K)' - msg += f' is different from the temperature passed as parameter ({T:.2f} K)' + msg = f"The temperature read from the input file ({file_datum.T:.2f} K)" + msg += f" is different from the temperature passed as parameter ({T:.2f} K)" logger.error(msg) raise ValueError(msg) @@ -311,18 +328,19 @@ def extract(outfile, T): old_nstep = -1 for line in secp: if " A V E R A G E S O V E R" in line: - _ = secp.skip_after('^|=========================================') - elif line.startswith(' NSTEP'): - nstep, dvdl = secp.extract_section('^ NSTEP', '^ ---', - ['NSTEP', 'DV/DL'], - extra=line) + _ = secp.skip_after("^|=========================================") + elif line.startswith(" NSTEP"): + nstep, dvdl = secp.extract_section( + "^ NSTEP", "^ ---", ["NSTEP", "DV/DL"], extra=line + ) if nstep != old_nstep and dvdl is not None and nstep is not None: file_datum.gradients.append(dvdl) nensec += 1 old_nstep = nstep - elif line.startswith('MBAR Energy analysis') and file_datum.have_mbar: - mbar = secp.extract_section('^MBAR', '^ ---', file_datum.mbar_lambdas, - extra=line) + elif line.startswith("MBAR Energy analysis") and file_datum.have_mbar: + mbar = secp.extract_section( + "^MBAR", "^ ---", file_datum.mbar_lambdas, extra=line + ) if None in mbar: msg = "Something strange parsing the following MBAR section." @@ -335,40 +353,48 @@ def extract(outfile, T): if energy > 0.0: high_E_cnt += 1 - file_datum.mbar_energies[lmbda].append(beta * (energy - reference_energy)) - elif line == ' 5. TIMINGS\n': + file_datum.mbar_energies[lmbda].append( + beta * (energy - reference_energy) + ) + elif line == " 5. TIMINGS\n": finished = True break if high_E_cnt: - logger.warning('%i MBAR energ%s > 0.0 kcal/mol', - high_E_cnt, 'ies are' if high_E_cnt > 1 else 'y is') + logger.warning( + "%i MBAR energ%s > 0.0 kcal/mol", + high_E_cnt, + "ies are" if high_E_cnt > 1 else "y is", + ) if not finished: - logger.warning('WARNING: file %s is a prematurely terminated run', outfile) + logger.warning("WARNING: file %s is a prematurely terminated run", outfile) if file_datum.have_mbar: mbar_time = [ file_datum.t0 + (frame_index + 1) * file_datum.dt * file_datum.ntpr - for frame_index in range(len(file_datum.mbar_energies[0]))] + for frame_index in range(len(file_datum.mbar_energies[0])) + ] mbar_df = pd.DataFrame( file_datum.mbar_energies, index=np.array(file_datum.mbar_lambdas, dtype=np.float64), columns=pd.MultiIndex.from_arrays( - [mbar_time, np.repeat(file_datum.clambda, len(mbar_time))], names=['time', 'lambdas']) - ).T + [mbar_time, np.repeat(file_datum.clambda, len(mbar_time))], + names=["time", "lambdas"], + ), + ).T else: logger.info('WARNING: No MBAR energies found! "u_nk" entry will be None') mbar_df = None if not nensec: - logger.warning('WARNING: File %s does not contain any dV/dl data', outfile) + logger.warning("WARNING: File %s does not contain any dV/dl data", outfile) dHdl_df = None else: - logger.info('Read %s dV/dl data points in file %s', nensec, outfile) + logger.info("Read %s dV/dl data points in file %s", nensec, outfile) dHdl_df = convert_to_pandas(file_datum) - dHdl_df['dHdl'] *= beta + dHdl_df["dHdl"] *= beta return {"u_nk": mbar_df, "dHdl": dHdl_df} @@ -395,7 +421,7 @@ def extract_dHdl(outfile, T): """ extracted = extract(outfile, T) - return extracted['dHdl'] + return extracted["dHdl"] def extract_u_nk(outfile, T): @@ -421,7 +447,7 @@ def extract_u_nk(outfile, T): """ extracted = extract(outfile, T) - return extracted['u_nk'] + return extracted["u_nk"] def _process_mbar_lambdas(secp): @@ -441,15 +467,15 @@ def _process_mbar_lambdas(secp): mbar_lambdas = [] for line in secp: - if line.startswith(' MBAR - lambda values considered:'): + if line.startswith(" MBAR - lambda values considered:"): in_mbar = True continue if in_mbar: - if line.startswith(' Extra'): + if line.startswith(" Extra"): break - if 'total' in line: + if "total" in line: data = line.split() mbar_lambdas.extend(data[2:]) else: diff --git a/src/alchemlyb/parsing/gmx.py b/src/alchemlyb/parsing/gmx.py index 00267c66..a9f83498 100644 --- a/src/alchemlyb/parsing/gmx.py +++ b/src/alchemlyb/parsing/gmx.py @@ -1,15 +1,16 @@ """Parsers for extracting alchemical data from `Gromacs `_ output files. """ -import pandas as pd import numpy as np +import pandas as pd -from .util import anyopen from . import _init_attrs +from .util import anyopen from ..postprocessors.units import R_kJmol k_b = R_kJmol + @_init_attrs def extract_u_nk(xvg, T, filter=True): r"""Return reduced potentials `u_nk` from a Hamiltonian differences XVG file. @@ -60,9 +61,9 @@ def extract_u_nk(xvg, T, filter=True): """ h_col_match = r"\xD\f{}H \xl\f{}" - pv_col_match = 'pV' - u_col_match = ['Total Energy', 'Potential Energy'] - beta = 1/(k_b * T) + pv_col_match = "pV" + u_col_match = ["Total Energy", "Potential Energy"] + beta = 1 / (k_b * T) state, lambdas, statevec = _extract_state(xvg) @@ -82,7 +83,11 @@ def extract_u_nk(xvg, T, filter=True): pv = df[pv_cols[0]] # gromacs also gives us total/potential energy U directly; need this for reduced potential - u_cols = [col for col in df.columns if any(single_u_col_match in col for single_u_col_match in u_col_match)] + u_cols = [ + col + for col in df.columns + if any(single_u_col_match in col for single_u_col_match in u_col_match) + ] u = None if u_cols: u = df[u_cols[0]] @@ -90,7 +95,7 @@ def extract_u_nk(xvg, T, filter=True): u_k = dict() cols = list() for col in dH: - u_col = eval(col.split('to')[1]) + u_col = eval(col.split("to")[1]) # calculate reduced potential u_k = dH + pV + U u_k[u_col] = beta * dH[col].values if pv_cols: @@ -99,8 +104,9 @@ def extract_u_nk(xvg, T, filter=True): u_k[u_col] += beta * u.values cols.append(u_col) - u_k = pd.DataFrame(u_k, columns=cols, - index=pd.Index(times.values, name='time', dtype='Float64')) + u_k = pd.DataFrame( + u_k, columns=cols, index=pd.Index(times.values, name="time", dtype="Float64") + ) # create columns for each lambda, indicating state each row sampled from # if state is None run as expanded ensemble data or REX @@ -108,8 +114,8 @@ def extract_u_nk(xvg, T, filter=True): # if thermodynamic state is specified map thermodynamic # state data to lambda values, else (for REX) # define state based on the legend - if 'Thermodynamic state' in df: - ts_index = df.columns.get_loc('Thermodynamic state') + if "Thermodynamic state" in df: + ts_index = df.columns.get_loc("Thermodynamic state") thermo_state = df[df.columns[ts_index]] for i, l in enumerate(lambdas): v = [] @@ -128,13 +134,14 @@ def extract_u_nk(xvg, T, filter=True): u_k[l] = statevec # set up new multi-index - newind = ['time'] + lambdas + newind = ["time"] + lambdas u_k = u_k.reset_index().set_index(newind) - u_k.name = 'u_nk' + u_k.name = "u_nk" return u_k + @_init_attrs def extract_dHdl(xvg, T, filter=True): r"""Return gradients `dH/dl` from a Hamiltonian differences XVG file. @@ -182,7 +189,7 @@ def extract_dHdl(xvg, T, filter=True): parsed and is turned on by default. """ - beta = 1/(k_b * T) + beta = 1 / (k_b * T) headers = _get_headers(xvg) state, lambdas, statevec = _extract_state(xvg, headers) @@ -204,10 +211,13 @@ def extract_dHdl(xvg, T, filter=True): # rename columns to not include the word 'lambda', since we use this for # index below - cols = [l.split('-')[0] for l in lambdas] + cols = [l.split("-")[0] for l in lambdas] - dHdl = pd.DataFrame(dHdl.values, columns=cols, - index=pd.Index(times.values, name='time', dtype='Float64')) + dHdl = pd.DataFrame( + dHdl.values, + columns=cols, + index=pd.Index(times.values, name="time", dtype="Float64"), + ) # create columns for each lambda, indicating state each row sampled from # if state is None run as expanded ensemble data or REX @@ -215,8 +225,8 @@ def extract_dHdl(xvg, T, filter=True): # if thermodynamic state is specified map thermodynamic # state data to lambda values, else (for REX) # define state based on the legend - if 'Thermodynamic state' in df: - ts_index = df.columns.get_loc('Thermodynamic state') + if "Thermodynamic state" in df: + ts_index = df.columns.get_loc("Thermodynamic state") thermo_state = df[df.columns[ts_index]] for i, l in enumerate(lambdas): v = [] @@ -235,10 +245,10 @@ def extract_dHdl(xvg, T, filter=True): dHdl[l] = statevec # set up new multi-index - newind = ['time'] + lambdas - dHdl= dHdl.reset_index().set_index(newind) + newind = ["time"] + lambdas + dHdl = dHdl.reset_index().set_index(newind) - dHdl.name='dH/dl' + dHdl.name = "dH/dl" return dHdl @@ -289,34 +299,44 @@ def _extract_state(xvg, headers=None): state = None if headers is None: headers = _get_headers(xvg) - subtitle = _get_value_by_key(headers, 'subtitle') - if subtitle and 'state' in subtitle: - state = int(subtitle.split('state')[1].split(':')[0]) - lambdas = [word.strip(')(,') for word in subtitle.split() if 'lambda' in word] - statevec = eval(subtitle.strip().split(' = ')[-1].strip('"')) + subtitle = _get_value_by_key(headers, "subtitle") + if subtitle and "state" in subtitle: + state = int(subtitle.split("state")[1].split(":")[0]) + lambdas = [word.strip(")(,") for word in subtitle.split() if "lambda" in word] + statevec = eval(subtitle.strip().split(" = ")[-1].strip('"')) # if expanded ensemble data is used the state variable will never be assigned # parsing expanded ensemble data if state is None: lambdas = [] statevec = [] - for line in headers['_raw_lines']: - if ('legend' in line) and ('lambda' in line): - lambdas.append([word.strip(')(,') for word in line.split() if 'lambda' in word][0]) - if ('legend' in line) and (' to ' in line): - statevec.append(([float(i) for i in line.strip().split(' to ')[-1].strip('"()').split(',')])) + for line in headers["_raw_lines"]: + if ("legend" in line) and ("lambda" in line): + lambdas.append( + [word.strip(")(,") for word in line.split() if "lambda" in word][0] + ) + if ("legend" in line) and (" to " in line): + statevec.append( + ( + [ + float(i) + for i in line.strip() + .split(" to ")[-1] + .strip('"()') + .split(",") + ] + ) + ) return state, lambdas, statevec def _extract_legend(xvg): - """Extract information on state sampled for REX simulations. - - """ + """Extract information on state sampled for REX simulations.""" state_legend = {} - with anyopen(xvg, 'r') as f: + with anyopen(xvg, "r") as f: for line in f: - if ('legend' in line) and ('lambda' in line): + if ("legend" in line) and ("lambda" in line): state_legend[line.split()[4]] = float(line.split()[6].strip('"')) return state_legend @@ -344,31 +364,49 @@ def _extract_dataframe(xvg, headers=None, filter=filter): if headers is None: headers = _get_headers(xvg) - xaxis = _get_value_by_key(headers, 'xaxis', 'label') - names = [_get_value_by_key(headers, 's{}'.format(x), 'legend') for x in - range(len(headers)) if 's{}'.format(x) in headers] + xaxis = _get_value_by_key(headers, "xaxis", "label") + names = [ + _get_value_by_key(headers, "s{}".format(x), "legend") + for x in range(len(headers)) + if "s{}".format(x) in headers + ] cols = [xaxis] + names # march through column names, mark duplicates when found - cols = [col + "{}[duplicated]".format(i) if col in cols[:i] else col - for i, col, in enumerate(cols)] + cols = [ + col + "{}[duplicated]".format(i) if col in cols[:i] else col + for i, col, in enumerate(cols) + ] - header_cnt = len(headers['_raw_lines']) + header_cnt = len(headers["_raw_lines"]) if not filter: # assumes clean files - df = pd.read_csv(xvg, sep=r"\s+", header=None, skiprows=header_cnt, - na_filter=True, memory_map=True, names=cols, - dtype=np.float64, - float_precision='high') + df = pd.read_csv( + xvg, + sep=r"\s+", + header=None, + skiprows=header_cnt, + na_filter=True, + memory_map=True, + names=cols, + dtype=np.float64, + float_precision="high", + ) else: - df = pd.read_csv(xvg, sep=r"\s+", header=None, skiprows=header_cnt, - memory_map=True, on_bad_lines='skip') + df = pd.read_csv( + xvg, + sep=r"\s+", + header=None, + skiprows=header_cnt, + memory_map=True, + on_bad_lines="skip", + ) # If names=cols is passed to read_csv, rows with more than the # designated columns will be truncated and used instead of discarded. df.rename(columns={i: name for i, name in enumerate(cols)}, inplace=True) # If dtype=np.float64 and float_precision='high' are passed to read_csv, # 12.345.56 and - cannot be read. - df = df.apply(pd.to_numeric, errors='coerce') + df = df.apply(pd.to_numeric, errors="coerce") # drop duplicate df.dropna(inplace=True) @@ -423,7 +461,7 @@ def _parse_header(line, headers={}, depth=2): else: break - next_t["_val"] = ''.join(s[1:]).rstrip().strip('"') + next_t["_val"] = "".join(s[1:]).rstrip().strip('"') def _get_headers(xvg): @@ -484,17 +522,17 @@ def _get_headers(xvg): headers: dict """ - with anyopen(xvg, 'r') as f: - headers = { '_raw_lines': [] } + with anyopen(xvg, "r") as f: + headers = {"_raw_lines": []} for line in f: line = line.strip() if len(line) == 0: continue - if line.startswith('@'): + if line.startswith("@"): _parse_header(line, headers) - headers['_raw_lines'].append(line) - elif line.startswith('#'): - headers['_raw_lines'].append(line) + headers["_raw_lines"].append(line) + elif line.startswith("#"): + headers["_raw_lines"].append(line) continue # assuming to start a body section else: @@ -522,8 +560,8 @@ def _get_value_by_key(headers, key1, key2=None): val = None if key1 in headers: if key2 is not None and key2 in headers[key1]: - val = headers[key1][key2]['_val'] + val = headers[key1][key2]["_val"] else: - val = headers[key1]['_val'] + val = headers[key1]["_val"] return val diff --git a/src/alchemlyb/parsing/gomc.py b/src/alchemlyb/parsing/gomc.py index 7cf03af4..90124687 100644 --- a/src/alchemlyb/parsing/gomc.py +++ b/src/alchemlyb/parsing/gomc.py @@ -3,12 +3,13 @@ """ import pandas as pd -from .util import anyopen from . import _init_attrs +from .util import anyopen from ..postprocessors.units import R_kJmol k_b = R_kJmol + @_init_attrs def extract_u_nk(filename, T): """Return reduced potentials `u_nk` from a Hamiltonian differences dat file. @@ -34,9 +35,9 @@ def extract_u_nk(filename, T): dh_col_match = "dU/dL" h_col_match = "DelE" - pv_col_match = 'PV' - u_col_match = ['Total_En'] - beta = 1/(k_b * T) + pv_col_match = "PV" + u_col_match = ["Total_En"] + beta = 1 / (k_b * T) state, lambdas, statevec = _extract_state(filename) @@ -56,7 +57,11 @@ def extract_u_nk(filename, T): pv = df[pv_cols[0]] # GOMC also gives us total energy U directly; need this for reduced potential - u_cols = [col for col in df.columns if any(single_u_col_match in col for single_u_col_match in u_col_match)] + u_cols = [ + col + for col in df.columns + if any(single_u_col_match in col for single_u_col_match in u_col_match) + ] u = None if u_cols: u = df[u_cols[0]] @@ -64,7 +69,7 @@ def extract_u_nk(filename, T): u_k = dict() cols = list() for col in dH: - u_col = eval(col.split('->')[1][:-1]) + u_col = eval(col.split("->")[1][:-1]) # calculate reduced potential u_k = dH + pV + U u_k[u_col] = beta * dH[col].values if pv_cols: @@ -73,8 +78,9 @@ def extract_u_nk(filename, T): u_k[u_col] += beta * u.values cols.append(u_col) - u_k = pd.DataFrame(u_k, columns=cols, - index=pd.Index(times.values, name='time', dtype='Float64')) + u_k = pd.DataFrame( + u_k, columns=cols, index=pd.Index(times.values, name="time", dtype="Float64") + ) # Need to modify the lambda name cols = [l + "-lambda" for l in lambdas] @@ -83,13 +89,14 @@ def extract_u_nk(filename, T): u_k[l] = statevec[i] # set up new multi-index - newind = ['time'] + cols + newind = ["time"] + cols u_k = u_k.reset_index().set_index(newind) - u_k.name = 'u_nk' + u_k.name = "u_nk" return u_k + @_init_attrs def extract_dHdl(filename, T): """Return gradients `dH/dl` from a Hamiltonian differences free energy file. @@ -112,7 +119,7 @@ def extract_dHdl(filename, T): the constants used by the corresponding MD engine. """ - beta = 1/(k_b * T) + beta = 1 / (k_b * T) state, lambdas, statevec = _extract_state(filename) @@ -131,8 +138,11 @@ def extract_dHdl(filename, T): # make dimensionless dHdl *= beta - dHdl = pd.DataFrame(dHdl.values, columns=lambdas, - index=pd.Index(times.values, name='time', dtype='Float64')) + dHdl = pd.DataFrame( + dHdl.values, + columns=lambdas, + index=pd.Index(times.values, name="time", dtype="Float64"), + ) # Need to modify the lambda name cols = [l + "-lambda" for l in lambdas] @@ -141,10 +151,10 @@ def extract_dHdl(filename, T): dHdl[l] = statevec[i] # set up new multi-index - newind = ['time'] + cols - dHdl= dHdl.reset_index().set_index(newind) + newind = ["time"] + cols + dHdl = dHdl.reset_index().set_index(newind) - dHdl.name='dH/dl' + dHdl.name = "dH/dl" return dHdl @@ -180,33 +190,29 @@ def extract(filename, T): def _extract_state(filename): - """Extract information on state sampled, names of lambdas. - - """ + """Extract information on state sampled, names of lambdas.""" state = None - with anyopen(filename, 'r') as f: + with anyopen(filename, "r") as f: for line in f: - if ('#' in line) and ('State' in line): - state = int(line.split('State')[1].split(':')[0]) + if ("#" in line) and ("State" in line): + state = int(line.split("State")[1].split(":")[0]) # GOMC always print these two fields - lambdas = ['Coulomb', 'VDW'] - statevec = eval(line.strip().split(' = ')[-1]) + lambdas = ["Coulomb", "VDW"] + statevec = eval(line.strip().split(" = ")[-1]) break return state, lambdas, statevec def _extract_dataframe(filename): - """Extract a DataFrame from free energy data. - - """ + """Extract a DataFrame from free energy data.""" dh_col_match = "dU/dL" h_col_match = "DelE" - pv_col_match = 'PV' - u_col_match = 'Total_En' + pv_col_match = "PV" + u_col_match = "Total_En" xaxis = "time" - with anyopen(filename, 'r') as f: + with anyopen(filename, "r") as f: names = [] rows = [] for line in f: @@ -214,7 +220,7 @@ def _extract_dataframe(filename): if len(line) == 0: # avoid parsing empty line continue - elif line.startswith('#T'): + elif line.startswith("#T"): # this line has state information. No need to be parsed continue elif line.startswith("#Steps"): diff --git a/src/alchemlyb/parsing/namd.py b/src/alchemlyb/parsing/namd.py index c4181b1d..1647467c 100644 --- a/src/alchemlyb/parsing/namd.py +++ b/src/alchemlyb/parsing/namd.py @@ -1,13 +1,15 @@ """Parsers for extracting alchemical data from `NAMD `_ output files. """ -import pandas as pd -import numpy as np +import logging from os.path import basename from re import split -import logging -from .util import anyopen + +import numpy as np +import pandas as pd + from . import _init_attrs +from .util import anyopen from ..postprocessors.units import R_kJmol, kJ2kcal logger = logging.getLogger("alchemlyb.parsers.NAMD") @@ -21,12 +23,12 @@ def _filename_sort_key(s): This means that unlike with the standard Python sorted() function, "foo9" < "foo10". """ - return [int(t) if t.isdigit() else t.lower() for t in split(r'(\d+)', basename(s))] + return [int(t) if t.isdigit() else t.lower() for t in split(r"(\d+)", basename(s))] def _get_lambdas(fep_files): """Retrieves all lambda values included in the FEP files provided. - + We have to do this in order to tolerate truncated and restarted fepout files. The IDWS lambda is not present at the termination of the window, presumably for backwards compatibility with ParseFEP and probably other things. @@ -48,25 +50,25 @@ def _get_lambdas(fep_files): endpoint_windows = [] for fep_file in sorted(fep_files, key=_filename_sort_key): - with anyopen(fep_file, 'r') as f: + with anyopen(fep_file, "r") as f: for line in f: l = line.strip().split() # We might not have a #NEW line so make the best guess - if l[0] == '#NEW': + if l[0] == "#NEW": lambda1, lambda2 = float(l[6]), float(l[8]) - lambda_idws = float(l[10]) if 'LAMBDA_IDWS' in l else None - elif l[0] == '#Free': + lambda_idws = float(l[10]) if "LAMBDA_IDWS" in l else None + elif l[0] == "#Free": lambda1, lambda2, lambda_idws = float(l[7]), float(l[8]), None else: # We only care about lines with lambda values. No need to # do all that other processing below for every line - continue # pragma: no cover + continue # pragma: no cover # Keep track of whether the lambda values are increasing or decreasing, so we can return # a sorted list of the lambdas in the correct order. # If it changes during parsing of this set of fepout files, then we know something is wrong - + # Keep track of endpoints separately since in IDWS runs there must be one of opposite direction if 0.0 in (lambda1, lambda2) or 1.0 in (lambda1, lambda2): endpoint_windows.append((lambda1, lambda2)) @@ -78,23 +80,35 @@ def _get_lambdas(fep_files): is_ascending.add(lambda1 > lambda_idws) if len(is_ascending) > 1: - raise ValueError(f'Lambda values change direction in {fep_file}, relative to the other files: {lambda1} -> {lambda2} (IDWS: {lambda_idws})') + raise ValueError( + f"Lambda values change direction in {fep_file}, relative to the other files: {lambda1} -> {lambda2} (IDWS: {lambda_idws})" + ) # Make sure the lambda2 values are consistent if lambda1 in lambda_fwd_map and lambda_fwd_map[lambda1] != lambda2: - logger.error(f'fwd: lambda1 {lambda1} has lambda2 {lambda_fwd_map[lambda1]} in {fep_file} but it has already been {lambda2}') - raise ValueError('More than one lambda2 value for a particular lambda1') + logger.error( + f"fwd: lambda1 {lambda1} has lambda2 {lambda_fwd_map[lambda1]} in {fep_file} but it has already been {lambda2}" + ) + raise ValueError( + "More than one lambda2 value for a particular lambda1" + ) lambda_fwd_map[lambda1] = lambda2 # Make sure the lambda_idws values are consistent if lambda_idws is not None: - if lambda1 in lambda_bwd_map and lambda_bwd_map[lambda1] != lambda_idws: - logger.error(f'bwd: lambda1 {lambda1} has lambda_idws {lambda_bwd_map[lambda1]} but it has already been {lambda_idws}') - raise ValueError('More than one lambda_idws value for a particular lambda1') + if ( + lambda1 in lambda_bwd_map + and lambda_bwd_map[lambda1] != lambda_idws + ): + logger.error( + f"bwd: lambda1 {lambda1} has lambda_idws {lambda_bwd_map[lambda1]} but it has already been {lambda_idws}" + ) + raise ValueError( + "More than one lambda_idws value for a particular lambda1" + ) lambda_bwd_map[lambda1] = lambda_idws - is_ascending = next(iter(is_ascending)) all_lambdas = set() @@ -147,7 +161,7 @@ def extract_u_nk(fep_files, T): `fep_files` can now be a list of filenames. """ - beta = 1/(k_b * T) + beta = 1 / (k_b * T) # lists to get times and work values of each window win_ts = [] @@ -156,7 +170,7 @@ def extract_u_nk(fep_files, T): win_de_back = [] # create dataframe for results - u_nk = pd.DataFrame(columns=['time','fep-lambda']) + u_nk = pd.DataFrame(columns=["time", "fep-lambda"]) # boolean flag to parse data after equil time parsing = False @@ -176,32 +190,36 @@ def extract_u_nk(fep_files, T): for fep_file in sorted(fep_files, key=_filename_sort_key): # Note we have not set parsing=False because we could be continuing one window across # more than one fepout file - with anyopen(fep_file, 'r') as f: + with anyopen(fep_file, "r") as f: has_idws = False for line in f: l = line.strip().split() # We don't know if IDWS was enabled just from the #Free line, and we might not have # a #NEW line in this file, so we have to check for the existence of FepE_back lines # We rely on short-circuit evaluation to avoid the string comparison most of the time - if has_idws is False and l[0] == 'FepE_back:': + if has_idws is False and l[0] == "FepE_back:": has_idws = True # New window, get IDWS lambda if any # We keep track of lambdas from the #NEW line and if they disagree with the #Free line # within the same file, then complain. This can happen if truncated fepout files # are presented in the wrong order. - if l[0] == '#NEW': + if l[0] == "#NEW": if parsing: - logger.error(f'Window with lambda1: {lambda1_at_start} lambda2: {lambda2_at_start} lambda_idws: {lambda_idws_at_start} appears truncated') - logger.error(f'because a new window was encountered in {fep_file} before the previous one finished.') - raise ValueError('New window begun after truncated window') + logger.error( + f"Window with lambda1: {lambda1_at_start} lambda2: {lambda2_at_start} lambda_idws: {lambda_idws_at_start} appears truncated" + ) + logger.error( + f"because a new window was encountered in {fep_file} before the previous one finished." + ) + raise ValueError("New window begun after truncated window") lambda1_at_start, lambda2_at_start = float(l[6]), float(l[8]) - lambda_idws_at_start = float(l[10]) if 'LAMBDA_IDWS' in l else None + lambda_idws_at_start = float(l[10]) if "LAMBDA_IDWS" in l else None has_idws = True if lambda_idws_at_start is not None else False # this line marks end of window; dump data into dataframe - if l[0] == '#Free': + if l[0] == "#Free": # extract lambda values for finished window # lambda1 = sampling lambda (row), lambda2 = comparison lambda (col) lambda1 = float(l[7]) @@ -210,17 +228,25 @@ def extract_u_nk(fep_files, T): # If the lambdas are not what we thought they would be, raise an exception to ensure the calculation # fails. This can happen if fepouts where one window spans multiple fepouts are processed out of order # NB: There is no way to tell if lambda_idws changed because it isn't in the '#Free' line that ends a window - if lambda1_at_start is not None \ - and (lambda1, lambda2) != (lambda1_at_start, lambda2_at_start): - logger.error(f"Lambdas changed unexpectedly while processing {fep_file}") - logger.error(f"l1, l2: {lambda1_at_start}, {lambda2_at_start} changed to {lambda1}, {lambda2}") + if lambda1_at_start is not None and (lambda1, lambda2) != ( + lambda1_at_start, + lambda2_at_start, + ): + logger.error( + f"Lambdas changed unexpectedly while processing {fep_file}" + ) + logger.error( + f"l1, l2: {lambda1_at_start}, {lambda2_at_start} changed to {lambda1}, {lambda2}" + ) logger.error(line) - raise ValueError("Inconsistent lambda values within the same window") + raise ValueError( + "Inconsistent lambda values within the same window" + ) # As we are at the end of a window, convert last window's work and times values to np arrays # (with energy unit kT since they were kcal/mol in the fepouts) - win_de_arr = beta * np.asarray(win_de) # dE values - win_ts_arr = np.asarray(win_ts) # timesteps + win_de_arr = beta * np.asarray(win_de) # dE values + win_ts_arr = np.asarray(win_ts) # timesteps # This handles the special case where there are IDWS energies but no lambda_idws value in the # current .fepout file. This can happen when the NAMD firsttimestep is not 0, because NAMD only emits @@ -236,10 +262,16 @@ def extract_u_nk(fep_files, T): # Test for the highly pathological case where the first window is both incomplete and has IDWS # data but no lambda_idws value. if l1_idx == 0: - raise ValueError(f'IDWS data present in first window but lambda_idws not included; no way to infer the correct lambda_idws') + raise ValueError( + f"IDWS data present in first window but lambda_idws not included; no way to infer the correct lambda_idws" + ) lambda_idws_at_start = all_lambdas[l1_idx - 1] - logger.warning(f'Warning: {fep_file} has IDWS data but lambda_idws not included.') - logger.warning(f' lambda1 = {lambda1}, lambda2 = {lambda2}; inferring lambda_idws to be {lambda_idws_at_start}') + logger.warning( + f"Warning: {fep_file} has IDWS data but lambda_idws not included." + ) + logger.warning( + f" lambda1 = {lambda1}, lambda2 = {lambda2}; inferring lambda_idws to be {lambda_idws_at_start}" + ) if lambda_idws_at_start is not None: # Mimic classic DWS data @@ -248,22 +280,28 @@ def extract_u_nk(fep_files, T): win_de_back_arr = beta * np.asarray(win_de_back) n = min(len(win_de_back_arr), len(win_de_arr)) - tempDF = pd.DataFrame({ - 'time': win_ts_arr[:n], - 'fep-lambda': np.full(n,lambda1), - lambda1: 0, - lambda2: win_de_arr[:n], - lambda_idws_at_start: win_de_back_arr[:n]}) + tempDF = pd.DataFrame( + { + "time": win_ts_arr[:n], + "fep-lambda": np.full(n, lambda1), + lambda1: 0, + lambda2: win_de_arr[:n], + lambda_idws_at_start: win_de_back_arr[:n], + } + ) # print(f"{fep_file}: IDWS window {lambda1} {lambda2} {lambda_idws_at_start}") else: # print(f"{fep_file}: Forward-only window {lambda1} {lambda2}") # create dataframe of times and work values # this window's data goes in row LAMBDA1 and column LAMBDA2 - tempDF = pd.DataFrame({ - 'time': win_ts_arr, - 'fep-lambda': np.full(len(win_de_arr), lambda1), - lambda1: 0, - lambda2: win_de_arr}) + tempDF = pd.DataFrame( + { + "time": win_ts_arr, + "fep-lambda": np.full(len(win_de_arr), lambda1), + lambda1: 0, + lambda2: win_de_arr, + } + ) # join the new window's df to existing df u_nk = pd.concat([u_nk, tempDF], sort=False) @@ -275,38 +313,42 @@ def extract_u_nk(fep_files, T): win_ts_back = [] parsing = False has_idws = False - lambda1_at_start, lambda2_at_start, lambda_idws_at_start = None, None, None + lambda1_at_start, lambda2_at_start, lambda_idws_at_start = ( + None, + None, + None, + ) # append work value from 'dE' column of fepout file if parsing: - if l[0] == 'FepEnergy:': + if l[0] == "FepEnergy:": win_de.append(float(l[6])) win_ts.append(float(l[1])) - elif l[0] == 'FepE_back:': + elif l[0] == "FepE_back:": win_de_back.append(float(l[6])) win_ts_back.append(float(l[1])) # Turn parsing on after line 'STARTING COLLECTION OF ENSEMBLE AVERAGE' - if '#STARTING' in l: + if "#STARTING" in l: parsing = True - if len(win_de) != 0 or len(win_de_back) != 0: # pragma: no cover - logger.warning('Trailing data without footer line (\"#Free energy...\"). Interrupted run?') - raise ValueError('Last window is truncated') - + if len(win_de) != 0 or len(win_de_back) != 0: # pragma: no cover + logger.warning( + 'Trailing data without footer line ("#Free energy..."). Interrupted run?' + ) + raise ValueError("Last window is truncated") if lambda2 in (0.0, 1.0): # this excludes the IDWS case where a dataframe already exists for both endpoints # create last dataframe for fep-lambda at last LAMBDA2 - tempDF = pd.DataFrame({ - 'time': win_ts_arr, - 'fep-lambda': lambda2}) + tempDF = pd.DataFrame({"time": win_ts_arr, "fep-lambda": lambda2}) u_nk = pd.concat([u_nk, tempDF], sort=True) - u_nk.set_index(['time','fep-lambda'], inplace=True) + u_nk.set_index(["time", "fep-lambda"], inplace=True) return u_nk + def extract(fep_files, T): """Return reduced potentials `u_nk` from NAMD fepout file(s). @@ -342,4 +384,6 @@ def extract(fep_files, T): .. versionadded:: 1.0.0 """ - return {"u_nk": extract_u_nk(fep_files, T)} # NOTE: maybe we should also have 'dHdl': None + return { + "u_nk": extract_u_nk(fep_files, T) + } # NOTE: maybe we should also have 'dHdl': None diff --git a/src/alchemlyb/parsing/util.py b/src/alchemlyb/parsing/util.py index f8259aa6..28e5a568 100644 --- a/src/alchemlyb/parsing/util.py +++ b/src/alchemlyb/parsing/util.py @@ -1,23 +1,24 @@ """Collection of utilities used by many parsers. """ -import os -from os import PathLike -from typing import IO, Optional, Union import bz2 import gzip +import os +from os import PathLike +from typing import IO, Union + def bz2_open(filename, mode): - mode += 't' if mode in ['r','w','a','x'] else '' + mode += "t" if mode in ["r", "w", "a", "x"] else "" return bz2.open(filename, mode) def gzip_open(filename, mode): - mode += 't' if mode in ['r','w','a','x'] else '' + mode += "t" if mode in ["r", "w", "a", "x"] else "" return gzip.open(filename, mode) -def anyopen(datafile: Union[PathLike, IO], mode='r', compression=None): +def anyopen(datafile: Union[PathLike, IO], mode="r", compression=None): """Return a file stream for file or stream, even if compressed. Supports files compressed with bzip2 (.bz2) and gzip (.gz) compression @@ -59,16 +60,15 @@ def anyopen(datafile: Union[PathLike, IO], mode='r', compression=None): """ # opener for each type of file - extensions = {'.bz2': bz2_open, - '.gz': gzip_open} + extensions = {".bz2": bz2_open, ".gz": gzip_open} # compression selections available - compressions = {'bzip2': bz2_open, - 'gzip': gzip_open} + compressions = {"bzip2": bz2_open, "gzip": gzip_open} # if `datafile` is a stream - if ((hasattr(datafile, 'read') and any((i in mode for i in ('r',)))) or - (hasattr(datafile, 'write') and any((i in mode for i in ('w', 'a', 'x'))))): + if (hasattr(datafile, "read") and any((i in mode for i in ("r",)))) or ( + hasattr(datafile, "write") and any((i in mode for i in ("w", "a", "x"))) + ): # if no compression specified, just pass the stream through if compression is None: return datafile @@ -76,7 +76,9 @@ def anyopen(datafile: Union[PathLike, IO], mode='r', compression=None): compressor = compressions[compression] return compressor(datafile, mode=mode) else: - raise ValueError("`datafile` is a stream, but specified `compression` '{compression}' is not supported") + raise ValueError( + "`datafile` is a stream, but specified `compression` '{compression}' is not supported" + ) # otherwise, treat as a file # allow compression to override any extension on the file diff --git a/src/alchemlyb/postprocessors/__init__.py b/src/alchemlyb/postprocessors/__init__.py index 6e769ac4..932d2b06 100644 --- a/src/alchemlyb/postprocessors/__init__.py +++ b/src/alchemlyb/postprocessors/__init__.py @@ -1,3 +1,3 @@ __all__ = [ - 'units', + "units", ] diff --git a/src/alchemlyb/postprocessors/units.py b/src/alchemlyb/postprocessors/units.py index f5e1984d..510b4465 100644 --- a/src/alchemlyb/postprocessors/units.py +++ b/src/alchemlyb/postprocessors/units.py @@ -12,8 +12,9 @@ #: in :mod:`scipy.constants` R_kJmol = R / 1000 + def to_kT(df, T=None): - """ Convert the unit of a DataFrame to `kT`. + """Convert the unit of a DataFrame to `kT`. If temperature `T` is not provided, the DataFrame need to have attribute `temperature` and `energy_unit`. Otherwise, the temperature of the output @@ -33,28 +34,28 @@ def to_kT(df, T=None): """ new_df = df.copy() if T is not None: - new_df.attrs['temperature'] = T - elif 'temperature' not in df.attrs: - raise TypeError('Attribute temperature not found in the input ' - 'Dataframe.') + new_df.attrs["temperature"] = T + elif "temperature" not in df.attrs: + raise TypeError("Attribute temperature not found in the input " "Dataframe.") - if 'energy_unit' not in df.attrs: - raise TypeError('Attribute energy_unit not found in the input ' - 'Dataframe.') + if "energy_unit" not in df.attrs: + raise TypeError("Attribute energy_unit not found in the input " "Dataframe.") - if df.attrs['energy_unit'] == 'kT': + if df.attrs["energy_unit"] == "kT": return new_df - elif df.attrs['energy_unit'] == 'kJ/mol': - new_df /= R_kJmol * df.attrs['temperature'] - new_df.attrs['energy_unit'] = 'kT' + elif df.attrs["energy_unit"] == "kJ/mol": + new_df /= R_kJmol * df.attrs["temperature"] + new_df.attrs["energy_unit"] = "kT" return new_df - elif df.attrs['energy_unit'] == 'kcal/mol': - new_df /= R_kJmol * df.attrs['temperature'] * kJ2kcal - new_df.attrs['energy_unit'] = 'kT' + elif df.attrs["energy_unit"] == "kcal/mol": + new_df /= R_kJmol * df.attrs["temperature"] * kJ2kcal + new_df.attrs["energy_unit"] = "kT" return new_df else: - raise ValueError('energy_unit {} can only be kT, kJ/mol or ' \ - 'kcal/mol.'.format(df.attrs['energy_unit'])) + raise ValueError( + "energy_unit {} can only be kT, kJ/mol or " + "kcal/mol.".format(df.attrs["energy_unit"]) + ) def to_kcalmol(df, T=None): @@ -77,10 +78,11 @@ def to_kcalmol(df, T=None): `df` converted. """ kt_df = to_kT(df, T) - kt_df *= R_kJmol * df.attrs['temperature'] * kJ2kcal - kt_df.attrs['energy_unit'] = 'kcal/mol' + kt_df *= R_kJmol * df.attrs["temperature"] * kJ2kcal + kt_df.attrs["energy_unit"] = "kcal/mol" return kt_df + def to_kJmol(df, T=None): """Convert the unit of a DataFrame to kJ/mol. @@ -101,12 +103,13 @@ def to_kJmol(df, T=None): `df` converted. """ kt_df = to_kT(df, T) - kt_df *= R_kJmol * df.attrs['temperature'] - kt_df.attrs['energy_unit'] = 'kJ/mol' + kt_df *= R_kJmol * df.attrs["temperature"] + kt_df.attrs["energy_unit"] = "kJ/mol" return kt_df + def get_unit_converter(units): - """ Obtain the converter according to the unit string. + """Obtain the converter according to the unit string. If `units` is 'kT', the `to_kT` converter is returned. If `units` is 'kJ/mol', the `to_kJmol` converter is returned. If `units` is 'kcal/mol', @@ -125,12 +128,12 @@ def get_unit_converter(units): .. versionadded:: 0.5.0 """ - converters = {'kT': to_kT, 'kJ/mol': to_kJmol, - 'kcal/mol': to_kcalmol} + converters = {"kT": to_kT, "kJ/mol": to_kJmol, "kcal/mol": to_kcalmol} try: convert = converters[units] except KeyError: raise ValueError( f"Energy unit {units} is not supported, " - f"choose one of {list(converters.keys())}") + f"choose one of {list(converters.keys())}" + ) return convert diff --git a/src/alchemlyb/preprocessing/__init__.py b/src/alchemlyb/preprocessing/__init__.py index 6b759482..223c942e 100644 --- a/src/alchemlyb/preprocessing/__init__.py +++ b/src/alchemlyb/preprocessing/__init__.py @@ -3,16 +3,22 @@ preparing data for estimators. """ -from .subsampling import slicing, dhdl2series, u_nk2series, decorrelate_dhdl, decorrelate_u_nk -from .subsampling import statistical_inefficiency from .subsampling import equilibrium_detection +from .subsampling import ( + slicing, + dhdl2series, + u_nk2series, + decorrelate_dhdl, + decorrelate_u_nk, +) +from .subsampling import statistical_inefficiency __all__ = [ - 'slicing', - 'statistical_inefficiency', - 'equilibrium_detection', - 'decorrelate_dhdl', - 'decorrelate_u_nk', - 'dhdl2series', - 'u_nk2series' + "slicing", + "statistical_inefficiency", + "equilibrium_detection", + "decorrelate_dhdl", + "decorrelate_u_nk", + "dhdl2series", + "u_nk2series", ] diff --git a/src/alchemlyb/preprocessing/subsampling.py b/src/alchemlyb/preprocessing/subsampling.py index 3b8bd8af..653adbca 100644 --- a/src/alchemlyb/preprocessing/subsampling.py +++ b/src/alchemlyb/preprocessing/subsampling.py @@ -4,13 +4,18 @@ import warnings import pandas as pd -from pymbar.timeseries import (statisticalInefficiency, - detectEquilibration, - subsampleCorrelatedData, ) +from pymbar.timeseries import ( + statisticalInefficiency, + detectEquilibration, + subsampleCorrelatedData, +) + from .. import pass_attrs -def decorrelate_u_nk(df, method='dE', drop_duplicates=True, - sort=True, remove_burnin=False, **kwargs): + +def decorrelate_u_nk( + df, method="dE", drop_duplicates=True, sort=True, remove_burnin=False, **kwargs +): """Subsample an u_nk DataFrame based on the selected method. The method can be either 'all' (obtained as a sum over all energy @@ -57,8 +62,8 @@ def decorrelate_u_nk(df, method='dE', drop_duplicates=True, deprecate the 'dhdl'. """ - kwargs['drop_duplicates'] = drop_duplicates - kwargs['sort'] = sort + kwargs["drop_duplicates"] = drop_duplicates + kwargs["sort"] = sort series = u_nk2series(df, method) @@ -67,8 +72,10 @@ def decorrelate_u_nk(df, method='dE', drop_duplicates=True, else: return statistical_inefficiency(df, series, **kwargs) -def decorrelate_dhdl(df, drop_duplicates=True, sort=True, - remove_burnin=False, **kwargs): + +def decorrelate_dhdl( + df, drop_duplicates=True, sort=True, remove_burnin=False, **kwargs +): """Subsample a dhdl DataFrame. This is a wrapper function around the function :func:`~alchemlyb.preprocessing.subsampling.statistical_inefficiency` and @@ -111,8 +118,8 @@ def decorrelate_dhdl(df, drop_duplicates=True, sort=True, """ - kwargs['drop_duplicates'] = drop_duplicates - kwargs['sort'] = sort + kwargs["drop_duplicates"] = drop_duplicates + kwargs["sort"] = sort series = dhdl2series(df) @@ -121,8 +128,9 @@ def decorrelate_dhdl(df, drop_duplicates=True, sort=True, else: return statistical_inefficiency(df, series, **kwargs) + @pass_attrs -def u_nk2series(df, method='dE'): +def u_nk2series(df, method="dE"): """Convert an u_nk DataFrame into a series based on the selected method for subsampling. @@ -152,18 +160,22 @@ def u_nk2series(df, method='dE'): # deprecation: remove in 3.0.0 # (the deprecations should show up in the calling functions) - if method == 'dhdl': - warnings.warn("Method 'dhdl' has been deprecated, using 'dE' instead. " - "'dhdl' will be removed in alchemlyb 3.0.0.", - category=DeprecationWarning, - stacklevel=2) - method = 'dE' - elif method == 'dhdl_all': - warnings.warn("Method 'dhdl_all' has been deprecated, using 'all' instead. " - "'dhdl_all' will be removed in alchemlyb 3.0.0.", - category=DeprecationWarning, - stacklevel=2) - method = 'all' + if method == "dhdl": + warnings.warn( + "Method 'dhdl' has been deprecated, using 'dE' instead. " + "'dhdl' will be removed in alchemlyb 3.0.0.", + category=DeprecationWarning, + stacklevel=2, + ) + method = "dE" + elif method == "dhdl_all": + warnings.warn( + "Method 'dhdl_all' has been deprecated, using 'all' instead. " + "'dhdl_all' will be removed in alchemlyb 3.0.0.", + category=DeprecationWarning, + stacklevel=2, + ) + method = "all" # Check if the input is u_nk try: @@ -172,11 +184,11 @@ def u_nk2series(df, method='dE'): key = key[0] df[key] except KeyError: - raise ValueError('The input should be u_nk') + raise ValueError("The input should be u_nk") - if method == 'all': + if method == "all": series = df.sum(axis=1) - elif method == 'dE': + elif method == "dE": # Using the same logic as alchemical-analysis key = df.index.values[0][1:] if len(key) == 1: @@ -192,13 +204,12 @@ def u_nk2series(df, method='dE'): else: series = df.iloc[:, index - 1] else: - raise ValueError( - 'Decorrelation method {} not found.'.format(method)) + raise ValueError("Decorrelation method {} not found.".format(method)) return series @pass_attrs -def dhdl2series(df, method='all'): +def dhdl2series(df, method="all"): """Convert a dhdl DataFrame to a series for subsampling. The series is generated by summing over all energy components (axis 1 of @@ -235,13 +246,15 @@ def dhdl2series(df, method='all'): def _check_multiple_times(df): if isinstance(df, pd.Series): - return df.sort_index(axis=0).reset_index('time', name='').duplicated('time').any() + return ( + df.sort_index(axis=0).reset_index("time", name="").duplicated("time").any() + ) else: - return df.sort_index(axis=0).reset_index('time').duplicated('time').any() + return df.sort_index(axis=0).reset_index("time").duplicated("time").any() def _check_sorted(df): - return df.reset_index(0)['time'].is_monotonic_increasing + return df.reset_index(0)["time"].is_monotonic_increasing def _drop_duplicates(df, series=None): @@ -265,34 +278,44 @@ def _drop_duplicates(df, series=None): """ if isinstance(df, pd.Series): # remove the duplicate based on time - drop_duplicates_series = df.reset_index('time', name=''). \ - drop_duplicates('time') + drop_duplicates_series = df.reset_index("time", name="").drop_duplicates("time") # Rest the time index - lambda_names = ['time', ] + lambda_names = [ + "time", + ] lambda_names.extend(drop_duplicates_series.index.names) - df = drop_duplicates_series.set_index('time', append=True). \ - reorder_levels(lambda_names) + df = drop_duplicates_series.set_index("time", append=True).reorder_levels( + lambda_names + ) else: # remove the duplicate based on time - drop_duplicates_df = df.reset_index('time').drop_duplicates('time') + drop_duplicates_df = df.reset_index("time").drop_duplicates("time") # Rest the time index - lambda_names = ['time', ] + lambda_names = [ + "time", + ] lambda_names.extend(drop_duplicates_df.index.names) - df = drop_duplicates_df.set_index('time', append=True). \ - reorder_levels(lambda_names) + df = drop_duplicates_df.set_index("time", append=True).reorder_levels( + lambda_names + ) # Do the same withing with the series if series is not None: # remove the duplicate based on time - drop_duplicates_series = series.reset_index('time', name=''). \ - drop_duplicates('time') + drop_duplicates_series = series.reset_index("time", name="").drop_duplicates( + "time" + ) # Rest the time index - lambda_names = ['time', ] + lambda_names = [ + "time", + ] lambda_names.extend(drop_duplicates_series.index.names) - series = drop_duplicates_series.set_index('time', append=True). \ - reorder_levels(lambda_names) + series = drop_duplicates_series.set_index("time", append=True).reorder_levels( + lambda_names + ) return df, series + def _sort_by_time(df, series=None): """Sort the ``df`` by time which could be Dataframe or Series, if series is provided, sort the series as well. @@ -311,12 +334,13 @@ def _sort_by_time(df, series=None): series : Series Formatted Series. """ - df = df.sort_index(level='time') + df = df.sort_index(level="time") if series is not None: - series = series.sort_index(level='time') + series = series.sort_index(level="time") return df, series + def _prepare_input(df, series, drop_duplicates, sort): """Prepare and check the input to be used for statistical_inefficiency or equilibrium_detection. @@ -341,7 +365,8 @@ def _prepare_input(df, series, drop_duplicates, sort): raise KeyError( "Duplicate time values found; statistical inefficiency " "only works on a single, contiguous, " - "and sorted timeseries.") + "and sorted timeseries." + ) if not _check_sorted(df): if sort: @@ -349,16 +374,17 @@ def _prepare_input(df, series, drop_duplicates, sort): else: raise KeyError( "Statistical inefficiency only works as expected if " - "values are sorted by time, increasing.") + "values are sorted by time, increasing." + ) if series is not None: - if (len(series) != len(df) or - not all( - series.reset_index()['time'] == df.reset_index()['time'])): - raise ValueError( - "series and data must be sampled at the same times") + if len(series) != len(df) or not all( + series.reset_index()["time"] == df.reset_index()["time"] + ): + raise ValueError("series and data must be sampled at the same times") return df, series + def slicing(df, lower=None, upper=None, step=None, force=False): """Subsample a DataFrame using simple slicing. @@ -390,16 +416,25 @@ def slicing(df, lower=None, upper=None, step=None, force=False): raise KeyError("DataFrame rows must be sorted by time, increasing.") if not force and _check_multiple_times(df): - raise KeyError("Duplicate time values found; it's generally advised " - "to use slicing on DataFrames with unique time values " - "for each row. Use `force=True` to ignore this error.") + raise KeyError( + "Duplicate time values found; it's generally advised " + "to use slicing on DataFrames with unique time values " + "for each row. Use `force=True` to ignore this error." + ) return df -def statistical_inefficiency(df, series=None, lower=None, upper=None, - step=None, conservative=True, - drop_duplicates=False, sort=False): +def statistical_inefficiency( + df, + series=None, + lower=None, + upper=None, + step=None, + conservative=True, + drop_duplicates=False, + sort=False, +): """Subsample a DataFrame based on the calculated statistical inefficiency of a timeseries. @@ -480,8 +515,7 @@ def statistical_inefficiency(df, series=None, lower=None, upper=None, statinef = statisticalInefficiency(series, fast=False) # use the subsampleCorrelatedData function to get the subsample index - indices = subsampleCorrelatedData(series, g=statinef, - conservative=conservative) + indices = subsampleCorrelatedData(series, g=statinef, conservative=conservative) df = df.iloc[indices] else: df = slicing(df, lower=lower, upper=upper, step=step) @@ -489,8 +523,15 @@ def statistical_inefficiency(df, series=None, lower=None, upper=None, return df -def equilibrium_detection(df, series=None, lower=None, upper=None, step=None, - drop_duplicates=False, sort=False): +def equilibrium_detection( + df, + series=None, + lower=None, + upper=None, + step=None, + drop_duplicates=False, + sort=False, +): """Subsample a DataFrame using automated equilibrium detection on a timeseries. This function uses the :mod:`pymbar` implementation of the *simple diff --git a/src/alchemlyb/tests/conftest.py b/src/alchemlyb/tests/conftest.py index b5b485e9..2a9530b1 100644 --- a/src/alchemlyb/tests/conftest.py +++ b/src/alchemlyb/tests/conftest.py @@ -198,7 +198,6 @@ def amber_simplesolvated_charge_dHdl(amber_simplesolvated): @pytest.fixture def amber_simplesolvated_vdw_dHdl(amber_simplesolvated): - return [ amber.extract_dHdl(filename, T=298.0) for filename in amber_simplesolvated["vdw"] diff --git a/src/alchemlyb/tests/parsing/test_amber.py b/src/alchemlyb/tests/parsing/test_amber.py index c1fe137d..0d186cc7 100644 --- a/src/alchemlyb/tests/parsing/test_amber.py +++ b/src/alchemlyb/tests/parsing/test_amber.py @@ -2,28 +2,30 @@ """ import logging + +import pandas as pd import pytest +from alchemtest.amber import load_bace_example +from alchemtest.amber import load_bace_improper +from alchemtest.amber import load_simplesolvated +from alchemtest.amber import load_testfiles from numpy.testing import assert_allclose -import pandas as pd +from alchemlyb.parsing.amber import extract from alchemlyb.parsing.amber import extract_dHdl from alchemlyb.parsing.amber import extract_u_nk -from alchemlyb.parsing.amber import extract -from alchemtest.amber import load_simplesolvated -from alchemtest.amber import load_bace_example -from alchemtest.amber import load_bace_improper -from alchemtest.amber import load_testfiles ################################################################################## ################ Check the parser behaviour with problematic files ################################################################################## + @pytest.fixture(name="testfiles", scope="module") def fixture_testfiles(): - """ Returns the testfiles data dictionary """ + """Returns the testfiles data dictionary""" bunch = load_testfiles() - return bunch['data'] + return bunch["data"] def test_file_not_found(): @@ -77,10 +79,10 @@ def test_no_control_data(caplog, testfiles): def test_no_free_energy_info(caplog, testfiles): """Test if we raise an exception if there is no free energy section""" filename = testfiles["no_free_energy_info"][0] - with pytest.raises(ValueError, match='no free energy section found'): + with pytest.raises(ValueError, match="no free energy section found"): with caplog.at_level(logging.ERROR): _ = extract(str(filename), T=298.0) - assert 'No free energy section found' in caplog.text + assert "No free energy section found" in caplog.text def test_no_useful_data(caplog, testfiles): @@ -89,7 +91,7 @@ def test_no_useful_data(caplog, testfiles): with pytest.raises(ValueError, match="does not contain any data"): with caplog.at_level(logging.ERROR): _ = extract(str(filename), T=298.0) - assert 'does not contain any data' in caplog.text + assert "does not contain any data" in caplog.text def test_no_temp0_set(caplog, testfiles): @@ -119,16 +121,16 @@ def test_long_and_wrong_number_MBAR(caplog, testfiles): with pytest.raises(ValueError, match="the number of lambda windows read"): with caplog.at_level(logging.ERROR): _ = extract_u_nk(str(filename), T=300.0) - assert 'the number of lambda windows read' in caplog.text + assert "the number of lambda windows read" in caplog.text def test_no_starting_time(caplog, testfiles): """Test if raise an exception if the starting time is not read""" filename = testfiles["no_starting_simulation_time"][0] - with pytest.raises(ValueError, match='No starting simulation time in file'): + with pytest.raises(ValueError, match="No starting simulation time in file"): with caplog.at_level(logging.ERROR): _ = extract(str(filename), T=298.0) - assert 'No starting simulation time in file' in caplog.text + assert "No starting simulation time in file" in caplog.text def test_parse_without_spaces_around_equal(testfiles): @@ -138,23 +140,24 @@ def test_parse_without_spaces_around_equal(testfiles): """ filename = testfiles["no_spaces_around_equal"][0] df_dict = extract(str(filename), T=298.0) - assert isinstance(df_dict['dHdl'], pd.DataFrame) + assert isinstance(df_dict["dHdl"], pd.DataFrame) ################################################################################## ################ Check the parser behaviour with standard single files ################################################################################## + @pytest.fixture(name="single_u_nk", scope="module") def fixture_single_u_nk(): """return a single file to check u_unk parsing""" - return load_bace_example().data['complex']['vdw'][0] + return load_bace_example().data["complex"]["vdw"][0] @pytest.fixture(name="single_dHdl", scope="module") def fixture_single_dHdl(): """return a single file to check dHdl parsing""" - return load_simplesolvated().data['charge'][0] + return load_simplesolvated().data["charge"][0] def test_dHdl_time_reading(single_dHdl): @@ -175,18 +178,18 @@ def test_extract_with_both_data(single_u_nk): """Test that dHdl and u_nk have the correct form when extracted from files with the single "extract" funcion.""" df_dict = extract(single_u_nk, T=298.0) - assert df_dict['dHdl'].index.names == ('time', 'lambdas') - assert df_dict['dHdl'].shape == (500, 1) - assert df_dict['u_nk'].index.names == ('time', 'lambdas') + assert df_dict["dHdl"].index.names == ("time", "lambdas") + assert df_dict["dHdl"].shape == (500, 1) + assert df_dict["u_nk"].index.names == ("time", "lambdas") def test_extract_with_only_dhdl_data(single_dHdl): """Test that parsing with the extract function a file - with just dHdl gives the correct results""" + with just dHdl gives the correct results""" df_dict = extract(single_dHdl, T=298.0) - assert df_dict['dHdl'].index.names == ('time', 'lambdas') - assert df_dict['dHdl'].shape == (500, 1) - assert df_dict['u_nk'] is None + assert df_dict["dHdl"].index.names == ("time", "lambdas") + assert df_dict["dHdl"].shape == (500, 1) + assert df_dict["u_nk"] is None def test_wrong_T_should_raise_warning(single_dHdl, T=300.0): @@ -195,24 +198,21 @@ def test_wrong_T_should_raise_warning(single_dHdl, T=300.0): read from the AMBER file gives a warning """ with pytest.raises( - ValueError, - match="is different from the temperature passed as parameter"): + ValueError, match="is different from the temperature passed as parameter" + ): _ = extract(single_dHdl, T=T) - ################################################################### ################ Check the behaviour on proper datasets ################################################################### -@pytest.mark.parametrize("filename", - [filename - for leg in load_simplesolvated()['data'].values() - for filename in leg]) -def test_dHdl(filename, - names=('time', 'lambdas'), - shape=(500, 1)): +@pytest.mark.parametrize( + "filename", + [filename for leg in load_simplesolvated()["data"].values() for filename in leg], +) +def test_dHdl(filename, names=("time", "lambdas"), shape=(500, 1)): """Test that dHdl has the correct form when extracted from files.""" dHdl = extract_dHdl(filename, T=298.0) @@ -220,27 +220,33 @@ def test_dHdl(filename, assert dHdl.shape == shape -@pytest.mark.parametrize("mbar_filename", - [mbar_filename - for leg in load_bace_example()['data']['complex'].values() - for mbar_filename in leg]) -def test_u_nk(mbar_filename, - names=('time', 'lambdas')): +@pytest.mark.parametrize( + "mbar_filename", + [ + mbar_filename + for leg in load_bace_example()["data"]["complex"].values() + for mbar_filename in leg + ], +) +def test_u_nk(mbar_filename, names=("time", "lambdas")): """Test the u_nk has the correct form when extracted from files""" u_nk = extract_u_nk(mbar_filename, T=298.0) assert u_nk.index.names == names -@pytest.mark.parametrize("improper_filename", - [improper_filename - for leg in load_bace_improper()['data'].values() - for improper_filename in leg]) -def test_u_nk_improper(improper_filename, - names=('time', 'lambdas')): +@pytest.mark.parametrize( + "improper_filename", + [ + improper_filename + for leg in load_bace_improper()["data"].values() + for improper_filename in leg + ], +) +def test_u_nk_improper(improper_filename, names=("time", "lambdas")): """Test the u_nk has the correct form when extracted from files""" try: u_nk = extract_u_nk(improper_filename, T=298.0) assert u_nk.index.names == names except Exception: - assert '0.5626' in improper_filename + assert "0.5626" in improper_filename diff --git a/src/alchemlyb/tests/parsing/test_gmx.py b/src/alchemlyb/tests/parsing/test_gmx.py index d85ad1bf..1959c925 100644 --- a/src/alchemlyb/tests/parsing/test_gmx.py +++ b/src/alchemlyb/tests/parsing/test_gmx.py @@ -3,118 +3,144 @@ """ import bz2 -import pytest -from alchemlyb.parsing.gmx import extract_dHdl, extract_u_nk, extract +import pytest from alchemtest.gmx import load_benzene -from alchemtest.gmx import load_expanded_ensemble_case_1, load_expanded_ensemble_case_2, load_expanded_ensemble_case_3 -from alchemtest.gmx import load_water_particle_with_total_energy +from alchemtest.gmx import ( + load_expanded_ensemble_case_1, + load_expanded_ensemble_case_2, + load_expanded_ensemble_case_3, +) from alchemtest.gmx import load_water_particle_with_potential_energy +from alchemtest.gmx import load_water_particle_with_total_energy from alchemtest.gmx import load_water_particle_without_energy from numpy.testing import assert_almost_equal +from alchemlyb.parsing.gmx import extract_dHdl, extract_u_nk, extract + def test_dHdl(): - """Test that dHdl has the correct form when extracted from files. - - """ + """Test that dHdl has the correct form when extracted from files.""" dataset = load_benzene() - for leg in dataset['data']: - for filename in dataset['data'][leg]: + for leg in dataset["data"]: + for filename in dataset["data"][leg]: dHdl = extract_dHdl(filename, T=300) - assert dHdl.index.names == ['time', 'fep-lambda'] + assert dHdl.index.names == ["time", "fep-lambda"] assert dHdl.shape == (4001, 1) -def test_u_nk(): - """Test that u_nk has the correct form when extracted from files. - """ +def test_u_nk(): + """Test that u_nk has the correct form when extracted from files.""" dataset = load_benzene() - for leg in dataset['data']: - for filename in dataset['data'][leg]: + for leg in dataset["data"]: + for filename in dataset["data"][leg]: u_nk = extract_u_nk(filename, T=300) - assert u_nk.index.names == ['time', 'fep-lambda'] - if leg == 'Coulomb': + assert u_nk.index.names == ["time", "fep-lambda"] + if leg == "Coulomb": assert u_nk.shape == (4001, 5) - elif leg == 'VDW': + elif leg == "VDW": assert u_nk.shape == (4001, 16) -def test_u_nk_case1(): - """Test that u_nk has the correct form when extracted from expanded ensemble files (case 1). - """ +def test_u_nk_case1(): + """Test that u_nk has the correct form when extracted from expanded ensemble files (case 1).""" dataset = load_expanded_ensemble_case_1() - for leg in dataset['data']: - for filename in dataset['data'][leg]: + for leg in dataset["data"]: + for filename in dataset["data"][leg]: u_nk = extract_u_nk(filename, T=300, filter=False) - assert u_nk.index.names == ['time', 'fep-lambda', 'coul-lambda', 'vdw-lambda', 'restraint-lambda'] + assert u_nk.index.names == [ + "time", + "fep-lambda", + "coul-lambda", + "vdw-lambda", + "restraint-lambda", + ] assert u_nk.shape == (50001, 28) -def test_dHdl_case1(): - """Test that dHdl has the correct form when extracted from expanded ensemble files (case 1). - """ +def test_dHdl_case1(): + """Test that dHdl has the correct form when extracted from expanded ensemble files (case 1).""" dataset = load_expanded_ensemble_case_1() - for leg in dataset['data']: - for filename in dataset['data'][leg]: + for leg in dataset["data"]: + for filename in dataset["data"][leg]: dHdl = extract_dHdl(filename, T=300, filter=False) - assert dHdl.index.names == ['time', 'fep-lambda', 'coul-lambda', 'vdw-lambda', 'restraint-lambda'] + assert dHdl.index.names == [ + "time", + "fep-lambda", + "coul-lambda", + "vdw-lambda", + "restraint-lambda", + ] assert dHdl.shape == (50001, 4) -def test_u_nk_case2(): - """Test that u_nk has the correct form when extracted from expanded ensemble files (case 2). - """ +def test_u_nk_case2(): + """Test that u_nk has the correct form when extracted from expanded ensemble files (case 2).""" dataset = load_expanded_ensemble_case_2() - for leg in dataset['data']: - for filename in dataset['data'][leg]: + for leg in dataset["data"]: + for filename in dataset["data"][leg]: u_nk = extract_u_nk(filename, T=300, filter=False) - assert u_nk.index.names == ['time', 'fep-lambda', 'coul-lambda', 'vdw-lambda', 'restraint-lambda'] + assert u_nk.index.names == [ + "time", + "fep-lambda", + "coul-lambda", + "vdw-lambda", + "restraint-lambda", + ] assert u_nk.shape == (25001, 28) -def test_u_nk_case3(): - """Test that u_nk has the correct form when extracted from REX files (case 3). - """ +def test_u_nk_case3(): + """Test that u_nk has the correct form when extracted from REX files (case 3).""" dataset = load_expanded_ensemble_case_3() - for leg in dataset['data']: - for filename in dataset['data'][leg]: + for leg in dataset["data"]: + for filename in dataset["data"][leg]: u_nk = extract_u_nk(filename, T=300, filter=False) - assert u_nk.index.names == ['time', 'fep-lambda', 'coul-lambda', 'vdw-lambda', 'restraint-lambda'] + assert u_nk.index.names == [ + "time", + "fep-lambda", + "coul-lambda", + "vdw-lambda", + "restraint-lambda", + ] assert u_nk.shape == (2500, 28) -def test_dHdl_case3(): - """Test that dHdl has the correct form when extracted from REX files (case 3). - """ +def test_dHdl_case3(): + """Test that dHdl has the correct form when extracted from REX files (case 3).""" dataset = load_expanded_ensemble_case_3() - for leg in dataset['data']: - for filename in dataset['data'][leg]: + for leg in dataset["data"]: + for filename in dataset["data"][leg]: dHdl = extract_dHdl(filename, T=300, filter=False) - assert dHdl.index.names == ['time', 'fep-lambda', 'coul-lambda', 'vdw-lambda', 'restraint-lambda'] + assert dHdl.index.names == [ + "time", + "fep-lambda", + "coul-lambda", + "vdw-lambda", + "restraint-lambda", + ] assert dHdl.shape == (2500, 4) -def test_u_nk_with_total_energy(): - """Test that the reduced potential is calculated correctly when the total energy is given. - """ +def test_u_nk_with_total_energy(): + """Test that the reduced potential is calculated correctly when the total energy is given.""" # Load dataset dataset = load_water_particle_with_total_energy() @@ -124,15 +150,16 @@ def test_u_nk_with_total_energy(): # Check one specific value in the dataframe assert_almost_equal( - extract_u_nk(dataset['data']['AllStates'][0], T=300).loc[0][(0.0,0.0)].values[0], + extract_u_nk(dataset["data"]["AllStates"][0], T=300) + .loc[0][(0.0, 0.0)] + .values[0], -11211.577658852531, - decimal=6 + decimal=6, ) -def test_u_nk_with_potential_energy(): - """Test that the reduced potential is calculated correctly when the potential energy is given. - """ +def test_u_nk_with_potential_energy(): + """Test that the reduced potential is calculated correctly when the potential energy is given.""" # Load dataset dataset = load_water_particle_with_potential_energy() @@ -142,16 +169,16 @@ def test_u_nk_with_potential_energy(): # Check one specific value in the dataframe assert_almost_equal( - extract_u_nk(dataset['data']['AllStates'][0], T=300).loc[0][(0.0,0.0)].values[0], + extract_u_nk(dataset["data"]["AllStates"][0], T=300) + .loc[0][(0.0, 0.0)] + .values[0], -15656.557252200757, - decimal=6 + decimal=6, ) def test_u_nk_without_energy(): - """Test that the reduced potential is calculated correctly when no energy is given. - - """ + """Test that the reduced potential is calculated correctly when no energy is given.""" # Load dataset dataset = load_water_particle_without_energy() @@ -161,105 +188,114 @@ def test_u_nk_without_energy(): # Check one specific value in the dataframe assert_almost_equal( - extract_u_nk(dataset['data']['AllStates'][0], T=300).loc[0][(0.0,0.0)].values[0], + extract_u_nk(dataset["data"]["AllStates"][0], T=300) + .loc[0][(0.0, 0.0)] + .values[0], 0.0, - decimal=6 + decimal=6, ) def _diag_sum(dataset): - """Calculate the sum of diagonal elements (i, i) - - """ + """Calculate the sum of diagonal elements (i, i)""" # Initialize the sum variable ds = 0.0 - for leg in dataset['data']: - for filename in dataset['data'][leg]: + for leg in dataset["data"]: + for filename in dataset["data"][leg]: u_nk = extract_u_nk(filename, T=300) # Calculate the sum of diagonal elements: for i, lambda_ in enumerate(u_nk.columns): - #18.6 is the time step - ds += u_nk.loc[i*186/10][lambda_].values[0] + # 18.6 is the time step + ds += u_nk.loc[i * 186 / 10][lambda_].values[0] return ds + def test_extract_u_nk_unit(): - '''Test if extract_u_nk assign the attr correctly''' + """Test if extract_u_nk assign the attr correctly""" dataset = load_benzene() - u_nk = extract_u_nk(dataset['data']['Coulomb'][0], 310) - assert u_nk.attrs['temperature'] == 310 - assert u_nk.attrs['energy_unit'] == 'kT' + u_nk = extract_u_nk(dataset["data"]["Coulomb"][0], 310) + assert u_nk.attrs["temperature"] == 310 + assert u_nk.attrs["energy_unit"] == "kT" + def test_extract_dHdl_unit(): - '''Test if extract_u_nk assign the attr correctly''' + """Test if extract_u_nk assign the attr correctly""" dataset = load_benzene() - dhdl = extract_dHdl(dataset['data']['Coulomb'][0], 310) - assert dhdl.attrs['temperature'] == 310 - assert dhdl.attrs['energy_unit'] == 'kT' + dhdl = extract_dHdl(dataset["data"]["Coulomb"][0], 310) + assert dhdl.attrs["temperature"] == 310 + assert dhdl.attrs["energy_unit"] == "kT" + def test_calling_extract(): - '''Test if the extract function is working''' + """Test if the extract function is working""" dataset = load_benzene() - df_dict = extract(dataset['data']['Coulomb'][0], 310) - assert df_dict['dHdl'].attrs['temperature'] == 310 - assert df_dict['dHdl'].attrs['energy_unit'] == 'kT' - assert df_dict['u_nk'].attrs['temperature'] == 310 - assert df_dict['u_nk'].attrs['energy_unit'] == 'kT' - -class TestRobustGMX(): - '''Test dropping the row that is wrong in different way''' + df_dict = extract(dataset["data"]["Coulomb"][0], 310) + assert df_dict["dHdl"].attrs["temperature"] == 310 + assert df_dict["dHdl"].attrs["energy_unit"] == "kT" + assert df_dict["u_nk"].attrs["temperature"] == 310 + assert df_dict["u_nk"].attrs["energy_unit"] == "kT" + + +class TestRobustGMX: + """Test dropping the row that is wrong in different way""" + @staticmethod - @pytest.fixture(scope='class') + @pytest.fixture(scope="class") def data(): - dhdl = extract_dHdl(load_benzene()['data']['Coulomb'][0], 310) - with bz2.open(load_benzene()['data']['Coulomb'][0], "rt") as bz_file: + dhdl = extract_dHdl(load_benzene()["data"]["Coulomb"][0], 310) + with bz2.open(load_benzene()["data"]["Coulomb"][0], "rt") as bz_file: text = bz_file.read() return text, len(dhdl) def test_sanity(self, data, tmp_path): - '''Test if the test routine is working.''' + """Test if the test routine is working.""" text, length = data - new_text = tmp_path / 'text.xvg' + new_text = tmp_path / "text.xvg" new_text.write_text(text) dhdl = extract_dHdl(new_text, 310) assert len(dhdl) == length def test_truncated_row(self, data, tmp_path): - '''Test the case where the last row has been truncated.''' + """Test the case where the last row has been truncated.""" text, length = data - new_text = tmp_path / 'text.xvg' - new_text.write_text(text + '40010.0 27.0\n') + new_text = tmp_path / "text.xvg" + new_text.write_text(text + "40010.0 27.0\n") dhdl = extract_dHdl(new_text, 310, filter=True) assert len(dhdl) == length def test_truncated_number(self, data, tmp_path): - '''Test the case where the last row has been truncated and a - has - been left.''' + """Test the case where the last row has been truncated and a - has + been left.""" text, length = data - new_text = tmp_path / 'text.xvg' - new_text.write_text(text + '40010.0 27.0 -\n') + new_text = tmp_path / "text.xvg" + new_text.write_text(text + "40010.0 27.0 -\n") dhdl = extract_dHdl(new_text, 310, filter=True) assert len(dhdl) == length def test_weirdnumber(self, data, tmp_path): - '''Test the case where the last number has been appended a weird - number.''' + """Test the case where the last number has been appended a weird + number.""" text, length = data - new_text = tmp_path / 'text.xvg' + new_text = tmp_path / "text.xvg" # Note the 27.040010.0 which is the sum of 27.0 and 40010.0 - new_text.write_text(text + '40010.0 27.040010.0 27.0 0.0 6.7 13.5 20.2 27.0 0.7 27.0 0.0 6.7 ' - '13.5 20.2 27.0 0.7\n') + new_text.write_text( + text + "40010.0 27.040010.0 27.0 0.0 6.7 13.5 20.2 27.0 0.7 27.0 0.0 6.7 " + "13.5 20.2 27.0 0.7\n" + ) dhdl = extract_dHdl(new_text, 310, filter=True) assert len(dhdl) == length def test_too_many_cols(self, data, tmp_path): - '''Test the case where the row has too many columns.''' + """Test the case where the row has too many columns.""" text, length = data - new_text = tmp_path / 'text.xvg' - new_text.write_text(text + - '40010.0 27.0 0.0 6.7 13.5 20.2 27.0 0.7 27.0 0.0 6.7 13.5 20.2 27.0 0.7\n') + new_text = tmp_path / "text.xvg" + new_text.write_text( + text + + "40010.0 27.0 0.0 6.7 13.5 20.2 27.0 0.7 27.0 0.0 6.7 13.5 20.2 27.0 0.7\n" + ) dhdl = extract_dHdl(new_text, 310, filter=True) assert len(dhdl) == length diff --git a/src/alchemlyb/tests/parsing/test_gomc.py b/src/alchemlyb/tests/parsing/test_gomc.py index d61b3789..241add45 100644 --- a/src/alchemlyb/tests/parsing/test_gomc.py +++ b/src/alchemlyb/tests/parsing/test_gomc.py @@ -2,43 +2,40 @@ """ -from alchemlyb.parsing.gomc import extract_dHdl, extract_u_nk, extract from alchemtest.gomc import load_benzene +from alchemlyb.parsing.gomc import extract_dHdl, extract_u_nk, extract + def test_dHdl(): - """Test that dHdl has the correct form when extracted from files. - - """ + """Test that dHdl has the correct form when extracted from files.""" dataset = load_benzene() - for filename in dataset['data']: + for filename in dataset["data"]: dHdl = extract_dHdl(filename, T=298) - assert dHdl.index.names == ['time', 'Coulomb-lambda', 'VDW-lambda'] + assert dHdl.index.names == ["time", "Coulomb-lambda", "VDW-lambda"] assert dHdl.shape == (1000, 2) -def test_u_nk(): - """Test that u_nk has the correct form when extracted from files. - """ +def test_u_nk(): + """Test that u_nk has the correct form when extracted from files.""" dataset = load_benzene() - for filename in dataset['data']: + for filename in dataset["data"]: u_nk = extract_u_nk(filename, T=298) - assert u_nk.index.names == ['time', 'Coulomb-lambda', 'VDW-lambda'] + assert u_nk.index.names == ["time", "Coulomb-lambda", "VDW-lambda"] assert u_nk.shape == (1000, 23) -def test_extract(): - """Test that u_nk and dHdl have the correct form when extracted from files. - """ +def test_extract(): + """Test that u_nk and dHdl have the correct form when extracted from files.""" dataset = load_benzene() - df_dict = extract(dataset['data'][0], T=298) + df_dict = extract(dataset["data"][0], T=298) - assert df_dict['u_nk'].index.names == ['time', 'Coulomb-lambda', 'VDW-lambda'] - assert df_dict['u_nk'].shape == (1000, 23) - assert df_dict['dHdl'].index.names == ['time', 'Coulomb-lambda', 'VDW-lambda'] - assert df_dict['dHdl'].shape == (1000, 2) + assert df_dict["u_nk"].index.names == ["time", "Coulomb-lambda", "VDW-lambda"] + assert df_dict["u_nk"].shape == (1000, 23) + assert df_dict["dHdl"].index.names == ["time", "Coulomb-lambda", "VDW-lambda"] + assert df_dict["dHdl"].shape == (1000, 2) diff --git a/src/alchemlyb/tests/parsing/test_namd.py b/src/alchemlyb/tests/parsing/test_namd.py index 8c4c1858..f5168e8f 100644 --- a/src/alchemlyb/tests/parsing/test_namd.py +++ b/src/alchemlyb/tests/parsing/test_namd.py @@ -1,16 +1,17 @@ """NAMD parser tests. """ +import bz2 from os.path import basename from re import search -import bz2 -import pytest -from alchemlyb.parsing.namd import extract_u_nk, extract -from alchemtest.namd import load_tyr2ala +import pytest from alchemtest.namd import load_idws from alchemtest.namd import load_restarted from alchemtest.namd import load_restarted_reversed +from alchemtest.namd import load_tyr2ala + +from alchemlyb.parsing.namd import extract_u_nk, extract # Indices of lambda values in the following line in NAMD fepout files: # #NEW FEP WINDOW: LAMBDA SET TO 0.6 LAMBDA2 0.7 LAMBDA_IDWS 0.5 @@ -27,27 +28,30 @@ def dataset(): return load_tyr2ala() -@pytest.mark.parametrize("direction,shape", - [('forward', (21021, 21)), - ('backward', (21021, 21)), - ]) + +@pytest.mark.parametrize( + "direction,shape", + [ + ("forward", (21021, 21)), + ("backward", (21021, 21)), + ], +) def test_u_nk(dataset, direction, shape): - """Test that u_nk has the correct form when extracted from files. - """ - for filename in dataset['data'][direction]: + """Test that u_nk has the correct form when extracted from files.""" + for filename in dataset["data"][direction]: u_nk = extract_u_nk(filename, T=300) - assert u_nk.index.names == ['time', 'fep-lambda'] + assert u_nk.index.names == ["time", "fep-lambda"] assert u_nk.shape == shape + def test_u_nk_idws(): - """Test that u_nk has the correct form when extracted from files. - """ + """Test that u_nk has the correct form when extracted from files.""" - filenames = load_idws()['data']['forward'] + filenames = load_idws()["data"]["forward"] u_nk = extract_u_nk(filenames, T=300) - assert u_nk.index.names == ['time', 'fep-lambda'] + assert u_nk.index.names == ["time", "fep-lambda"] assert u_nk.shape == (29252, 11) @@ -64,7 +68,7 @@ def _corrupt_fepout(fepout_in, params, tmp_path): ---------- fepout_in: str Path to fepout file to be modified. This file will not be overwritten. - + params: list of tuples For each tuple, the first element must be a str that will be passed to startswith() to identify the line(s) to modify (e.g. "#NEW"). The @@ -82,13 +86,17 @@ def _corrupt_fepout(fepout_in, params, tmp_path): """ fepout_out = tmp_path / basename(fepout_in) - with bz2.open(fepout_out, 'wt') as f_out: - with bz2.open(fepout_in, 'rt') as f_in: + with bz2.open(fepout_out, "wt") as f_out: + with bz2.open(fepout_in, "rt") as f_in: for line in f_in: for prefix, func in params: if line.startswith(prefix): tokens_out = func(line.split()) - line = ' '.join(tokens_out) + '\n' if tokens_out is not None else None + line = ( + " ".join(tokens_out) + "\n" + if tokens_out is not None + else None + ) if line is not None: f_out.write(line) return str(fepout_out) @@ -99,9 +107,10 @@ def restarted_dataset_inconsistent(restarted_dataset, tmp_path): """Returns intentionally messed up dataset where lambda1 and lambda2 at start and end of a window are different.""" - filenames = sorted(restarted_dataset['data']['both']) + filenames = sorted(restarted_dataset["data"]["both"]) changed = False + def func_free_line(l): nonlocal changed if float(l[7]) >= 0.7 and float(l[7]) < 0.9: @@ -110,13 +119,15 @@ def func_free_line(l): return l for i in range(len(filenames)): - filenames[i] = _corrupt_fepout(filenames[i], [('#Free', func_free_line)], tmp_path) + filenames[i] = _corrupt_fepout( + filenames[i], [("#Free", func_free_line)], tmp_path + ) # Only actually modify one window so we don't trigger the wrong exception if changed is True: break # Don't directly modify the glob object - restarted_dataset['data']['both'] = filenames + restarted_dataset["data"]["both"] = filenames return restarted_dataset @@ -129,26 +140,32 @@ def restarted_dataset_idws_without_lambda_idws(restarted_dataset, tmp_path): # First window won't have any IDWS data so we just drop all its files and fudge the lambdas # in the next window to include 0.0 or 1.0 (as appropriate) so we still have a nominally complete calculation - - filenames = [x for x in sorted(restarted_dataset['data']['both']) if search('000[a-z]?.fepout', x) is None] + + filenames = [ + x + for x in sorted(restarted_dataset["data"]["both"]) + if search("000[a-z]?.fepout", x) is None + ] def func_new_line(l): - if float(l[LAMBDA1_IDX_NEW]) > 0.5: # 1->0 (reversed) calculation - l[LAMBDA1_IDX_NEW] == '1.0' - else: # regular 0->1 calculation - l[LAMBDA1_IDX_NEW] = '0.0' + if float(l[LAMBDA1_IDX_NEW]) > 0.5: # 1->0 (reversed) calculation + l[LAMBDA1_IDX_NEW] == "1.0" + else: # regular 0->1 calculation + l[LAMBDA1_IDX_NEW] = "0.0" # Drop the lambda_idws return l[:9] - + def func_free_line(l): - if float(l[LAMBDA1_IDX_FREE]) > 0.5: # 1->0 (reversed) calculation - l[LAMBDA1_IDX_FREE] == '1.0' - else: # regular 0->1 calculation - l[LAMBDA1_IDX_FREE] = '0.0' + if float(l[LAMBDA1_IDX_FREE]) > 0.5: # 1->0 (reversed) calculation + l[LAMBDA1_IDX_FREE] == "1.0" + else: # regular 0->1 calculation + l[LAMBDA1_IDX_FREE] = "0.0" return l - - filenames[0] = _corrupt_fepout(filenames[0], [('#NEW', func_new_line), ('#Free', func_free_line)], tmp_path) - restarted_dataset['data']['both'] = filenames + + filenames[0] = _corrupt_fepout( + filenames[0], [("#NEW", func_new_line), ("#Free", func_free_line)], tmp_path + ) + restarted_dataset["data"]["both"] = filenames return restarted_dataset @@ -157,7 +174,7 @@ def restarted_dataset_toomany_lambda2(restarted_dataset, tmp_path): """Returns intentionally messed up dataset, where there are too many lambda2 values for a given lambda1.""" - filenames = sorted(restarted_dataset['data']['both']) + filenames = sorted(restarted_dataset["data"]["both"]) # For the same l1 and lidws we retain old lambda2 values thus ensuring a collision # Also, don't make a window where lambda1 >= lambda2 because this will trigger the @@ -165,22 +182,23 @@ def restarted_dataset_toomany_lambda2(restarted_dataset, tmp_path): def func_new_line(l): if float(l[LAMBDA2_IDX_NEW]) <= 0.2: return l - l[LAMBDA1_IDX_NEW] = '0.2' - if len(l) > 9 and l[9] == 'LAMBDA_IDWS': - l[LAMBDA_IDWS_IDX_NEW] = '0.1' + l[LAMBDA1_IDX_NEW] = "0.2" + if len(l) > 9 and l[9] == "LAMBDA_IDWS": + l[LAMBDA_IDWS_IDX_NEW] = "0.1" return l def func_free_line(l): if float(l[LAMBDA2_IDX_FREE]) <= 0.2: return l - l[LAMBDA1_IDX_FREE] = '0.2' + l[LAMBDA1_IDX_FREE] = "0.2" return l for i in range(len(filenames)): - filenames[i] = \ - _corrupt_fepout(filenames[i], [('#NEW', func_new_line), ('#Free', func_free_line)], tmp_path) + filenames[i] = _corrupt_fepout( + filenames[i], [("#NEW", func_new_line), ("#Free", func_free_line)], tmp_path + ) - restarted_dataset['data']['both'] = filenames + restarted_dataset["data"]["both"] = filenames return restarted_dataset @@ -189,7 +207,7 @@ def restarted_dataset_toomany_lambda_idws(restarted_dataset, tmp_path): """Returns intentionally messed up dataset, where there are too many lambda2 values for a given lambda1.""" - filenames = sorted(restarted_dataset['data']['both']) + filenames = sorted(restarted_dataset["data"]["both"]) # For the same lambda1 and lambda2 we retain the first set of lambda1/lambda2 values # and replicate them across all windows thus ensuring that there will be more than @@ -198,7 +216,7 @@ def restarted_dataset_toomany_lambda_idws(restarted_dataset, tmp_path): def func_new_line(l): nonlocal this_lambda1, this_lambda2 - + if this_lambda1 is None: this_lambda1, this_lambda2 = l[LAMBDA1_IDX_NEW], l[LAMBDA2_IDX_NEW] # Ensure that changing these lambda values won't cause a reversal in direction and trigger @@ -212,9 +230,11 @@ def func_free_line(l): return l for i in range(len(filenames)): - filenames[i] = _corrupt_fepout(filenames[i], [('#NEW', func_new_line)], tmp_path) + filenames[i] = _corrupt_fepout( + filenames[i], [("#NEW", func_new_line)], tmp_path + ) - restarted_dataset['data']['both'] = filenames + restarted_dataset["data"]["both"] = filenames return restarted_dataset @@ -222,7 +242,7 @@ def func_free_line(l): def restarted_dataset_direction_changed(restarted_dataset, tmp_path): """Returns intentionally messed up dataset, with one window where the lambda values are reversed.""" - filenames = sorted(restarted_dataset['data']['both']) + filenames = sorted(restarted_dataset["data"]["both"]) def func_new_line(l): l[6], l[8], l[10] = l[10], l[8], l[6] @@ -231,12 +251,16 @@ def func_new_line(l): def func_free_line(l): l[7], l[8] = l[8], l[7] return l - + # Reverse the direction of lambdas for this window idx_to_corrupt = filenames.index(sorted(filenames)[-3]) - fname1 = _corrupt_fepout(filenames[idx_to_corrupt], [('#NEW', func_new_line), ('#Free', func_free_line)], tmp_path) + fname1 = _corrupt_fepout( + filenames[idx_to_corrupt], + [("#NEW", func_new_line), ("#Free", func_free_line)], + tmp_path, + ) filenames[idx_to_corrupt] = fname1 - restarted_dataset['data']['both'] = filenames + restarted_dataset["data"]["both"] = filenames return restarted_dataset @@ -244,15 +268,17 @@ def func_free_line(l): def restarted_dataset_all_windows_truncated(restarted_dataset, tmp_path): """Returns dataset where all windows are truncated (no #Free... footer lines).""" - filenames = sorted(restarted_dataset['data']['both']) + filenames = sorted(restarted_dataset["data"]["both"]) def func_free_line(l): return None for i in range(len(filenames)): - filenames[i] = _corrupt_fepout(filenames[i], [('#Free', func_free_line)], tmp_path) - - restarted_dataset['data']['both'] = filenames + filenames[i] = _corrupt_fepout( + filenames[i], [("#Free", func_free_line)], tmp_path + ) + + restarted_dataset["data"]["both"] = filenames return restarted_dataset @@ -260,13 +286,15 @@ def func_free_line(l): def restarted_dataset_last_window_truncated(restarted_dataset, tmp_path): """Returns dataset where the last window is truncated (no #Free... footer line).""" - filenames = sorted(restarted_dataset['data']['both']) + filenames = sorted(restarted_dataset["data"]["both"]) def func_free_line(l): return None - filenames[-1] = _corrupt_fepout(filenames[-1], [('#Free', func_free_line)], tmp_path) - restarted_dataset['data']['both'] = filenames + filenames[-1] = _corrupt_fepout( + filenames[-1], [("#Free", func_free_line)], tmp_path + ) + restarted_dataset["data"]["both"] = filenames return restarted_dataset @@ -274,72 +302,91 @@ def test_u_nk_restarted(): """Test that u_nk has the correct form when extracted from an IDWS FEP run that includes terminations and restarts. """ - filenames = load_restarted()['data']['both'] + filenames = load_restarted()["data"]["both"] u_nk = extract_u_nk(filenames, T=300) - - assert u_nk.index.names == ['time', 'fep-lambda'] + + assert u_nk.index.names == ["time", "fep-lambda"] assert u_nk.shape == (30061, 11) def test_u_nk_restarted_missing_window_header(tmp_path): """Test that u_nk has the correct form when a #NEW line is missing from the restarted dataset and the parser has to infer lambda_idws for that window.""" - filenames = sorted(load_restarted()['data']['both']) + filenames = sorted(load_restarted()["data"]["both"]) # Remove "#NEW" line - filenames[4] = _corrupt_fepout(filenames[4], [('#NEW', lambda l: None),], tmp_path) + filenames[4] = _corrupt_fepout( + filenames[4], + [ + ("#NEW", lambda l: None), + ], + tmp_path, + ) u_nk = extract_u_nk(filenames, T=300) - - assert u_nk.index.names == ['time', 'fep-lambda'] + + assert u_nk.index.names == ["time", "fep-lambda"] assert u_nk.shape == (30061, 11) def test_u_nk_restarted_reversed(): - filenames = load_restarted_reversed()['data']['both'] + filenames = load_restarted_reversed()["data"]["both"] u_nk = extract_u_nk(filenames, T=300) - - assert u_nk.index.names == ['time', 'fep-lambda'] + + assert u_nk.index.names == ["time", "fep-lambda"] assert u_nk.shape == (30170, 11) def test_extract(): - filenames = load_restarted_reversed()['data']['both'] + filenames = load_restarted_reversed()["data"]["both"] df_dict = extract(filenames, T=300) - assert df_dict['u_nk'].index.names == ['time', 'fep-lambda'] - assert df_dict['u_nk'].shape == (30170, 11) - assert 'dHdl' not in df_dict + assert df_dict["u_nk"].index.names == ["time", "fep-lambda"] + assert df_dict["u_nk"].shape == (30170, 11) + assert "dHdl" not in df_dict def test_u_nk_restarted_reversed_missing_window_header(tmp_path): """Test that u_nk has the correct form when a #NEW line is missing from the restarted_reversed dataset and the parser has to infer lambda_idws for that window.""" - filenames = sorted(load_restarted_reversed()['data']['both']) + filenames = sorted(load_restarted_reversed()["data"]["both"]) # Remove "#NEW" line - filenames[4] = _corrupt_fepout(filenames[4], [('#NEW', lambda l: None),], tmp_path) + filenames[4] = _corrupt_fepout( + filenames[4], + [ + ("#NEW", lambda l: None), + ], + tmp_path, + ) u_nk = extract_u_nk(filenames, T=300) - - assert u_nk.index.names == ['time', 'fep-lambda'] + + assert u_nk.index.names == ["time", "fep-lambda"] assert u_nk.shape == (30170, 11) def test_u_nk_restarted_direction_changed(restarted_dataset_direction_changed): """Test that when lambda values change direction within a dataset, parsing throws an error.""" - with pytest.raises(ValueError, match='Lambda values change direction'): - u_nk = extract_u_nk(restarted_dataset_direction_changed['data']['both'], T=300) + with pytest.raises(ValueError, match="Lambda values change direction"): + u_nk = extract_u_nk(restarted_dataset_direction_changed["data"]["both"], T=300) -def test_u_nk_restarted_idws_without_lambda_idws(restarted_dataset_idws_without_lambda_idws): +def test_u_nk_restarted_idws_without_lambda_idws( + restarted_dataset_idws_without_lambda_idws, +): """Test that when the first window has IDWS data but no lambda_idws, parsing throws an error. - + In this situation, the lambda_idws cannot be inferred, because there's no previous lambda value available. """ - with pytest.raises(ValueError, match='IDWS data present in first window but lambda_idws not included'): - u_nk = extract_u_nk(restarted_dataset_idws_without_lambda_idws['data']['both'], T=300) + with pytest.raises( + ValueError, + match="IDWS data present in first window but lambda_idws not included", + ): + u_nk = extract_u_nk( + restarted_dataset_idws_without_lambda_idws["data"]["both"], T=300 + ) def test_u_nk_restarted_inconsistent(restarted_dataset_inconsistent): @@ -347,33 +394,45 @@ def test_u_nk_restarted_inconsistent(restarted_dataset_inconsistent): parsing throws an error. """ - with pytest.raises(ValueError, match='Inconsistent lambda values within the same window'): - u_nk = extract_u_nk(restarted_dataset_inconsistent['data']['both'], T=300) + with pytest.raises( + ValueError, match="Inconsistent lambda values within the same window" + ): + u_nk = extract_u_nk(restarted_dataset_inconsistent["data"]["both"], T=300) def test_u_nk_restarted_toomany_lambda_idws(restarted_dataset_toomany_lambda_idws): """Test that when there is more than one lambda_idws for a given lambda1, parsing throws an error.""" - with pytest.raises(ValueError, match='More than one lambda_idws value for a particular lambda1'): - u_nk = extract_u_nk(restarted_dataset_toomany_lambda_idws['data']['both'], T=300) + with pytest.raises( + ValueError, match="More than one lambda_idws value for a particular lambda1" + ): + u_nk = extract_u_nk( + restarted_dataset_toomany_lambda_idws["data"]["both"], T=300 + ) def test_u_nk_restarted_toomany_lambda2(restarted_dataset_toomany_lambda2): """Test that when there is more than one lambda2 for a given lambda1, parsing throws an error.""" - with pytest.raises(ValueError, match='More than one lambda2 value for a particular lambda1'): - u_nk = extract_u_nk(restarted_dataset_toomany_lambda2['data']['both'], T=300) + with pytest.raises( + ValueError, match="More than one lambda2 value for a particular lambda1" + ): + u_nk = extract_u_nk(restarted_dataset_toomany_lambda2["data"]["both"], T=300) def test_u_nk_restarted_all_windows_truncated(restarted_dataset_all_windows_truncated): """Test that when there is more than one lambda2 for a given lambda1, parsing throws an error.""" - with pytest.raises(ValueError, match='New window begun after truncated window'): - u_nk = extract_u_nk(restarted_dataset_all_windows_truncated['data']['both'], T=300) + with pytest.raises(ValueError, match="New window begun after truncated window"): + u_nk = extract_u_nk( + restarted_dataset_all_windows_truncated["data"]["both"], T=300 + ) def test_u_nk_restarted_last_window_truncated(restarted_dataset_last_window_truncated): """Test that when there is more than one lambda2 for a given lambda1, parsing throws an error.""" - with pytest.raises(ValueError, match='Last window is truncated'): - u_nk = extract_u_nk(restarted_dataset_last_window_truncated['data']['both'], T=300) + with pytest.raises(ValueError, match="Last window is truncated"): + u_nk = extract_u_nk( + restarted_dataset_last_window_truncated["data"]["both"], T=300 + ) diff --git a/src/alchemlyb/tests/parsing/test_util.py b/src/alchemlyb/tests/parsing/test_util.py index 334ee0c2..85107d61 100644 --- a/src/alchemlyb/tests/parsing/test_util.py +++ b/src/alchemlyb/tests/parsing/test_util.py @@ -1,32 +1,29 @@ import io -import pytest +import pytest from alchemtest.gmx import load_expanded_ensemble_case_1 + from alchemlyb.parsing.util import anyopen def test_gzip(): - """Test that gzip reads .gz files in the correct (text) mode. - - """ + """Test that gzip reads .gz files in the correct (text) mode.""" dataset = load_expanded_ensemble_case_1() - for leg in dataset['data']: - for filename in dataset['data'][leg]: - with anyopen(filename, 'r') as f: + for leg in dataset["data"]: + for filename in dataset["data"][leg]: + with anyopen(filename, "r") as f: assert type(f.readline()) is str def test_gzip_stream(): - """Test that `anyopen` reads streams with specified compression. - - """ + """Test that `anyopen` reads streams with specified compression.""" dataset = load_expanded_ensemble_case_1() - for leg in dataset['data']: - for filename in dataset['data'][leg]: - with open(filename, 'rb') as f: - with anyopen(f, mode='r', compression='gzip') as f_uc: + for leg in dataset["data"]: + for filename in dataset["data"][leg]: + with open(filename, "rb") as f: + with anyopen(f, mode="r", compression="gzip") as f_uc: assert type(f_uc.readline()) is str @@ -37,11 +34,11 @@ def test_gzip_stream_wrong(): """ dataset = load_expanded_ensemble_case_1() - for leg in dataset['data']: - for filename in dataset['data'][leg]: - with open(filename, 'rb') as f: - with anyopen(f, mode='r', compression='bzip2') as f_uc: - with pytest.raises(OSError, match='Invalid data stream'): + for leg in dataset["data"]: + for filename in dataset["data"][leg]: + with open(filename, "rb") as f: + with anyopen(f, mode="r", compression="bzip2") as f_uc: + with pytest.raises(OSError, match="Invalid data stream"): assert type(f_uc.readline()) is str @@ -52,33 +49,30 @@ def test_gzip_stream_wrong_no_compression(): """ dataset = load_expanded_ensemble_case_1() - for leg in dataset['data']: - for filename in dataset['data'][leg]: - with open(filename, 'rb') as f: - with anyopen(f, mode='r') as f_uc: + for leg in dataset["data"]: + for filename in dataset["data"][leg]: + with open(filename, "rb") as f: + with anyopen(f, mode="r") as f_uc: assert type(f_uc.readline()) is bytes -@pytest.mark.parametrize('extension', ['bz2', 'gz']) +@pytest.mark.parametrize("extension", ["bz2", "gz"]) def test_file_roundtrip(extension, tmp_path): - """Test that roundtripping write/read to a file works with `anyopen`. - - """ + """Test that roundtripping write/read to a file works with `anyopen`.""" data = "my momma told me to pick the very best one and you are not it" - filepath = tmp_path / f'testfile.txt.{extension}' - with anyopen(filepath, mode='w') as f: + filepath = tmp_path / f"testfile.txt.{extension}" + with anyopen(filepath, mode="w") as f: f.write(data) - with anyopen(filepath, 'r') as f: + with anyopen(filepath, "r") as f: data_out = f.read() assert data_out == data -@pytest.mark.parametrize('extension,compression', - [('bz2', 'gzip'), ('gz', 'bzip2')]) +@pytest.mark.parametrize("extension,compression", [("bz2", "gzip"), ("gz", "bzip2")]) def test_file_roundtrip_force_compression(extension, compression, tmp_path): """Test that roundtripping write/read to a file works with `anyopen`, in which we force compression despite different extension. @@ -87,50 +81,45 @@ def test_file_roundtrip_force_compression(extension, compression, tmp_path): data = "my momma told me to pick the very best one and you are not it" - filepath = tmp_path / f'testfile.txt.{extension}' - with anyopen(filepath, mode='w', compression=compression) as f: + filepath = tmp_path / f"testfile.txt.{extension}" + with anyopen(filepath, mode="w", compression=compression) as f: f.write(data) - with anyopen(filepath, 'r', compression=compression) as f: + with anyopen(filepath, "r", compression=compression) as f: data_out = f.read() assert data_out == data -@pytest.mark.parametrize('compression', ['bzip2', 'gzip']) +@pytest.mark.parametrize("compression", ["bzip2", "gzip"]) def test_stream_roundtrip(compression): - """Test that roundtripping write/read to a stream works with `anyopen` - - """ + """Test that roundtripping write/read to a stream works with `anyopen`""" data = "my momma told me to pick the very best one and you are not it" with io.BytesIO() as stream: - # write to stream - with anyopen(stream, mode='w', compression=compression) as f: + with anyopen(stream, mode="w", compression=compression) as f: f.write(data) # start at the beginning stream.seek(0) # read from stream - with anyopen(stream, 'r', compression=compression) as f: + with anyopen(stream, "r", compression=compression) as f: data_out = f.read() assert data_out == data -def test_stream_unsupported_compression(): - """Test that we throw a ValueError when an unsupported compression is used. - """ +def test_stream_unsupported_compression(): + """Test that we throw a ValueError when an unsupported compression is used.""" - compression="fakez" + compression = "fakez" data = b"my momma told me to pick the very best one and you are not it" with io.BytesIO() as stream: - # write to stream stream.write(data) @@ -139,5 +128,5 @@ def test_stream_unsupported_compression(): # read from stream with pytest.raises(ValueError): - with anyopen(stream, 'r', compression=compression) as f: + with anyopen(stream, "r", compression=compression) as f: data_out = f.read() diff --git a/src/alchemlyb/tests/test_convergence.py b/src/alchemlyb/tests/test_convergence.py index d0ffb2c2..32fa1eb9 100644 --- a/src/alchemlyb/tests/test_convergence.py +++ b/src/alchemlyb/tests/test_convergence.py @@ -2,8 +2,7 @@ import pandas as pd import pytest -from alchemlyb.convergence import forward_backward_convergence, \ - fwdrev_cumavg_Rc, A_c +from alchemlyb.convergence import forward_backward_convergence, fwdrev_cumavg_Rc, A_c from alchemlyb.convergence.convergence import _cummean diff --git a/src/alchemlyb/tests/test_fep_estimators.py b/src/alchemlyb/tests/test_fep_estimators.py index 05a2dd24..9d041eb8 100644 --- a/src/alchemlyb/tests/test_fep_estimators.py +++ b/src/alchemlyb/tests/test_fep_estimators.py @@ -13,7 +13,6 @@ class FEPestimatorMixin: """Mixin for all FEP Estimator test classes.""" def compare_delta_f(self, X_delta_f): - est = self.cls().fit(X_delta_f[0]) delta_f, d_delta_f = self.get_delta_f(est) diff --git a/src/alchemlyb/tests/test_import.py b/src/alchemlyb/tests/test_import.py index 50a5d933..ef467ec8 100644 --- a/src/alchemlyb/tests/test_import.py +++ b/src/alchemlyb/tests/test_import.py @@ -1,4 +1,5 @@ import alchemlyb + def test_name(): - assert alchemlyb.__name__ == 'alchemlyb' + assert alchemlyb.__name__ == "alchemlyb" diff --git a/src/alchemlyb/tests/test_preprocessing.py b/src/alchemlyb/tests/test_preprocessing.py index 5e2cc611..10085828 100644 --- a/src/alchemlyb/tests/test_preprocessing.py +++ b/src/alchemlyb/tests/test_preprocessing.py @@ -69,9 +69,9 @@ def test_unchanged(self, namd_idws): # NAMD energy files only have dE for adjacent lambdas, this ensures # that the slicer will not drop these rows as they have NaN values. # Do the pre-processing as the u_nk are from all lambdas - groups = namd_idws.groupby('fep-lambda') + groups = namd_idws.groupby("fep-lambda") for key, group in groups: - group = group[~group.index.duplicated(keep='first')] + group = group[~group.index.duplicated(keep="first")] df = self.slicer(group, None, None, None) assert len(df) == len(group) @@ -252,7 +252,6 @@ def test_conservative(self, dataloader, size, conservative, request): ], ) def test_raise_ValueError_for_mismatched_data(self, dataloader, end, step, request): - data = request.getfixturevalue(dataloader) with pytest.raises(ValueError): self.slicer(data, series=data[:end:step]) diff --git a/src/alchemlyb/tests/test_version.py b/src/alchemlyb/tests/test_version.py index 4f2afc78..ddab2ab6 100644 --- a/src/alchemlyb/tests/test_version.py +++ b/src/alchemlyb/tests/test_version.py @@ -1,5 +1,6 @@ import alchemlyb + def test_version(): try: version = alchemlyb.__version__ @@ -8,9 +9,10 @@ def test_version(): assert len(version) > 0 + def test_version_get_versions(): import alchemlyb._version + version = alchemlyb._version.get_versions() assert alchemlyb.__version__ == version["version"] - diff --git a/src/alchemlyb/tests/test_workflow.py b/src/alchemlyb/tests/test_workflow.py index a4308145..cc31e611 100644 --- a/src/alchemlyb/tests/test_workflow.py +++ b/src/alchemlyb/tests/test_workflow.py @@ -1,11 +1,14 @@ +import os + +import pandas as pd import pytest + from alchemlyb.workflows import base -import pandas as pd -import os -class Test_automatic_base(): + +class Test_automatic_base: @staticmethod - @pytest.fixture(scope='session') + @pytest.fixture(scope="session") def workflow(tmp_path_factory): outdir = tmp_path_factory.mktemp("out") workflow = base.WorkflowBase(out=str(outdir)) @@ -13,9 +16,9 @@ def workflow(tmp_path_factory): return workflow def test_write(self, workflow): - '''Patch the output directory to tmpdir''' - workflow.result.to_pickle(os.path.join(workflow.out, 'result.pkl')) - assert os.path.exists(os.path.join(workflow.out, 'result.pkl')) + """Patch the output directory to tmpdir""" + workflow.result.to_pickle(os.path.join(workflow.out, "result.pkl")) + assert os.path.exists(os.path.join(workflow.out, "result.pkl")) def test_read(self, workflow): assert len(workflow.u_nk_list) == 0 diff --git a/src/alchemlyb/tests/test_workflow_ABFE.py b/src/alchemlyb/tests/test_workflow_ABFE.py index f282041b..e1d6a4f2 100644 --- a/src/alchemlyb/tests/test_workflow_ABFE.py +++ b/src/alchemlyb/tests/test_workflow_ABFE.py @@ -1,204 +1,235 @@ +import os + import numpy as np import pytest -import os +from alchemtest.amber import load_bace_example +from alchemtest.gmx import load_ABFE, load_benzene from alchemlyb.workflows.abfe import ABFE -from alchemtest.gmx import load_ABFE, load_benzene -from alchemtest.amber import load_bace_example -class Test_automatic_ABFE(): - '''Test the full automatic workflow for load_ABFE from alchemtest.gmx for - three stage transformation.''' + +class Test_automatic_ABFE: + """Test the full automatic workflow for load_ABFE from alchemtest.gmx for + three stage transformation.""" @staticmethod - @pytest.fixture(scope='session') + @pytest.fixture(scope="session") def workflow(tmp_path_factory): outdir = tmp_path_factory.mktemp("out") - dir = os.path.dirname(load_ABFE()['data']['complex'][0]) - workflow = ABFE(units='kcal/mol', software='GROMACS', dir=dir, - prefix='dhdl', suffix='xvg', T=310, outdirectory=str(outdir)) - workflow.run(skiptime=10, uncorr='dE', threshold=50, - estimators=('MBAR', 'BAR', 'TI'), overlap='O_MBAR.pdf', - breakdown=True, forwrev=10) + dir = os.path.dirname(load_ABFE()["data"]["complex"][0]) + workflow = ABFE( + units="kcal/mol", + software="GROMACS", + dir=dir, + prefix="dhdl", + suffix="xvg", + T=310, + outdirectory=str(outdir), + ) + workflow.run( + skiptime=10, + uncorr="dE", + threshold=50, + estimators=("MBAR", "BAR", "TI"), + overlap="O_MBAR.pdf", + breakdown=True, + forwrev=10, + ) return workflow def test_read(self, workflow): - '''test if the files has been loaded correctly.''' + """test if the files has been loaded correctly.""" assert len(workflow.u_nk_list) == 30 assert len(workflow.dHdl_list) == 30 assert all([len(u_nk) == 1001 for u_nk in workflow.u_nk_list]) assert all([len(dHdl) == 1001 for dHdl in workflow.dHdl_list]) def test_subsample(self, workflow): - '''Test if the data has been shrinked by subsampling.''' + """Test if the data has been shrinked by subsampling.""" assert len(workflow.u_nk_sample_list) == 30 assert len(workflow.dHdl_sample_list) == 30 assert all([len(u_nk) < 1001 for u_nk in workflow.u_nk_sample_list]) assert all([len(dHdl) < 1001 for dHdl in workflow.dHdl_sample_list]) def test_estimator(self, workflow): - '''Test if all three estimators have been used.''' + """Test if all three estimators have been used.""" assert len(workflow.estimator) == 3 - assert 'MBAR' in workflow.estimator - assert 'TI' in workflow.estimator - assert 'BAR' in workflow.estimator + assert "MBAR" in workflow.estimator + assert "TI" in workflow.estimator + assert "BAR" in workflow.estimator def test_summary(self, workflow): - '''Test if if the summary is right.''' + """Test if if the summary is right.""" summary = workflow.generate_result() - assert np.isclose(summary['MBAR']['Stages']['TOTAL'], 21.8, 0.1) - assert np.isclose(summary['TI']['Stages']['TOTAL'], 21.8, 0.1) - assert np.isclose(summary['BAR']['Stages']['TOTAL'], 21.8, 0.1) + assert np.isclose(summary["MBAR"]["Stages"]["TOTAL"], 21.8, 0.1) + assert np.isclose(summary["TI"]["Stages"]["TOTAL"], 21.8, 0.1) + assert np.isclose(summary["BAR"]["Stages"]["TOTAL"], 21.8, 0.1) def test_plot_O_MBAR(self, workflow): - '''test if the O_MBAR.pdf has been plotted.''' - assert os.path.isfile(os.path.join(workflow.out, 'O_MBAR.pdf')) + """test if the O_MBAR.pdf has been plotted.""" + assert os.path.isfile(os.path.join(workflow.out, "O_MBAR.pdf")) def test_plot_dhdl_TI(self, workflow): - '''test if the dhdl_TI.pdf has been plotted.''' - assert os.path.isfile(os.path.join(workflow.out, 'dhdl_TI.pdf')) + """test if the dhdl_TI.pdf has been plotted.""" + assert os.path.isfile(os.path.join(workflow.out, "dhdl_TI.pdf")) def test_plot_dF_state(self, workflow): - '''test if the dF_state.pdf has been plotted.''' - assert os.path.isfile(os.path.join(workflow.out, 'dF_state.pdf')) - assert os.path.isfile(os.path.join(workflow.out, 'dF_state_long.pdf')) + """test if the dF_state.pdf has been plotted.""" + assert os.path.isfile(os.path.join(workflow.out, "dF_state.pdf")) + assert os.path.isfile(os.path.join(workflow.out, "dF_state_long.pdf")) def test_check_convergence(self, workflow): - '''test if the dF_state.pdf has been plotted.''' - assert os.path.isfile(os.path.join(workflow.out, 'dF_t.pdf')) + """test if the dF_state.pdf has been plotted.""" + assert os.path.isfile(os.path.join(workflow.out, "dF_t.pdf")) assert len(workflow.convergence) == 10 def test_estimator_method(self, workflow, monkeypatch): - '''Test if the method keyword could be passed to the AutoMBAR estimator.''' - monkeypatch.setattr(workflow, 'estimator', - dict()) - workflow.estimate(estimators='MBAR', method='adaptive') - assert 'MBAR' in workflow.estimator + """Test if the method keyword could be passed to the AutoMBAR estimator.""" + monkeypatch.setattr(workflow, "estimator", dict()) + workflow.estimate(estimators="MBAR", method="adaptive") + assert "MBAR" in workflow.estimator def test_convergence_method(self, workflow, monkeypatch): - '''Test if the method keyword could be passed to the AutoMBAR estimator from convergence.''' - monkeypatch.setattr(workflow, 'convergence', None) - workflow.check_convergence(2, estimator='MBAR', method='adaptive') + """Test if the method keyword could be passed to the AutoMBAR estimator from convergence.""" + monkeypatch.setattr(workflow, "convergence", None) + workflow.check_convergence(2, estimator="MBAR", method="adaptive") assert len(workflow.convergence) == 2 + class Test_manual_ABFE(Test_automatic_ABFE): - '''Test the manual workflow for load_ABFE from alchemtest.gmx for three - stage transformation.''' + """Test the manual workflow for load_ABFE from alchemtest.gmx for three + stage transformation.""" @staticmethod - @pytest.fixture(scope='session') + @pytest.fixture(scope="session") def workflow(tmp_path_factory): outdir = tmp_path_factory.mktemp("out") - dir = os.path.dirname(load_ABFE()['data']['complex'][0]) - workflow = ABFE(software='GROMACS', dir=dir, prefix='dhdl', - suffix='xvg', T=310, outdirectory=str(outdir)) - workflow.update_units('kcal/mol') + dir = os.path.dirname(load_ABFE()["data"]["complex"][0]) + workflow = ABFE( + software="GROMACS", + dir=dir, + prefix="dhdl", + suffix="xvg", + T=310, + outdirectory=str(outdir), + ) + workflow.update_units("kcal/mol") workflow.read() - workflow.preprocess(skiptime=10, uncorr='dE', threshold=50) - workflow.estimate(estimators=('MBAR', 'BAR', 'TI')) - workflow.plot_overlap_matrix(overlap='O_MBAR.pdf') - workflow.plot_ti_dhdl(dhdl_TI='dhdl_TI.pdf') - workflow.plot_dF_state(dF_state='dF_state.pdf') - workflow.check_convergence(10, dF_t='dF_t.pdf') + workflow.preprocess(skiptime=10, uncorr="dE", threshold=50) + workflow.estimate(estimators=("MBAR", "BAR", "TI")) + workflow.plot_overlap_matrix(overlap="O_MBAR.pdf") + workflow.plot_ti_dhdl(dhdl_TI="dhdl_TI.pdf") + workflow.plot_dF_state(dF_state="dF_state.pdf") + workflow.check_convergence(10, dF_t="dF_t.pdf") return workflow def test_plot_dF_state(self, workflow): - '''test if the dF_state.pdf has been plotted.''' - assert os.path.isfile(os.path.join(workflow.out, 'dF_state.pdf')) + """test if the dF_state.pdf has been plotted.""" + assert os.path.isfile(os.path.join(workflow.out, "dF_state.pdf")) def test_convergence_nosample_u_nk(self, workflow, monkeypatch): - '''test if the convergence routine would use the unsampled data - when the data has not been subsampled.''' - monkeypatch.setattr(workflow, 'u_nk_sample_list', - None) + """test if the convergence routine would use the unsampled data + when the data has not been subsampled.""" + monkeypatch.setattr(workflow, "u_nk_sample_list", None) workflow.check_convergence(10) assert len(workflow.convergence) == 10 def test_dhdl_TI_noTI(self, workflow, monkeypatch): - '''Test to plot the dhdl_TI when ti estimator is not there''' + """Test to plot the dhdl_TI when ti estimator is not there""" no_TI = workflow.estimator - no_TI.pop('TI') - monkeypatch.setattr(workflow, 'estimator', - no_TI) + no_TI.pop("TI") + monkeypatch.setattr(workflow, "estimator", no_TI) with pytest.raises(ValueError): - workflow.plot_ti_dhdl(dhdl_TI='dhdl_TI.pdf') + workflow.plot_ti_dhdl(dhdl_TI="dhdl_TI.pdf") def test_noMBAR_for_plot_overlap_matrix(self, workflow, monkeypatch): - monkeypatch.setattr(workflow, 'estimator', {}) + monkeypatch.setattr(workflow, "estimator", {}) assert workflow.plot_overlap_matrix() is None def test_no_u_nk_for_check_convergence(self, workflow, monkeypatch): - monkeypatch.setattr(workflow, 'u_nk_list', None) - monkeypatch.setattr(workflow, 'u_nk_sample_list', None) + monkeypatch.setattr(workflow, "u_nk_list", None) + monkeypatch.setattr(workflow, "u_nk_sample_list", None) with pytest.raises(ValueError): - workflow.check_convergence(10, estimator='MBAR') + workflow.check_convergence(10, estimator="MBAR") def test_no_dHdl_for_check_convergence(self, workflow, monkeypatch): - monkeypatch.setattr(workflow, 'dHdl_list', None) - monkeypatch.setattr(workflow, 'dHdl_sample_list', None) + monkeypatch.setattr(workflow, "dHdl_list", None) + monkeypatch.setattr(workflow, "dHdl_sample_list", None) with pytest.raises(ValueError): - workflow.check_convergence(10, estimator='TI') + workflow.check_convergence(10, estimator="TI") def test_no_update_units(self, workflow): assert workflow.update_units() is None def test_no_name_estimate(self, workflow): with pytest.raises(ValueError): - workflow.estimate('aaa') + workflow.estimate("aaa") -class Test_automatic_benzene(): - '''Test the full automatic workflow for load_benzene from alchemtest.gmx for - single stage transformation.''' +class Test_automatic_benzene: + """Test the full automatic workflow for load_benzene from alchemtest.gmx for + single stage transformation.""" @staticmethod - @pytest.fixture(scope='session') + @pytest.fixture(scope="session") def workflow(tmp_path_factory): outdir = tmp_path_factory.mktemp("out") - dir = os.path.dirname(os.path.dirname( - load_benzene()['data']['Coulomb'][0])) - dir = os.path.join(dir, '*') - workflow = ABFE(units='kcal/mol', software='GROMACS', dir=dir, - prefix='dhdl', suffix='bz2', T=310, - outdirectory=outdir) - workflow.run(skiptime=0, uncorr='dE', threshold=50, - estimators=('MBAR', 'BAR', 'TI'), overlap='O_MBAR.pdf', - breakdown=True, forwrev=10) + dir = os.path.dirname(os.path.dirname(load_benzene()["data"]["Coulomb"][0])) + dir = os.path.join(dir, "*") + workflow = ABFE( + units="kcal/mol", + software="GROMACS", + dir=dir, + prefix="dhdl", + suffix="bz2", + T=310, + outdirectory=outdir, + ) + workflow.run( + skiptime=0, + uncorr="dE", + threshold=50, + estimators=("MBAR", "BAR", "TI"), + overlap="O_MBAR.pdf", + breakdown=True, + forwrev=10, + ) return workflow def test_read(self, workflow): - '''test if the files has been loaded correctly.''' + """test if the files has been loaded correctly.""" assert len(workflow.u_nk_list) == 5 assert len(workflow.dHdl_list) == 5 assert all([len(u_nk) == 4001 for u_nk in workflow.u_nk_list]) assert all([len(dHdl) == 4001 for dHdl in workflow.dHdl_list]) def test_estimator(self, workflow): - '''Test if all three estimators have been used.''' + """Test if all three estimators have been used.""" assert len(workflow.estimator) == 3 - assert 'MBAR' in workflow.estimator - assert 'TI' in workflow.estimator - assert 'BAR' in workflow.estimator + assert "MBAR" in workflow.estimator + assert "TI" in workflow.estimator + assert "BAR" in workflow.estimator def test_O_MBAR(self, workflow): - '''test if the O_MBAR.pdf has been plotted.''' - assert os.path.isfile(os.path.join(workflow.out, 'O_MBAR.pdf')) + """test if the O_MBAR.pdf has been plotted.""" + assert os.path.isfile(os.path.join(workflow.out, "O_MBAR.pdf")) def test_dhdl_TI(self, workflow): - '''test if the dhdl_TI.pdf has been plotted.''' - assert os.path.isfile(os.path.join(workflow.out, 'dhdl_TI.pdf')) + """test if the dhdl_TI.pdf has been plotted.""" + assert os.path.isfile(os.path.join(workflow.out, "dhdl_TI.pdf")) def test_dF_state(self, workflow): - '''test if the dF_state.pdf has been plotted.''' - assert os.path.isfile(os.path.join(workflow.out, 'dF_state.pdf')) + """test if the dF_state.pdf has been plotted.""" + assert os.path.isfile(os.path.join(workflow.out, "dF_state.pdf")) def test_convergence(self, workflow): - '''test if the dF_state.pdf has been plotted.''' - assert os.path.isfile(os.path.join(workflow.out, 'dF_t.pdf')) + """test if the dF_state.pdf has been plotted.""" + assert os.path.isfile(os.path.join(workflow.out, "dF_t.pdf")) assert len(workflow.convergence) == 10 -class Test_unpertubed_lambda(): - '''Test the if two lamdas present and one of them is not pertubed. + +class Test_unpertubed_lambda: + """Test the if two lamdas present and one of them is not pertubed. fep bound time fep-lambda bound-lambda @@ -209,87 +240,118 @@ class Test_unpertubed_lambda(): 40.0 0.5 0 7.768072 0 Where only fep-lambda changes but the bonded-lambda is always 0. - ''' + """ @staticmethod - @pytest.fixture(scope='session') + @pytest.fixture(scope="session") def workflow(tmp_path_factory): outdir = tmp_path_factory.mktemp("out") - dir = os.path.dirname(os.path.dirname( - load_benzene()['data']['Coulomb'][0])) - dir = os.path.join(dir, '*') - workflow = ABFE(software='GROMACS', dir=dir, prefix='dhdl', - suffix='bz2', T=310, outdirectory=outdir) + dir = os.path.dirname(os.path.dirname(load_benzene()["data"]["Coulomb"][0])) + dir = os.path.join(dir, "*") + workflow = ABFE( + software="GROMACS", + dir=dir, + prefix="dhdl", + suffix="bz2", + T=310, + outdirectory=outdir, + ) workflow.read() # Block the n_uk workflow.u_nk_list = [] # Add another lambda column for dHdl in workflow.dHdl_list: - dHdl.insert(1, 'bound-lambda', [1.0, ] * len(dHdl)) - dHdl.insert(1, 'bound', [1.0, ] * len(dHdl)) - dHdl.set_index('bound-lambda', append=True, inplace=True) - - workflow.estimate(estimators=('TI', )) - workflow.plot_ti_dhdl(dhdl_TI='dhdl_TI.pdf') - workflow.plot_dF_state(dF_state='dF_state.pdf') - workflow.check_convergence(10, dF_t='dF_t.pdf', estimator='TI') + dHdl.insert( + 1, + "bound-lambda", + [ + 1.0, + ] + * len(dHdl), + ) + dHdl.insert( + 1, + "bound", + [ + 1.0, + ] + * len(dHdl), + ) + dHdl.set_index("bound-lambda", append=True, inplace=True) + + workflow.estimate(estimators=("TI",)) + workflow.plot_ti_dhdl(dhdl_TI="dhdl_TI.pdf") + workflow.plot_dF_state(dF_state="dF_state.pdf") + workflow.check_convergence(10, dF_t="dF_t.pdf", estimator="TI") return workflow def test_dhdl_TI(self, workflow): - '''test if the dhdl_TI.pdf has been plotted.''' - assert os.path.isfile(os.path.join(workflow.out, 'dhdl_TI.pdf')) + """test if the dhdl_TI.pdf has been plotted.""" + assert os.path.isfile(os.path.join(workflow.out, "dhdl_TI.pdf")) def test_dF_state(self, workflow): - '''test if the dF_state.pdf has been plotted.''' - assert os.path.isfile(os.path.join(workflow.out, 'dF_state.pdf')) + """test if the dF_state.pdf has been plotted.""" + assert os.path.isfile(os.path.join(workflow.out, "dF_state.pdf")) def test_convergence(self, workflow): - '''test if the dF_state.pdf has been plotted.''' - assert os.path.isfile(os.path.join(workflow.out, 'dF_t.pdf')) + """test if the dF_state.pdf has been plotted.""" + assert os.path.isfile(os.path.join(workflow.out, "dF_t.pdf")) assert len(workflow.convergence) == 10 def test_single_estimator_ti(self, workflow): - workflow.estimate(estimators='TI') + workflow.estimate(estimators="TI") summary = workflow.generate_result() - assert np.isclose(summary['TI']['Stages']['TOTAL'], 2.946, 0.1) + assert np.isclose(summary["TI"]["Stages"]["TOTAL"], 2.946, 0.1) + -class Test_methods(): - '''Test various methods.''' +class Test_methods: + """Test various methods.""" @staticmethod - @pytest.fixture(scope='class') + @pytest.fixture(scope="class") def workflow(tmp_path_factory): outdir = tmp_path_factory.mktemp("out") - dir = os.path.dirname(os.path.dirname( - load_benzene()['data']['Coulomb'][0])) - dir = os.path.join(dir, '*') - workflow = ABFE(software='GROMACS', dir=dir, prefix='dhdl', - suffix='bz2', T=310, outdirectory=outdir) + dir = os.path.dirname(os.path.dirname(load_benzene()["data"]["Coulomb"][0])) + dir = os.path.join(dir, "*") + workflow = ABFE( + software="GROMACS", + dir=dir, + prefix="dhdl", + suffix="bz2", + T=310, + outdirectory=outdir, + ) workflow.read() return workflow def test_run_none(self, workflow): - '''Don't run anything''' - workflow.run(uncorr=None, estimators=None, overlap=None, breakdown=None, - forwrev=None) + """Don't run anything""" + workflow.run( + uncorr=None, estimators=None, overlap=None, breakdown=None, forwrev=None + ) def test_run_single_estimator(self, workflow, monkeypatch): - monkeypatch.setattr(workflow, 'u_nk_list', []) - monkeypatch.setattr(workflow, 'dHdl_list', []) - workflow.run(uncorr=None, estimators='MBAR', overlap=None, breakdown=True, - forwrev=None) + monkeypatch.setattr(workflow, "u_nk_list", []) + monkeypatch.setattr(workflow, "dHdl_list", []) + workflow.run( + uncorr=None, estimators="MBAR", overlap=None, breakdown=True, forwrev=None + ) def test_run_invalid_estimator(self, workflow): - with pytest.raises(ValueError, - match=r'Estimator aaa is not supported.'): - workflow.run(uncorr=None, estimators='aaa', overlap=None, breakdown=None, - forwrev=None) - - @pytest.mark.parametrize('read_u_nk', [True, False]) - @pytest.mark.parametrize('read_dHdl', [True, False]) + with pytest.raises(ValueError, match=r"Estimator aaa is not supported."): + workflow.run( + uncorr=None, + estimators="aaa", + overlap=None, + breakdown=None, + forwrev=None, + ) + + @pytest.mark.parametrize("read_u_nk", [True, False]) + @pytest.mark.parametrize("read_dHdl", [True, False]) def test_read_TI_FEP(self, workflow, monkeypatch, read_u_nk, read_dHdl): - monkeypatch.setattr(workflow, 'u_nk_list', []) - monkeypatch.setattr(workflow, 'dHdl_list', []) + monkeypatch.setattr(workflow, "u_nk_list", []) + monkeypatch.setattr(workflow, "dHdl_list", []) workflow.read(read_u_nk, read_dHdl) if read_u_nk: assert len(workflow.u_nk_list) == 5 @@ -303,104 +365,112 @@ def test_read_TI_FEP(self, workflow, monkeypatch, read_u_nk, read_dHdl): def test_read_invalid_u_nk(self, workflow, monkeypatch): def extract_u_nk(self, T): - raise IOError('Error read u_nk.') - monkeypatch.setattr(workflow, '_extract_u_nk', - extract_u_nk) - with pytest.raises(OSError, - match=r'Error reading u_nk .*dhdl\.xvg\.bz2'): + raise IOError("Error read u_nk.") + + monkeypatch.setattr(workflow, "_extract_u_nk", extract_u_nk) + with pytest.raises(OSError, match=r"Error reading u_nk .*dhdl\.xvg\.bz2"): workflow.read() def test_read_invalid_dHdl(self, workflow, monkeypatch): def extract_dHdl(self, T): - raise IOError('Error read dHdl.') - monkeypatch.setattr(workflow, '_extract_dHdl', - extract_dHdl) - with pytest.raises(OSError, - match=r'Error reading dHdl .*dhdl\.xvg\.bz2'): + raise IOError("Error read dHdl.") + + monkeypatch.setattr(workflow, "_extract_dHdl", extract_dHdl) + with pytest.raises(OSError, match=r"Error reading dHdl .*dhdl\.xvg\.bz2"): workflow.read() def test_uncorr_threshold(self, workflow, monkeypatch): - '''Test if the full data will be used when the number of data points - are less than the threshold.''' - monkeypatch.setattr(workflow, 'u_nk_list', - [u_nk[:40] for u_nk in workflow.u_nk_list]) - monkeypatch.setattr(workflow, 'dHdl_list', - [dHdl[:40] for dHdl in workflow.dHdl_list]) + """Test if the full data will be used when the number of data points + are less than the threshold.""" + monkeypatch.setattr( + workflow, "u_nk_list", [u_nk[:40] for u_nk in workflow.u_nk_list] + ) + monkeypatch.setattr( + workflow, "dHdl_list", [dHdl[:40] for dHdl in workflow.dHdl_list] + ) workflow.preprocess(threshold=50) assert all([len(u_nk) == 40 for u_nk in workflow.u_nk_sample_list]) assert all([len(dHdl) == 40 for dHdl in workflow.dHdl_sample_list]) def test_no_u_nk_preprocess(self, workflow, monkeypatch): - monkeypatch.setattr(workflow, 'u_nk_list', []) + monkeypatch.setattr(workflow, "u_nk_list", []) workflow.preprocess(threshold=50) assert len(workflow.u_nk_list) == 0 def test_no_dHdl_preprocess(self, workflow, monkeypatch): - monkeypatch.setattr(workflow, 'dHdl_list', []) + monkeypatch.setattr(workflow, "dHdl_list", []) workflow.preprocess(threshold=50) assert len(workflow.dHdl_list) == 0 def test_single_estimator_mbar(self, workflow): - workflow.estimate(estimators='MBAR') + workflow.estimate(estimators="MBAR") assert len(workflow.estimator) == 1 - assert 'MBAR' in workflow.estimator + assert "MBAR" in workflow.estimator summary = workflow.generate_result() - assert np.isclose(summary['MBAR']['Stages']['TOTAL'], 2.946, 0.1) + assert np.isclose(summary["MBAR"]["Stages"]["TOTAL"], 2.946, 0.1) def test_single_estimator_ti(self, workflow): - workflow.estimate(estimators='TI') + workflow.estimate(estimators="TI") summary = workflow.generate_result() - assert np.isclose(summary['TI']['Stages']['TOTAL'], 2.946, 0.1) + assert np.isclose(summary["TI"]["Stages"]["TOTAL"], 2.946, 0.1) def test_bar_convergence(self, workflow): - workflow.check_convergence(10, estimator='BAR') + workflow.check_convergence(10, estimator="BAR") assert len(workflow.convergence) == 10 def test_convergence_invalid_estimator(self, workflow): with pytest.raises(ValueError): - workflow.check_convergence(10, estimator='aaa') + workflow.check_convergence(10, estimator="aaa") def test_ti_convergence(self, workflow): - workflow.check_convergence(10, estimator='TI') + workflow.check_convergence(10, estimator="TI") assert len(workflow.convergence) == 10 def test_unprocessed_n_uk(self, workflow, monkeypatch): - monkeypatch.setattr(workflow, 'u_nk_sample_list', - None) + monkeypatch.setattr(workflow, "u_nk_sample_list", None) workflow.estimate() assert len(workflow.estimator) == 3 - assert 'MBAR' in workflow.estimator + assert "MBAR" in workflow.estimator def test_unprocessed_dhdl(self, workflow, monkeypatch): - monkeypatch.setattr(workflow, 'dHdl_sample_list', - None) - workflow.check_convergence(10, estimator='TI') + monkeypatch.setattr(workflow, "dHdl_sample_list", None) + workflow.check_convergence(10, estimator="TI") assert len(workflow.convergence) == 10 -class Test_automatic_amber(): - '''Test the full automatic workflow for load_ABFE from alchemtest.amber for - three stage transformation.''' + +class Test_automatic_amber: + """Test the full automatic workflow for load_ABFE from alchemtest.amber for + three stage transformation.""" @staticmethod - @pytest.fixture(scope='session') + @pytest.fixture(scope="session") def workflow(tmp_path_factory): outdir = tmp_path_factory.mktemp("out") dir, _ = os.path.split( - os.path.dirname(load_bace_example()['data']['complex']['vdw'][0])) - - workflow = ABFE(units='kcal/mol', software='AMBER', dir=dir, - prefix='ti', suffix='bz2', T=298.0, outdirectory=str( - outdir)) + os.path.dirname(load_bace_example()["data"]["complex"]["vdw"][0]) + ) + + workflow = ABFE( + units="kcal/mol", + software="AMBER", + dir=dir, + prefix="ti", + suffix="bz2", + T=298.0, + outdirectory=str(outdir), + ) workflow.read() - workflow.estimate(estimators='TI') + workflow.estimate(estimators="TI") return workflow def test_summary(self, workflow): - '''Test if if the summary is right.''' + """Test if if the summary is right.""" summary = workflow.generate_result() - assert np.isclose(summary['TI']['Stages']['TOTAL'], 1.40405980473, 0.1) + assert np.isclose(summary["TI"]["Stages"]["TOTAL"], 1.40405980473, 0.1) + def test_no_parser(): with pytest.raises(NotImplementedError): - workflow = ABFE(units='kcal/mol', software='aaa', - prefix='ti', suffix='bz2', T=298.0) + workflow = ABFE( + units="kcal/mol", software="aaa", prefix="ti", suffix="bz2", T=298.0 + ) diff --git a/src/alchemlyb/visualisation/__init__.py b/src/alchemlyb/visualisation/__init__.py index d58b367e..6955dcaf 100644 --- a/src/alchemlyb/visualisation/__init__.py +++ b/src/alchemlyb/visualisation/__init__.py @@ -1,4 +1,4 @@ +from .convergence import plot_convergence +from .dF_state import plot_dF_state from .mbar_matrix import plot_mbar_overlap_matrix from .ti_dhdl import plot_ti_dhdl -from .dF_state import plot_dF_state -from .convergence import plot_convergence \ No newline at end of file diff --git a/src/alchemlyb/visualisation/convergence.py b/src/alchemlyb/visualisation/convergence.py index fcef3e50..c1cc477a 100644 --- a/src/alchemlyb/visualisation/convergence.py +++ b/src/alchemlyb/visualisation/convergence.py @@ -1,92 +1,91 @@ import matplotlib.pyplot as plt -import pandas as pd -from matplotlib.font_manager import FontProperties as FP import numpy as np +from matplotlib.font_manager import FontProperties as FP from ..postprocessors.units import get_unit_converter + def plot_convergence(dataframe, units=None, final_error=None, ax=None): """Plot the forward and backward convergence. - The input could be the result from - :func:`~alchemlyb.convergence.forward_backward_convergence` or - :func:`~alchemlyb.convergence.fwdrev_cumavg_Rc`. The input should be a - :class:`pandas.DataFrame` which has column `Forward`, `Backward` and - :attr:`pandas.DataFrame.attrs` should compile with :ref:`note-on-units`. - The errorbar will be plotted if column `Forward_Error` and `Backward_Error` - is present. - - `Forward`: A column of free energy estimate from the first X% of data, - where optional `Forward_Error` column is the corresponding error. - - `Backward`: A column of free energy estimate from the last X% of data., - where optional `Backward_Error` column is the corresponding error. - - `final_error` is the error of the final value and is shown as the error band around the - final value. It can be provided in case an estimate is available that is more appropriate - than the default, which is the error of the last value in `Backward`. - - Parameters - ---------- - dataframe : Dataframe - Output Dataframe has column `Forward`, `Backward` or optionally - `Forward_Error`, `Backward_Error` see :ref:`plot_convergence `. - units : str - The unit of the estimate. The default is `None`, which is to use the - unit in the input. Setting this will change the output unit. - final_error : float - The error of the final value in ``units``. If not given, takes the last - error in `backward_error`. - ax : matplotlib.axes.Axes - Matplotlib axes object where the plot will be drawn on. If ``ax=None``, - a new axes will be generated. - - Returns - ------- - matplotlib.axes.Axes - An axes with the forward and backward convergence drawn. - - Note - ---- - The code is taken and modified from - `Alchemical Analysis `_. - - - .. versionchanged:: 1.0.0 - Keyword arg final_error for plotting a horizontal error bar. - The array input has been deprecated. - The units default to `None` which uses the units in the input. - - .. versionchanged:: 0.6.0 - data now takes in dataframe - - .. versionadded:: 0.4.0 + The input could be the result from + :func:`~alchemlyb.convergence.forward_backward_convergence` or + :func:`~alchemlyb.convergence.fwdrev_cumavg_Rc`. The input should be a + :class:`pandas.DataFrame` which has column `Forward`, `Backward` and + :attr:`pandas.DataFrame.attrs` should compile with :ref:`note-on-units`. + The errorbar will be plotted if column `Forward_Error` and `Backward_Error` + is present. + + `Forward`: A column of free energy estimate from the first X% of data, + where optional `Forward_Error` column is the corresponding error. + + `Backward`: A column of free energy estimate from the last X% of data., + where optional `Backward_Error` column is the corresponding error. + + `final_error` is the error of the final value and is shown as the error band around the + final value. It can be provided in case an estimate is available that is more appropriate + than the default, which is the error of the last value in `Backward`. + + Parameters + ---------- + dataframe : Dataframe + Output Dataframe has column `Forward`, `Backward` or optionally + `Forward_Error`, `Backward_Error` see :ref:`plot_convergence `. + units : str + The unit of the estimate. The default is `None`, which is to use the + unit in the input. Setting this will change the output unit. + final_error : float + The error of the final value in ``units``. If not given, takes the last + error in `backward_error`. + ax : matplotlib.axes.Axes + Matplotlib axes object where the plot will be drawn on. If ``ax=None``, + a new axes will be generated. + + Returns + ------- + matplotlib.axes.Axes + An axes with the forward and backward convergence drawn. + + Note + ---- + The code is taken and modified from + `Alchemical Analysis `_. + + + .. versionchanged:: 1.0.0 + Keyword arg final_error for plotting a horizontal error bar. + The array input has been deprecated. + The units default to `None` which uses the units in the input. + + .. versionchanged:: 0.6.0 + data now takes in dataframe + + .. versionadded:: 0.4.0 """ if units is not None: dataframe = get_unit_converter(units)(dataframe) - forward = dataframe['Forward'].to_numpy() - if 'Forward_Error' in dataframe: - forward_error = dataframe['Forward_Error'].to_numpy() + forward = dataframe["Forward"].to_numpy() + if "Forward_Error" in dataframe: + forward_error = dataframe["Forward_Error"].to_numpy() else: forward_error = np.zeros(len(forward)) - backward = dataframe['Backward'].to_numpy() - if 'Backward_Error' in dataframe: - backward_error = dataframe['Backward_Error'].to_numpy() + backward = dataframe["Backward"].to_numpy() + if "Backward_Error" in dataframe: + backward_error = dataframe["Backward_Error"].to_numpy() else: backward_error = np.zeros(len(backward)) - - if ax is None: # pragma: no cover + if ax is None: # pragma: no cover fig, ax = plt.subplots(figsize=(8, 6)) - plt.setp(ax.spines['bottom'], color='#D2B9D3', lw=3, zorder=-2) - plt.setp(ax.spines['left'], color='#D2B9D3', lw=3, zorder=-2) + plt.setp(ax.spines["bottom"], color="#D2B9D3", lw=3, zorder=-2) + plt.setp(ax.spines["left"], color="#D2B9D3", lw=3, zorder=-2) - for dire in ['top', 'right']: - ax.spines[dire].set_color('none') + for dire in ["top", "right"]: + ax.spines[dire].set_color("none") - ax.xaxis.set_ticks_position('bottom') - ax.yaxis.set_ticks_position('left') + ax.xaxis.set_ticks_position("bottom") + ax.yaxis.set_ticks_position("left") f_ts = np.linspace(0, 1, len(forward) + 1)[1:] r_ts = np.linspace(0, 1, len(backward) + 1)[1:] @@ -94,28 +93,54 @@ def plot_convergence(dataframe, units=None, final_error=None, ax=None): if final_error is None: final_error = backward_error[-1] - line0 = ax.fill_between([0, 1], backward[-1] - final_error, - backward[-1] + final_error, color='#D2B9D3', - zorder=1) - line1 = ax.errorbar(f_ts, forward, yerr=forward_error, color='#736AFF', - lw=3, zorder=2, marker='o', - mfc='w', mew=2.5, mec='#736AFF', ms=12,) - line2 = ax.errorbar(r_ts, backward, yerr=backward_error, color='#C11B17', - lw=3, zorder=3, marker='o', - mfc='w', mew=2.5, mec='#C11B17', ms=12, ) + line0 = ax.fill_between( + [0, 1], + backward[-1] - final_error, + backward[-1] + final_error, + color="#D2B9D3", + zorder=1, + ) + line1 = ax.errorbar( + f_ts, + forward, + yerr=forward_error, + color="#736AFF", + lw=3, + zorder=2, + marker="o", + mfc="w", + mew=2.5, + mec="#736AFF", + ms=12, + ) + line2 = ax.errorbar( + r_ts, + backward, + yerr=backward_error, + color="#C11B17", + lw=3, + zorder=3, + marker="o", + mfc="w", + mew=2.5, + mec="#C11B17", + ms=12, + ) xticks_spacing = len(r_ts) // 10 or 1 xticks = r_ts[::xticks_spacing] - plt.xticks(xticks, ['%.2f' % i for i in xticks], fontsize=10) + plt.xticks(xticks, ["%.2f" % i for i in xticks], fontsize=10) plt.yticks(fontsize=10) - ax.legend((line1[0], line2[0]), ('Forward', 'Reverse'), loc=9, - prop=FP(size=18), frameon=False) - ax.set_xlabel(r'Fraction of the simulation time', fontsize=16, - color='#151B54') - ax.set_ylabel(r'$\Delta G$ ({})'.format(units), fontsize=16, color='#151B54') - plt.tick_params(axis='x', color='#D2B9D3') - plt.tick_params(axis='y', color='#D2B9D3') + ax.legend( + (line1[0], line2[0]), + ("Forward", "Reverse"), + loc=9, + prop=FP(size=18), + frameon=False, + ) + ax.set_xlabel(r"Fraction of the simulation time", fontsize=16, color="#151B54") + ax.set_ylabel(r"$\Delta G$ ({})".format(units), fontsize=16, color="#151B54") + plt.tick_params(axis="x", color="#D2B9D3") + plt.tick_params(axis="y", color="#D2B9D3") return ax - - diff --git a/src/alchemlyb/visualisation/dF_state.py b/src/alchemlyb/visualisation/dF_state.py index 8f5a1409..e36fbc21 100644 --- a/src/alchemlyb/visualisation/dF_state.py +++ b/src/alchemlyb/visualisation/dF_state.py @@ -9,15 +9,17 @@ """ import matplotlib.pyplot as plt -from matplotlib.font_manager import FontProperties as FP import numpy as np +from matplotlib.font_manager import FontProperties as FP from ..estimators import TI, BAR, MBAR from ..postprocessors.units import get_unit_converter -def plot_dF_state(estimators, labels=None, colors=None, units=None, - orientation='portrait', nb=10): - '''Plot the dhdl of TI. + +def plot_dF_state( + estimators, labels=None, colors=None, units=None, orientation="portrait", nb=10 +): + """Plot the dhdl of TI. Parameters ---------- @@ -57,11 +59,13 @@ def plot_dF_state(estimators, labels=None, colors=None, units=None, changing the figure legend. .. versionadded:: 0.4.0 - ''' + """ try: len(estimators) except TypeError: - estimators = [estimators, ] + estimators = [ + estimators, + ] formatted_data = [] for dhdl in estimators: @@ -69,10 +73,14 @@ def plot_dF_state(estimators, labels=None, colors=None, units=None, len(dhdl) formatted_data.append(dhdl) except TypeError: - formatted_data.append([dhdl, ]) + formatted_data.append( + [ + dhdl, + ] + ) if units is None: - units = formatted_data[0][0].delta_f_.attrs['energy_unit'] + units = formatted_data[0][0].delta_f_.attrs["energy_unit"] estimators = formatted_data @@ -96,47 +104,69 @@ def plot_dF_state(estimators, labels=None, colors=None, units=None, error_list.append(error) # Get the determine orientation - if orientation == 'landscape': + if orientation == "landscape": if max_length < 8: fig, ax = plt.subplots(figsize=(8, 6)) else: fig, ax = plt.subplots(figsize=(max_length, 6)) - axs = [ax, ] - xs = [np.arange(max_length), ] - elif orientation == 'portrait': + axs = [ + ax, + ] + xs = [ + np.arange(max_length), + ] + elif orientation == "portrait": if max_length < nb: - xs = [np.arange(max_length), ] + xs = [ + np.arange(max_length), + ] fig, ax = plt.subplots(figsize=(8, 6)) - axs = [ax, ] + axs = [ + ax, + ] else: xs = np.array_split(np.arange(max_length), max_length / nb + 1) fig, axs = plt.subplots(nrows=len(xs), figsize=(8, 6)) mnb = max([len(i) for i in xs]) else: - raise ValueError("Not recognising {}, only supports 'landscape' or 'portrait'.".format(orientation)) + raise ValueError( + "Not recognising {}, only supports 'landscape' or 'portrait'.".format( + orientation + ) + ) # Sort out the colors if colors is None: - colors_dict = {'TI': '#C45AEC', 'TI-CUBIC': '#33CC33', - 'DEXP': '#F87431', 'IEXP': '#FF3030', 'GINS': '#EAC117', - 'GDEL': '#347235', 'BAR': '#6698FF', 'UBAR': '#817339', - 'RBAR': '#C11B17', 'MBAR': '#F9B7FF'} + colors_dict = { + "TI": "#C45AEC", + "TI-CUBIC": "#33CC33", + "DEXP": "#F87431", + "IEXP": "#FF3030", + "GINS": "#EAC117", + "GDEL": "#347235", + "BAR": "#6698FF", + "UBAR": "#817339", + "RBAR": "#C11B17", + "MBAR": "#F9B7FF", + } colors = [] for dhdl in estimators: dhdl = dhdl[0] if isinstance(dhdl, TI): - colors.append(colors_dict['TI']) + colors.append(colors_dict["TI"]) elif isinstance(dhdl, BAR): - colors.append(colors_dict['BAR']) + colors.append(colors_dict["BAR"]) elif isinstance(dhdl, MBAR): - colors.append(colors_dict['MBAR']) + colors.append(colors_dict["MBAR"]) else: if len(colors) >= len(estimators): pass else: raise ValueError( - 'Number of colors ({}) should be larger than the number of data ({})'.format( - len(colors), len(estimators))) + "Number of colors ({}) should be larger than the number of data ({})".format( + len(colors), len(estimators) + ) + ) # Sort out the labels if labels is None: @@ -144,21 +174,23 @@ def plot_dF_state(estimators, labels=None, colors=None, units=None, for dhdl in estimators: dhdl = dhdl[0] if isinstance(dhdl, TI): - labels.append('TI') + labels.append("TI") elif isinstance(dhdl, BAR): - labels.append('BAR') + labels.append("BAR") elif isinstance(dhdl, MBAR): - labels.append('MBAR') + labels.append("MBAR") else: if len(labels) == len(estimators): pass else: raise ValueError( - 'Length of labels ({}) should be the same as the number of data ({})'.format( - len(labels), len(estimators))) + "Length of labels ({}) should be the same as the number of data ({})".format( + len(labels), len(estimators) + ) + ) # Plot the figure - width = 1. / (len(estimators) + 1) + width = 1.0 / (len(estimators) + 1) elw = 30 * width ndx = 1 for x, ax in zip(xs, axs): @@ -166,35 +198,49 @@ def plot_dF_state(estimators, labels=None, colors=None, units=None, for i, (dF, error) in enumerate(zip(dF_list, error_list)): y = [dF[j] for j in x] ye = [error[j] for j in x] - if orientation == 'landscape': + if orientation == "landscape": lw = 0.1 * elw - elif orientation == 'portrait': + elif orientation == "portrait": lw = 0.05 * elw - line = ax.bar(x + len(lines) * width, y, width, - color=colors[i], yerr=ye, lw=lw, - error_kw=dict(elinewidth=elw, ecolor='black', - capsize=0.5 * elw)) + line = ax.bar( + x + len(lines) * width, + y, + width, + color=colors[i], + yerr=ye, + lw=lw, + error_kw=dict(elinewidth=elw, ecolor="black", capsize=0.5 * elw), + ) lines += (line[0],) - for dir in ['left', 'right', 'top', 'bottom']: - if dir == 'left': + for dir in ["left", "right", "top", "bottom"]: + if dir == "left": ax.yaxis.set_ticks_position(dir) else: - ax.spines[dir].set_color('none') + ax.spines[dir].set_color("none") - if orientation == 'landscape': + if orientation == "landscape": plt.yticks(fontsize=8) - ax.set_xlim(x[0]-width, x[-1] + len(lines) * width) - plt.xticks(x + 0.5 * width * len(estimators), - tuple(['%d--%d' % (i, i + 1) for i in x]), fontsize=8) - elif orientation == 'portrait': + ax.set_xlim(x[0] - width, x[-1] + len(lines) * width) + plt.xticks( + x + 0.5 * width * len(estimators), + tuple(["%d--%d" % (i, i + 1) for i in x]), + fontsize=8, + ) + elif orientation == "portrait": plt.yticks(fontsize=10) ax.xaxis.set_ticks([]) for i in x + 0.5 * width * len(estimators): - ax.annotate(r'$\mathrm{%d-%d}$' % (i, i + 1), xy=(i, 0), - xycoords=('data', 'axes fraction'), xytext=(0, -2), - size=10, textcoords='offset points', va='top', - ha='center') - ax.set_xlim(x[0]-width, x[-1]+len(lines)*width + (mnb - len(x))) + ax.annotate( + r"$\mathrm{%d-%d}$" % (i, i + 1), + xy=(i, 0), + xycoords=("data", "axes fraction"), + xytext=(0, -2), + size=10, + textcoords="offset points", + va="top", + ha="center", + ) + ax.set_xlim(x[0] - width, x[-1] + len(lines) * width + (mnb - len(x))) ndx += 1 x = np.arange(max_length) @@ -202,18 +248,21 @@ def plot_dF_state(estimators, labels=None, colors=None, units=None, for tick in ax.get_xticklines(): tick.set_visible(False) - if orientation == 'landscape': - leg = plt.legend(lines, labels, loc=3, ncol=2, prop=FP(size=10), - fancybox=True) - plt.title('The free energy change breakdown', fontsize=12) - plt.xlabel('States', fontsize=12, color='#151B54') - plt.ylabel(r'$\Delta G$ ({})'.format(units), fontsize=12, color='#151B54') - elif orientation == 'portrait': - leg = ax.legend(lines, labels, loc=0, ncol=2, - prop=FP(size=8), - title=r'$\Delta G$ ({})'.format(units) + - r'$\mathit{vs.}$ lambda pair', - fancybox=True) + if orientation == "landscape": + leg = plt.legend(lines, labels, loc=3, ncol=2, prop=FP(size=10), fancybox=True) + plt.title("The free energy change breakdown", fontsize=12) + plt.xlabel("States", fontsize=12, color="#151B54") + plt.ylabel(r"$\Delta G$ ({})".format(units), fontsize=12, color="#151B54") + elif orientation == "portrait": + leg = ax.legend( + lines, + labels, + loc=0, + ncol=2, + prop=FP(size=8), + title=r"$\Delta G$ ({})".format(units) + r"$\mathit{vs.}$ lambda pair", + fancybox=True, + ) leg.get_frame().set_alpha(0.5) return fig diff --git a/src/alchemlyb/visualisation/mbar_matrix.py b/src/alchemlyb/visualisation/mbar_matrix.py index 4b2bd952..6bdc068e 100644 --- a/src/alchemlyb/visualisation/mbar_matrix.py +++ b/src/alchemlyb/visualisation/mbar_matrix.py @@ -13,8 +13,9 @@ import matplotlib.pyplot as plt import numpy as np + def plot_mbar_overlap_matrix(matrix, skip_lambda_index=[], ax=None): - '''Plot the MBAR overlap matrix. + """Plot the MBAR overlap matrix. Parameters ---------- @@ -41,7 +42,7 @@ def plot_mbar_overlap_matrix(matrix, skip_lambda_index=[], ax=None): .. versionadded:: 0.4.0 - ''' + """ # Compute the size of the figure, if ax is not given. max_prob = matrix.max() size = len(matrix) @@ -49,25 +50,36 @@ def plot_mbar_overlap_matrix(matrix, skip_lambda_index=[], ax=None): fig, ax = plt.subplots(figsize=(size / 2, size / 2)) ax.set_xticks([]) ax.set_yticks([]) - ax.axis('off') + ax.axis("off") for i in range(size): if i != 0: - ax.axvline(x=i, ls='-', lw=0.5, color='k', alpha=0.25) - ax.axhline(y=i, ls='-', lw=0.5, color='k', alpha=0.25) + ax.axvline(x=i, ls="-", lw=0.5, color="k", alpha=0.25) + ax.axhline(y=i, ls="-", lw=0.5, color="k", alpha=0.25) for j in range(size): if matrix[j, i] < 0.005: - ii = '' + ii = "" elif matrix[j, i] > 0.995: - ii = '1.00' + ii = "1.00" else: - ii = ("{:.2f}".format(matrix[j, i])[1:]) + ii = "{:.2f}".format(matrix[j, i])[1:] alf = matrix[j, i] / max_prob - ax.fill_between([i, i + 1], [size - j, size - j], - [size - (j + 1), size - (j + 1)], color='k', - alpha=alf) - ax.annotate(ii, xy=(i, j), xytext=(i + 0.5, size - (j + 0.5)), - size=8, textcoords='data', va='center', - ha='center', color=('k' if alf < 0.5 else 'w')) + ax.fill_between( + [i, i + 1], + [size - j, size - j], + [size - (j + 1), size - (j + 1)], + color="k", + alpha=alf, + ) + ax.annotate( + ii, + xy=(i, j), + xytext=(i + 0.5, size - (j + 0.5)), + size=8, + textcoords="data", + va="center", + ha="center", + color=("k" if alf < 0.5 else "w"), + ) if skip_lambda_index: ks = [int(l) for l in skip_lambda_index] @@ -75,31 +87,48 @@ def plot_mbar_overlap_matrix(matrix, skip_lambda_index=[], ax=None): else: ks = range(size) for i in range(size): - ax.annotate(ks[i], xy=(i + 0.5, 1), xytext=(i + 0.5, size + 0.5), - size=10, textcoords=('data', 'data'), - va='center', ha='center', color='k') - ax.annotate(ks[i], xy=(-0.5, size - (size - 0.5)), - xytext=(-0.5, size - (i + 0.5)), - size=10, textcoords=('data', 'data'), - va='center', ha='center', color='k') - ax.annotate(r'$\lambda$', xy=(-0.5, size - (size - 0.5)), - xytext=(-0.5, size + 0.5), - size=10, textcoords=('data', 'data'), - va='center', ha='center', color='k') - ax.plot([0, size], [0, 0], 'k-', lw=4.0, solid_capstyle='butt') - ax.plot([size, size], [0, size], 'k-', lw=4.0, solid_capstyle='butt') - ax.plot([0, 0], [0, size], 'k-', lw=2.0, solid_capstyle='butt') - ax.plot([0, size], [size, size], 'k-', lw=2.0, solid_capstyle='butt') + ax.annotate( + ks[i], + xy=(i + 0.5, 1), + xytext=(i + 0.5, size + 0.5), + size=10, + textcoords=("data", "data"), + va="center", + ha="center", + color="k", + ) + ax.annotate( + ks[i], + xy=(-0.5, size - (size - 0.5)), + xytext=(-0.5, size - (i + 0.5)), + size=10, + textcoords=("data", "data"), + va="center", + ha="center", + color="k", + ) + ax.annotate( + r"$\lambda$", + xy=(-0.5, size - (size - 0.5)), + xytext=(-0.5, size + 0.5), + size=10, + textcoords=("data", "data"), + va="center", + ha="center", + color="k", + ) + ax.plot([0, size], [0, 0], "k-", lw=4.0, solid_capstyle="butt") + ax.plot([size, size], [0, size], "k-", lw=4.0, solid_capstyle="butt") + ax.plot([0, 0], [0, size], "k-", lw=2.0, solid_capstyle="butt") + ax.plot([0, size], [size, size], "k-", lw=2.0, solid_capstyle="butt") cx = np.repeat(range(size + 1), 2) cy = sorted(np.repeat(range(size + 1), 2), reverse=True) - ax.plot(cx[2:-1], cy[1:-2], 'k-', lw=2.0) - ax.plot(np.array(cx[2:-3]) + 1, cy[1:-4], 'k-', lw=2.0) - ax.plot(cx[1:-2], np.array(cy[:-3]) - 1, 'k-', lw=2.0) - ax.plot(cx[1:-4], np.array(cy[:-5]) - 2, 'k-', lw=2.0) + ax.plot(cx[2:-1], cy[1:-2], "k-", lw=2.0) + ax.plot(np.array(cx[2:-3]) + 1, cy[1:-4], "k-", lw=2.0) + ax.plot(cx[1:-2], np.array(cy[:-3]) - 1, "k-", lw=2.0) + ax.plot(cx[1:-4], np.array(cy[:-5]) - 2, "k-", lw=2.0) ax.set_xlim(-1, size) ax.set_ylim(0, size + 1) return ax - - diff --git a/src/alchemlyb/visualisation/ti_dhdl.py b/src/alchemlyb/visualisation/ti_dhdl.py index c071a97d..6dacb6dc 100644 --- a/src/alchemlyb/visualisation/ti_dhdl.py +++ b/src/alchemlyb/visualisation/ti_dhdl.py @@ -10,14 +10,14 @@ """ import matplotlib.pyplot as plt -from matplotlib.font_manager import FontProperties as FP import numpy as np +from matplotlib.font_manager import FontProperties as FP from ..postprocessors.units import get_unit_converter -def plot_ti_dhdl(dhdl_data, labels=None, colors=None, units=None, - ax=None): - '''Plot the dhdl of TI. + +def plot_ti_dhdl(dhdl_data, labels=None, colors=None, units=None, ax=None): + """Plot the dhdl of TI. Parameters ---------- @@ -55,7 +55,7 @@ def plot_ti_dhdl(dhdl_data, labels=None, colors=None, units=None, changing the figure legend. .. versionadded:: 0.4.0 - ''' + """ # Make it into a list # separate_dhdl method is used so that the input for the actual plotting # Function are a uniformed list of series object which only contains one @@ -69,7 +69,7 @@ def plot_ti_dhdl(dhdl_data, labels=None, colors=None, units=None, # Convert unit if units is None: - units = dhdl_list[0].attrs['energy_unit'] + units = dhdl_list[0].attrs["energy_unit"] new_unit = [] convert = get_unit_converter(units) @@ -80,11 +80,11 @@ def plot_ti_dhdl(dhdl_data, labels=None, colors=None, units=None, if ax is None: fig, ax = plt.subplots(figsize=(8, 6)) - ax.spines['bottom'].set_position('zero') - ax.spines['top'].set_color('none') - ax.spines['right'].set_color('none') - ax.xaxis.set_ticks_position('bottom') - ax.yaxis.set_ticks_position('left') + ax.spines["bottom"].set_position("zero") + ax.spines["top"].set_color("none") + ax.spines["right"].set_color("none") + ax.xaxis.set_ticks_position("bottom") + ax.yaxis.set_ticks_position("left") for k, spine in ax.spines.items(): spine.set_zorder(12.2) @@ -98,20 +98,24 @@ def plot_ti_dhdl(dhdl_data, labels=None, colors=None, units=None, else: if len(labels) == len(dhdl_list): lv_names2 = labels - else: # pragma: no cover + else: # pragma: no cover raise ValueError( - 'Length of labels ({}) should be the same as the number of data ({})'.format( - len(labels), len(dhdl_list))) + "Length of labels ({}) should be the same as the number of data ({})".format( + len(labels), len(dhdl_list) + ) + ) if colors is None: - colors = ['r', 'g', '#7F38EC', '#9F000F', 'b', 'y'] + colors = ["r", "g", "#7F38EC", "#9F000F", "b", "y"] else: if len(colors) >= len(dhdl_list): pass - else: # pragma: no cover + else: # pragma: no cover raise ValueError( - 'Number of colors ({}) should be larger than the number of data ({})'.format( - len(labels), len(dhdl_list))) + "Number of colors ({}) should be larger than the number of data ({})".format( + len(labels), len(dhdl_list) + ) + ) # Get the real data out xs, ndx, dx = [0], 0, 0.001 @@ -125,16 +129,22 @@ def plot_ti_dhdl(dhdl_data, labels=None, colors=None, units=None, for i in range(len(x) - 1): if i % 2 == 0: - ax.fill_between(x[i:i + 2] + ndx, 0, y[i:i + 2], - color=colors[ndx], alpha=1.0) + ax.fill_between( + x[i : i + 2] + ndx, 0, y[i : i + 2], color=colors[ndx], alpha=1.0 + ) else: - ax.fill_between(x[i:i + 2] + ndx, 0, y[i:i + 2], - color=colors[ndx], alpha=0.5) + ax.fill_between( + x[i : i + 2] + ndx, 0, y[i : i + 2], color=colors[ndx], alpha=0.5 + ) xlegend = [-100 * wnum for wnum in range(len(lv_names2))] - ax.plot(xlegend, [0 * wnum for wnum in xlegend], ls='-', - color=colors[ndx], - label=lv_names2[ndx]) + ax.plot( + xlegend, + [0 * wnum for wnum in xlegend], + ls="-", + color=colors[ndx], + label=lv_names2[ndx], + ) xs += (x + ndx).tolist()[1:] ndx += 1 @@ -159,7 +169,7 @@ def getInd(r=ri, z=[0]): if i in getInd(): xt.append(i) else: - xt.append('') + xt.append("") plt.xticks(xs[1:], xt[1:], fontsize=10) ax.yaxis.label.set_size(10) @@ -172,31 +182,46 @@ def getInd(r=ri, z=[0]): max_y *= 1.01 # Modified so that the x label won't conflict with the lambda label - min_y -= (max_y-min_y)*0.1 + min_y -= (max_y - min_y) * 0.1 ax.set_ylim(min_y, max_y) for i, j in zip(xs[1:], xt[1:]): ax.annotate( - ('%.2f' % (i - 1.0 if i > 1.0 else i) if not j == '' else ''), - xy=(i, 0), size=10, rotation=90, va='bottom', ha='center', - color='#151B54') + ("%.2f" % (i - 1.0 if i > 1.0 else i) if not j == "" else ""), + xy=(i, 0), + size=10, + rotation=90, + va="bottom", + ha="center", + color="#151B54", + ) if ndx > 1: lenticks = len(ax.get_ymajorticklabels()) - 1 - if min_y < 0: lenticks -= 1 + if min_y < 0: + lenticks -= 1 if lenticks < 5: # pragma: no cover from matplotlib.ticker import AutoMinorLocator as AML + ax.yaxis.set_minor_locator(AML()) - ax.grid(which='both', color='w', lw=0.25, axis='y', zorder=12) + ax.grid(which="both", color="w", lw=0.25, axis="y", zorder=12) ax.set_ylabel( - r'$\langle{\frac{\partial U}{\partial\lambda}}\rangle_{\lambda}$' + - '({})'.format(units), - fontsize=20, color='#151B54') - ax.annotate(r'$\mathit{\lambda}$', xy=(0, 0), xytext=(0.5, -0.05), size=18, - textcoords='axes fraction', va='top', ha='center', - color='#151B54') + r"$\langle{\frac{\partial U}{\partial\lambda}}\rangle_{\lambda}$" + + "({})".format(units), + fontsize=20, + color="#151B54", + ) + ax.annotate( + r"$\mathit{\lambda}$", + xy=(0, 0), + xytext=(0.5, -0.05), + size=18, + textcoords="axes fraction", + va="top", + ha="center", + color="#151B54", + ) lege = ax.legend(prop=FP(size=14), frameon=False, loc=1) for l in lege.legendHandles: l.set_linewidth(10) return ax - diff --git a/src/alchemlyb/workflows/__init__.py b/src/alchemlyb/workflows/__init__.py index 6b35d460..a6a156cf 100644 --- a/src/alchemlyb/workflows/__init__.py +++ b/src/alchemlyb/workflows/__init__.py @@ -1,4 +1,5 @@ __all__ = [ - 'base', + "base", ] + from .abfe import ABFE diff --git a/src/alchemlyb/workflows/abfe.py b/src/alchemlyb/workflows/abfe.py index 32e51a7c..9fef25c1 100644 --- a/src/alchemlyb/workflows/abfe.py +++ b/src/alchemlyb/workflows/abfe.py @@ -1,26 +1,31 @@ +import logging import os -from os.path import join from glob import glob -import pandas as pd -import numpy as np -import logging +from os.path import join + import matplotlib.pyplot as plt +import numpy as np +import pandas as pd from .base import WorkflowBase -from ..parsing import gmx, amber -from ..preprocessing.subsampling import decorrelate_dhdl, decorrelate_u_nk -from ..estimators import BAR, TI, FEP_ESTIMATORS, TI_ESTIMATORS +from .. import __version__ +from .. import concat +from ..convergence import forward_backward_convergence from ..estimators import AutoMBAR as MBAR -from ..visualisation import (plot_mbar_overlap_matrix, plot_ti_dhdl, - plot_dF_state, plot_convergence) +from ..estimators import BAR, TI, FEP_ESTIMATORS, TI_ESTIMATORS +from ..parsing import gmx, amber from ..postprocessors.units import get_unit_converter -from ..convergence import forward_backward_convergence -from .. import concat -from .. import __version__ +from ..preprocessing.subsampling import decorrelate_dhdl, decorrelate_u_nk +from ..visualisation import ( + plot_mbar_overlap_matrix, + plot_ti_dhdl, + plot_dF_state, + plot_convergence, +) class ABFE(WorkflowBase): - '''Workflow for absolute and relative binding free energy calculations. + """Workflow for absolute and relative binding free energy calculations. This workflow provides functionality similar to the ``alchemical-analysis.py`` script. It loads multiple input files from alchemical free energy calculations and computes the @@ -58,42 +63,50 @@ class ABFE(WorkflowBase): .. versionadded:: 1.0.0 - ''' - def __init__(self, T, units='kT', software='GROMACS', dir=os.path.curdir, - prefix='dhdl', suffix='xvg', - outdirectory=os.path.curdir): + """ + + def __init__( + self, + T, + units="kT", + software="GROMACS", + dir=os.path.curdir, + prefix="dhdl", + suffix="xvg", + outdirectory=os.path.curdir, + ): super().__init__(units, software, T, outdirectory) - self.logger = logging.getLogger('alchemlyb.workflows.ABFE') - self.logger.info('Initialise Alchemlyb ABFE Workflow') - self.logger.info(f'Alchemlyb Version: f{__version__}') - self.logger.info(f'Set Temperature to {T} K.') - self.logger.info(f'Set Software to {software}.') + self.logger = logging.getLogger("alchemlyb.workflows.ABFE") + self.logger.info("Initialise Alchemlyb ABFE Workflow") + self.logger.info(f"Alchemlyb Version: f{__version__}") + self.logger.info(f"Set Temperature to {T} K.") + self.logger.info(f"Set Software to {software}.") self.update_units(units) - self.logger.info(f'Finding files with prefix: {prefix}, suffix: ' - f'{suffix} under directory {dir} produced by ' - f'{software}') - self.file_list = glob(dir + '/**/' + prefix + '*' + suffix, - recursive=True) + self.logger.info( + f"Finding files with prefix: {prefix}, suffix: " + f"{suffix} under directory {dir} produced by " + f"{software}" + ) + self.file_list = glob(dir + "/**/" + prefix + "*" + suffix, recursive=True) - self.logger.info(f'Found {len(self.file_list)} xvg files.') - self.logger.info("Unsorted file list: \n %s", '\n'.join( - self.file_list)) + self.logger.info(f"Found {len(self.file_list)} xvg files.") + self.logger.info("Unsorted file list: \n %s", "\n".join(self.file_list)) - if software == 'GROMACS': - self.logger.info(f'Using {software} parser to read the data.') + if software == "GROMACS": + self.logger.info(f"Using {software} parser to read the data.") self._extract_u_nk = gmx.extract_u_nk self._extract_dHdl = gmx.extract_dHdl - elif software == 'AMBER': + elif software == "AMBER": self._extract_u_nk = amber.extract_u_nk self._extract_dHdl = amber.extract_dHdl else: - raise NotImplementedError(f'{software} parser not found.') + raise NotImplementedError(f"{software} parser not found.") def read(self, read_u_nk=True, read_dHdl=True): - '''Read the u_nk and dHdL data from the + """Read the u_nk and dHdL data from the :attr:`~alchemlyb.workflows.ABFE.file_list` Parameters @@ -109,7 +122,7 @@ def read(self, read_u_nk=True, read_dHdl=True): A list of :class:`pandas.DataFrame` of u_nk. dHdl_list : list A list of :class:`pandas.DataFrame` of dHdl. - ''' + """ self.u_nk_sample_list = None self.dHdl_sample_list = None @@ -119,46 +132,46 @@ def read(self, read_u_nk=True, read_dHdl=True): if read_u_nk: try: u_nk = self._extract_u_nk(file, T=self.T) - self.logger.info( - f'Reading {len(u_nk)} lines of u_nk from {file}') + self.logger.info(f"Reading {len(u_nk)} lines of u_nk from {file}") u_nk_list.append(u_nk) except Exception as exc: - msg = f'Error reading u_nk from {file}.' + msg = f"Error reading u_nk from {file}." self.logger.error(msg) raise OSError(msg) from exc if read_dHdl: try: dhdl = self._extract_dHdl(file, T=self.T) - self.logger.info( - f'Reading {len(dhdl)} lines of dhdl from {file}') + self.logger.info(f"Reading {len(dhdl)} lines of dhdl from {file}") dHdl_list.append(dhdl) except Exception as exc: - msg = f'Error reading dHdl from {file}.' + msg = f"Error reading dHdl from {file}." self.logger.error(msg) raise OSError(msg) from exc # Sort the files according to the state if read_u_nk: - self.logger.info('Sort files according to the u_nk.') + self.logger.info("Sort files according to the u_nk.") column_names = u_nk_list[0].columns.values.tolist() - index_list = sorted(range(len(self.file_list)), - key=lambda x: column_names.index( - u_nk_list[x].reset_index( - 'time').index.values[0])) + index_list = sorted( + range(len(self.file_list)), + key=lambda x: column_names.index( + u_nk_list[x].reset_index("time").index.values[0] + ), + ) elif read_dHdl: - self.logger.info('Sort files according to the dHdl.') - index_list = sorted(range(len(self.file_list)), - key=lambda x: - dHdl_list[x].reset_index( - 'time').index.values[0]) + self.logger.info("Sort files according to the dHdl.") + index_list = sorted( + range(len(self.file_list)), + key=lambda x: dHdl_list[x].reset_index("time").index.values[0], + ) else: self.u_nk_list = [] self.dHdl_list = [] return self.file_list = [self.file_list[i] for i in index_list] - self.logger.info("Sorted file list: \n%s", '\n'.join(self.file_list)) + self.logger.info("Sorted file list: \n%s", "\n".join(self.file_list)) if read_u_nk: self.u_nk_list = [u_nk_list[i] for i in index_list] else: @@ -169,11 +182,19 @@ def read(self, read_u_nk=True, read_dHdl=True): else: self.dHdl_list = [] - - def run(self, skiptime=0, uncorr='dE', threshold=50, - estimators=('MBAR', 'BAR', 'TI'), overlap='O_MBAR.pdf', - breakdown=True, forwrev=None, *args, **kwargs): - ''' The method for running the automatic analysis. + def run( + self, + skiptime=0, + uncorr="dE", + threshold=50, + estimators=("MBAR", "BAR", "TI"), + overlap="O_MBAR.pdf", + breakdown=True, + forwrev=None, + *args, + **kwargs, + ): + """The method for running the automatic analysis. Parameters ---------- @@ -214,29 +235,32 @@ def run(self, skiptime=0, uncorr='dE', threshold=50, The summary of the convergence results. See :func:`~alchemlyb.convergence.forward_backward_convergence` for further explanation. - ''' + """ use_FEP = False use_TI = False if estimators is not None: if isinstance(estimators, str): - estimators = [estimators, ] + estimators = [ + estimators, + ] for estimator in estimators: if estimator in FEP_ESTIMATORS: use_FEP = True elif estimator in TI_ESTIMATORS: use_TI = True else: - msg = f"Estimator {estimator} is not supported. Choose one from " \ - f"{FEP_ESTIMATORS + TI_ESTIMATORS}." + msg = ( + f"Estimator {estimator} is not supported. Choose one from " + f"{FEP_ESTIMATORS + TI_ESTIMATORS}." + ) self.logger.error(msg) raise ValueError(msg) self.read(use_FEP, use_TI) if uncorr is not None: - self.preprocess(skiptime=skiptime, uncorr=uncorr, - threshold=threshold) + self.preprocess(skiptime=skiptime, uncorr=uncorr, threshold=threshold) if estimators is not None: self.estimate(estimators) self.generate_result() @@ -251,31 +275,30 @@ def run(self, skiptime=0, uncorr='dE', threshold=50, plt.close(ax.figure) fig = self.plot_dF_state() plt.close(fig) - fig = self.plot_dF_state(dF_state='dF_state_long.pdf', - orientation='landscape') + fig = self.plot_dF_state( + dF_state="dF_state_long.pdf", orientation="landscape" + ) plt.close(fig) if forwrev is not None: - ax = self.check_convergence(forwrev, estimator='MBAR', - dF_t='dF_t.pdf') + ax = self.check_convergence(forwrev, estimator="MBAR", dF_t="dF_t.pdf") plt.close(ax.figure) - def update_units(self, units=None): - '''Update the unit. + """Update the unit. Parameters ---------- units : {'kcal/mol', 'kJ/mol', 'kT'} The unit used for printing and plotting results. - ''' + """ if units is not None: - self.logger.info(f'Set unit to {units}.') + self.logger.info(f"Set unit to {units}.") self.units = units or None - def preprocess(self, skiptime=0, uncorr='dE', threshold=50): - '''Preprocess the data by removing the equilibration time and + def preprocess(self, skiptime=0, uncorr="dE", threshold=50): + """Preprocess the data by removing the equilibration time and decorrelate the date. Parameters @@ -296,54 +319,65 @@ def preprocess(self, skiptime=0, uncorr='dE', threshold=50): The list of u_nk after decorrelation. dHdl_sample_list : list The list of dHdl after decorrelation. - ''' - self.logger.info(f'Start preprocessing with skiptime of {skiptime} ' - f'uncorrelation method of {uncorr} and threshold of ' - f'{threshold}') + """ + self.logger.info( + f"Start preprocessing with skiptime of {skiptime} " + f"uncorrelation method of {uncorr} and threshold of " + f"{threshold}" + ) if len(self.u_nk_list) > 0: self.logger.info( - f'Processing the u_nk data set with skiptime of {skiptime}.') + f"Processing the u_nk data set with skiptime of {skiptime}." + ) self.u_nk_sample_list = [] for index, u_nk in enumerate(self.u_nk_list): # Find the starting frame - u_nk = u_nk[u_nk.index.get_level_values('time') >= skiptime] + u_nk = u_nk[u_nk.index.get_level_values("time") >= skiptime] subsample = decorrelate_u_nk(u_nk, uncorr, remove_burnin=True) if len(subsample) < threshold: - self.logger.warning(f'Number of u_nk {len(subsample)} ' - f'for state {index} is less than the ' - f'threshold {threshold}.') - self.logger.info(f'Take all the u_nk for state {index}.') + self.logger.warning( + f"Number of u_nk {len(subsample)} " + f"for state {index} is less than the " + f"threshold {threshold}." + ) + self.logger.info(f"Take all the u_nk for state {index}.") self.u_nk_sample_list.append(u_nk) else: - self.logger.info(f'Take {len(subsample)} uncorrelated ' - f'u_nk for state {index}.') + self.logger.info( + f"Take {len(subsample)} uncorrelated " + f"u_nk for state {index}." + ) self.u_nk_sample_list.append(subsample) else: - self.logger.info('No u_nk data being subsampled') + self.logger.info("No u_nk data being subsampled") if len(self.dHdl_list) > 0: self.dHdl_sample_list = [] for index, dHdl in enumerate(self.dHdl_list): - dHdl = dHdl[dHdl.index.get_level_values('time') >= skiptime] + dHdl = dHdl[dHdl.index.get_level_values("time") >= skiptime] subsample = decorrelate_dhdl(dHdl, remove_burnin=True) if len(subsample) < threshold: - self.logger.warning(f'Number of dHdl {len(subsample)} for ' - f'state {index} is less than the ' - f'threshold {threshold}.') - self.logger.info(f'Take all the dHdl for state {index}.') + self.logger.warning( + f"Number of dHdl {len(subsample)} for " + f"state {index} is less than the " + f"threshold {threshold}." + ) + self.logger.info(f"Take all the dHdl for state {index}.") self.dHdl_sample_list.append(dHdl) else: - self.logger.info(f'Take {len(subsample)} uncorrelated ' - f'dHdl for state {index}.') + self.logger.info( + f"Take {len(subsample)} uncorrelated " + f"dHdl for state {index}." + ) self.dHdl_sample_list.append(subsample) else: - self.logger.info('No dHdl data being subsampled') + self.logger.info("No dHdl data being subsampled") - def estimate(self, estimators=('MBAR', 'BAR', 'TI'), **kwargs): - '''Estimate the free energy using the selected estimator. + def estimate(self, estimators=("MBAR", "BAR", "TI"), **kwargs): + """Estimate the free energy using the selected estimator. Parameters ---------- @@ -368,10 +402,10 @@ def estimate(self, estimators=('MBAR', 'BAR', 'TI'), **kwargs): behavior of :class:`~alchemlyb.estimators.MBAR`. (:code:`estimate(estimators='MBAR', method='adaptive')`) - ''' + """ # Make estimators into a tuple if isinstance(estimators, str): - estimators = (estimators, ) + estimators = (estimators,) for estimator in estimators: if estimator not in (FEP_ESTIMATORS + TI_ESTIMATORS): @@ -379,42 +413,38 @@ def estimate(self, estimators=('MBAR', 'BAR', 'TI'), **kwargs): self.logger.error(msg) raise ValueError(msg) - self.logger.info( - f"Start running estimator: {','.join(estimators)}.") + self.logger.info(f"Start running estimator: {','.join(estimators)}.") self.estimator = {} # Use unprocessed data if preprocess is not performed. - if 'TI' in estimators: + if "TI" in estimators: if self.dHdl_sample_list is not None: dHdl = concat(self.dHdl_sample_list) else: dHdl = concat(self.dHdl_list) - self.logger.warning('dHdl has not been preprocessed.') - self.logger.info( - f'A total {len(dHdl)} lines of dHdl is used.') + self.logger.warning("dHdl has not been preprocessed.") + self.logger.info(f"A total {len(dHdl)} lines of dHdl is used.") - if 'BAR' in estimators or 'MBAR' in estimators: + if "BAR" in estimators or "MBAR" in estimators: if self.u_nk_sample_list is not None: u_nk = concat(self.u_nk_sample_list) else: u_nk = concat(self.u_nk_list) - self.logger.warning('u_nk has not been preprocessed.') - self.logger.info( - f'A total {len(u_nk)} lines of u_nk is used.') + self.logger.warning("u_nk has not been preprocessed.") + self.logger.info(f"A total {len(u_nk)} lines of u_nk is used.") for estimator in estimators: - if estimator == 'MBAR': - self.logger.info('Run MBAR estimator.') + if estimator == "MBAR": + self.logger.info("Run MBAR estimator.") self.estimator[estimator] = MBAR(**kwargs).fit(u_nk) - elif estimator == 'BAR': - self.logger.info('Run BAR estimator.') + elif estimator == "BAR": + self.logger.info("Run BAR estimator.") self.estimator[estimator] = BAR(**kwargs).fit(u_nk) - elif estimator == 'TI': - self.logger.info('Run TI estimator.') + elif estimator == "TI": + self.logger.info("Run TI estimator.") self.estimator[estimator] = TI(**kwargs).fit(dHdl) - def generate_result(self): - '''Summarise the result into a dataframe. + """Summarise the result into a dataframe. Returns ------- @@ -460,38 +490,37 @@ def generate_result(self): ---------- summary : Dataframe The summary of the free energy estimate. - ''' + """ # Write estimate - self.logger.info('Summarise the estimate into a dataframe.') + self.logger.info("Summarise the estimate into a dataframe.") # Make the header name - self.logger.info('Generate the row names.') + self.logger.info("Generate the row names.") estimator_names = list(self.estimator.keys()) num_states = len(self.estimator[estimator_names[0]].states_) - data_dict = {'name': [], - 'state': []} + data_dict = {"name": [], "state": []} for i in range(num_states - 1): - data_dict['name'].append(str(i) + ' -- ' + str(i+1)) - data_dict['state'].append('States') + data_dict["name"].append(str(i) + " -- " + str(i + 1)) + data_dict["state"].append("States") try: u_nk = self.u_nk_list[0] - stages = u_nk.reset_index('time').index.names - self.logger.info('use the stage name from u_nk') + stages = u_nk.reset_index("time").index.names + self.logger.info("use the stage name from u_nk") except: dHdl = self.dHdl_list[0] - stages = dHdl.reset_index('time').index.names - self.logger.info('use the stage name from dHdl') + stages = dHdl.reset_index("time").index.names + self.logger.info("use the stage name from dHdl") for stage in stages: - data_dict['name'].append(stage.split('-')[0]) - data_dict['state'].append('Stages') - data_dict['name'].append('TOTAL') - data_dict['state'].append('Stages') + data_dict["name"].append(stage.split("-")[0]) + data_dict["state"].append("Stages") + data_dict["name"].append("TOTAL") + data_dict["state"].append("Stages") col_names = [] for estimator_name, estimator in self.estimator.items(): - self.logger.info(f'Read the results from estimator {estimator_name}') + self.logger.info(f"Read the results from estimator {estimator_name}") # Do the unit conversion delta_f_ = estimator.delta_f_ @@ -499,26 +528,26 @@ def generate_result(self): # Write the estimator header col_names.append(estimator_name) - col_names.append(estimator_name + '_Error') + col_names.append(estimator_name + "_Error") data_dict[estimator_name] = [] - data_dict[estimator_name + '_Error'] = [] + data_dict[estimator_name + "_Error"] = [] for index in range(1, num_states): - data_dict[estimator_name].append( - delta_f_.iloc[index-1, index]) - data_dict[estimator_name + '_Error'].append( - d_delta_f_.iloc[index - 1, index]) + data_dict[estimator_name].append(delta_f_.iloc[index - 1, index]) + data_dict[estimator_name + "_Error"].append( + d_delta_f_.iloc[index - 1, index] + ) - self.logger.info(f'Generate the staged result from estimator {estimator_name}') + self.logger.info( + f"Generate the staged result from estimator {estimator_name}" + ) for index, stage in enumerate(stages): if len(stages) == 1: start = 0 end = len(estimator.states_) - 1 else: # Get the start and the end of the state - lambda_min = min( - [state[index] for state in estimator.states_]) - lambda_max = max( - [state[index] for state in estimator.states_]) + lambda_min = min([state[index] for state in estimator.states_]) + lambda_max = max([state[index] for state in estimator.states_]) if lambda_min == lambda_max: # Deal with the case where a certain lambda is used but # not perturbed @@ -529,35 +558,39 @@ def generate_result(self): start = list(reversed(states)).index(lambda_min) start = num_states - start - 1 end = states.index(lambda_max) - self.logger.info( - f'Stage {stage} is from state {start} to state {end}.') + self.logger.info(f"Stage {stage} is from state {start} to state {end}.") # This assumes that the indexes are sorted as the # preprocessing should sort the index of the df. result = delta_f_.iloc[start, end] - if estimator_name != 'BAR': + if estimator_name != "BAR": error = d_delta_f_.iloc[start, end] else: - error = np.sqrt(sum( - [d_delta_f_.iloc[start, start+1]**2 - for i in range(start, end + 1)])) + error = np.sqrt( + sum( + [ + d_delta_f_.iloc[start, start + 1] ** 2 + for i in range(start, end + 1) + ] + ) + ) data_dict[estimator_name].append(result) - data_dict[estimator_name + '_Error'].append(error) + data_dict[estimator_name + "_Error"].append(error) # Total result # This assumes that the indexes are sorted as the # preprocessing should sort the index of the df. result = delta_f_.iloc[0, -1] - if estimator_name != 'BAR': + if estimator_name != "BAR": error = d_delta_f_.iloc[0, -1] else: - error = np.sqrt(sum( - [d_delta_f_.iloc[i, i + 1] ** 2 - for i in range(num_states - 1)])) + error = np.sqrt( + sum([d_delta_f_.iloc[i, i + 1] ** 2 for i in range(num_states - 1)]) + ) data_dict[estimator_name].append(result) - data_dict[estimator_name + '_Error'].append(error) + data_dict[estimator_name + "_Error"].append(error) summary = pd.DataFrame.from_dict(data_dict) - summary = summary.set_index(['state', 'name']) + summary = summary.set_index(["state", "name"]) # Make sure that the columns are in the right order summary = summary[col_names] # Remove the name of the index column to make it prettier @@ -567,11 +600,11 @@ def generate_result(self): converter = get_unit_converter(self.units) summary = converter(summary) self.summary = summary - self.logger.info(f'Write results:\n{summary.to_string()}') + self.logger.info(f"Write results:\n{summary.to_string()}") return summary - def plot_overlap_matrix(self, overlap='O_MBAR.pdf', ax=None): - '''Plot the overlap matrix for MBAR estimator using + def plot_overlap_matrix(self, overlap="O_MBAR.pdf", ax=None): + """Plot the overlap matrix for MBAR estimator using :func:`~alchemlyb.visualisation.plot_mbar_overlap_matrix`. Parameters @@ -586,21 +619,20 @@ def plot_overlap_matrix(self, overlap='O_MBAR.pdf', ax=None): ------- matplotlib.axes.Axes An axes with the overlap matrix drawn. - ''' - self.logger.info('Plot overlap matrix.') - if 'MBAR' in self.estimator: - ax = plot_mbar_overlap_matrix(self.estimator['MBAR'].overlap_matrix, - ax=ax) + """ + self.logger.info("Plot overlap matrix.") + if "MBAR" in self.estimator: + ax = plot_mbar_overlap_matrix(self.estimator["MBAR"].overlap_matrix, ax=ax) ax.figure.savefig(join(self.out, overlap)) - self.logger.info(f'Plot overlap matrix to {self.out} under {overlap}.') + self.logger.info(f"Plot overlap matrix to {self.out} under {overlap}.") return ax else: - self.logger.warning('MBAR estimator not found. ' - 'Overlap matrix not plotted.') + self.logger.warning( + "MBAR estimator not found. " "Overlap matrix not plotted." + ) - def plot_ti_dhdl(self, dhdl_TI='dhdl_TI.pdf', labels=None, colors=None, - ax=None): - '''Plot the dHdl for TI estimator using + def plot_ti_dhdl(self, dhdl_TI="dhdl_TI.pdf", labels=None, colors=None, ax=None): + """Plot the dHdl for TI estimator using :func:`~alchemlyb.visualisation.plot_ti_dhdl`. Parameters @@ -620,20 +652,31 @@ def plot_ti_dhdl(self, dhdl_TI='dhdl_TI.pdf', labels=None, colors=None, ------- matplotlib.axes.Axes An axes with the TI dhdl drawn. - ''' - self.logger.info('Plot TI dHdl.') - if 'TI' in self.estimator: - ax = plot_ti_dhdl(self.estimator['TI'], units=self.units, - labels=labels, colors=colors, ax=ax) + """ + self.logger.info("Plot TI dHdl.") + if "TI" in self.estimator: + ax = plot_ti_dhdl( + self.estimator["TI"], + units=self.units, + labels=labels, + colors=colors, + ax=ax, + ) ax.figure.savefig(join(self.out, dhdl_TI)) - self.logger.info(f'Plot TI dHdl to {dhdl_TI} under {self.out}.') + self.logger.info(f"Plot TI dHdl to {dhdl_TI} under {self.out}.") return ax else: - raise ValueError('No TI data available in estimators.') - - def plot_dF_state(self, dF_state='dF_state.pdf', labels=None, colors=None, - orientation='portrait', nb=10): - '''Plot the dF states using + raise ValueError("No TI data available in estimators.") + + def plot_dF_state( + self, + dF_state="dF_state.pdf", + labels=None, + colors=None, + orientation="portrait", + nb=10, + ): + """Plot the dF states using :func:`~alchemlyb.visualisation.plot_dF_state`. Parameters @@ -653,18 +696,24 @@ def plot_dF_state(self, dF_state='dF_state.pdf', labels=None, colors=None, ------- matplotlib.figure.Figure An Figure with the dF states drawn. - ''' - self.logger.info('Plot dF states.') - fig = plot_dF_state(self.estimator.values(), labels=labels, colors=colors, - units=self.units, - orientation=orientation, nb=nb) + """ + self.logger.info("Plot dF states.") + fig = plot_dF_state( + self.estimator.values(), + labels=labels, + colors=colors, + units=self.units, + orientation=orientation, + nb=nb, + ) fig.savefig(join(self.out, dF_state)) - self.logger.info(f'Plot dF state to {dF_state} under {self.out}.') + self.logger.info(f"Plot dF state to {dF_state} under {self.out}.") return fig - def check_convergence(self, forwrev, estimator='MBAR', dF_t='dF_t.pdf', - ax=None, **kwargs): - '''Compute the forward and backward convergence using + def check_convergence( + self, forwrev, estimator="MBAR", dF_t="dF_t.pdf", ax=None, **kwargs + ): + """Compute the forward and backward convergence using :func:`~alchemlyb.convergence.forward_backward_convergence`and plot with :func:`~alchemlyb.visualisation.plot_convergence`. @@ -701,59 +750,63 @@ def check_convergence(self, forwrev, estimator='MBAR', dF_t='dF_t.pdf', of :class:`~alchemlyb.estimators.MBAR`. (:code:`check_convergence(10, estimator='MBAR', method='adaptive')`) - ''' - self.logger.info('Start convergence analysis.') - self.logger.info('Checking data availability.') + """ + self.logger.info("Start convergence analysis.") + self.logger.info("Checking data availability.") if estimator in FEP_ESTIMATORS: if self.u_nk_sample_list is not None: u_nk_list = self.u_nk_sample_list - self.logger.info('Subsampled u_nk is available.') + self.logger.info("Subsampled u_nk is available.") else: if self.u_nk_list is not None: u_nk_list = self.u_nk_list - self.logger.info('Subsampled u_nk not available, ' - 'use original data instead.') + self.logger.info( + "Subsampled u_nk not available, " "use original data instead." + ) else: - msg = f"u_nk is needed for the f{estimator} estimator. " \ - f"If the dataset only has dHdl, " \ - f"run ABFE.check_convergence(estimator='TI') to " \ - f"use a TI estimator." + msg = ( + f"u_nk is needed for the f{estimator} estimator. " + f"If the dataset only has dHdl, " + f"run ABFE.check_convergence(estimator='TI') to " + f"use a TI estimator." + ) self.logger.error(msg) raise ValueError(msg) - convergence = forward_backward_convergence(u_nk_list, - estimator=estimator, - num=forwrev, **kwargs) + convergence = forward_backward_convergence( + u_nk_list, estimator=estimator, num=forwrev, **kwargs + ) elif estimator in TI_ESTIMATORS: - self.logger.warning('No valid FEP estimator or dataset found. ' - 'Fallback to TI.') + self.logger.warning( + "No valid FEP estimator or dataset found. " "Fallback to TI." + ) if self.dHdl_sample_list is not None: dHdl_list = self.dHdl_sample_list - self.logger.info('Subsampled dHdl is available.') + self.logger.info("Subsampled dHdl is available.") else: if self.dHdl_list is not None: dHdl_list = self.dHdl_list - self.logger.info('Subsampled dHdl not available, ' - 'use original data instead.') + self.logger.info( + "Subsampled dHdl not available, " "use original data instead." + ) else: - self.logger.error( - f'dHdl is needed for the f{estimator} estimator.') - raise ValueError( - f'dHdl is needed for the f{estimator} estimator.') - convergence = forward_backward_convergence(dHdl_list, - estimator=estimator, - num=forwrev, **kwargs) + self.logger.error(f"dHdl is needed for the f{estimator} estimator.") + raise ValueError(f"dHdl is needed for the f{estimator} estimator.") + convergence = forward_backward_convergence( + dHdl_list, estimator=estimator, num=forwrev, **kwargs + ) else: - msg = f"Estimator {estimator} is not supported. Choose one from " \ - f"{FEP_ESTIMATORS+TI_ESTIMATORS}." + msg = ( + f"Estimator {estimator} is not supported. Choose one from " + f"{FEP_ESTIMATORS + TI_ESTIMATORS}." + ) self.logger.error(msg) raise ValueError(msg) self.convergence = get_unit_converter(self.units)(convergence) - self.logger.info(f'Plot convergence analysis to {dF_t} under {self.out}.') + self.logger.info(f"Plot convergence analysis to {dF_t} under {self.out}.") - ax = plot_convergence(self.convergence, - units=self.units, ax=ax) + ax = plot_convergence(self.convergence, units=self.units, ax=ax) ax.figure.savefig(join(self.out, dF_t)) return ax diff --git a/src/alchemlyb/workflows/base.py b/src/alchemlyb/workflows/base.py index 1728f7f0..1b4cbe41 100644 --- a/src/alchemlyb/workflows/base.py +++ b/src/alchemlyb/workflows/base.py @@ -2,7 +2,8 @@ import pandas as pd -class WorkflowBase(): + +class WorkflowBase: """The base class for the Workflow. This is the base class for the creation of new Workflow. The @@ -37,9 +38,10 @@ class WorkflowBase(): .. versionadded:: 0.7.0 """ - def __init__(self, units='kT', software='Gromacs', T=298, out='./', *args, - **kwargs): + def __init__( + self, units="kT", software="Gromacs", T=298, out="./", *args, **kwargs + ): self.T = T self.software = software self.unit = units @@ -47,7 +49,7 @@ def __init__(self, units='kT', software='Gromacs', T=298, out='./', *args, self.out = out def run(self, *args, **kwargs): - """ Run the workflow in an automatic fashion. + """Run the workflow in an automatic fashion. This method would execute the :func:`~alchemlyb.workflows.WorkflowBase.read`, @@ -88,7 +90,7 @@ def run(self, *args, **kwargs): self.plot(*args, **kwargs) def read(self, *args, **kwargs): - """ The function that reads the files in `file_list` and parse them + """The function that reads the files in `file_list` and parse them into u_nk and dHdl files. Attributes @@ -104,7 +106,7 @@ def read(self, *args, **kwargs): self.dHdl_list = [] def preprocess(self, *args, **kwargs): - """ The function that subsample the u_nk and dHdl in `u_nk_list` and + """The function that subsample the u_nk and dHdl in `u_nk_list` and `dHdl_list`. Attributes @@ -120,7 +122,7 @@ def preprocess(self, *args, **kwargs): self.u_nk_sample_list = [] def estimate(self, *args, **kwargs): - """ The function that runs the estimator based on `u_nk_sample_list` + """The function that runs the estimator based on `u_nk_sample_list` and `dHdl_sample_list`. Attributes @@ -133,7 +135,7 @@ def estimate(self, *args, **kwargs): self.result = pd.DataFrame() def check_convergence(self, *args, **kwargs): - """ The function for doing convergence analysis. + """The function for doing convergence analysis. Attributes ---------- @@ -145,7 +147,5 @@ def check_convergence(self, *args, **kwargs): self.convergence = pd.DataFrame() def plot(self, *args, **kwargs): - """ The function for producing any plots. - - """ + """The function for producing any plots.""" pass From 26a48984a4cf2d6eb88afdd350758b011e17cb30 Mon Sep 17 00:00:00 2001 From: "William (Zhiyi) Wu" Date: Tue, 6 Dec 2022 10:51:23 +0000 Subject: [PATCH 21/21] fix --- src/alchemlyb/tests/test_preprocessing.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/alchemlyb/tests/test_preprocessing.py b/src/alchemlyb/tests/test_preprocessing.py index da54cf5d..ea8231ab 100644 --- a/src/alchemlyb/tests/test_preprocessing.py +++ b/src/alchemlyb/tests/test_preprocessing.py @@ -28,13 +28,11 @@ def _check_data_is_outside_bounds(data, lower, upper): assert any(data.reset_index()["time"] > upper) - @pytest.fixture() def dHdl(gmx_benzene_Coulomb_dHdl): return gmx_benzene_Coulomb_dHdl[0] - @pytest.fixture() def u_nk(gmx_benzene_Coulomb_u_nk): return gmx_benzene_Coulomb_u_nk[0] @@ -256,7 +254,7 @@ def test_conservative(self, dataloader, size, conservative, request): def test_raise_ValueError_for_mismatched_data(self, dataloader, end, step, request): data = request.getfixturevalue(dataloader) with pytest.raises(ValueError): - self.slicer(data, series=data[:end:step]) + self.slicer(data, series=data["fep"][:end:step]) @pytest.mark.parametrize( ("dataloader", "lower", "upper"),