Skip to content

Commit

Permalink
Parallel read and preprocess the data in ABFE workflow (#371)
Browse files Browse the repository at this point in the history
- Use joblib to parallelise the read and preprocess in ABFE workflow
- add joblib as dependency
- new kwarg n_jobs=1 for run() (and read()/preprocess()); default is n_jobs=1,
  which maintains the previous behavior
- add tests
- update CHANGES (start 2.5.0)
  • Loading branch information
xiki-tempula authored Oct 6, 2024
1 parent 0093905 commit 30ecce0
Show file tree
Hide file tree
Showing 7 changed files with 172 additions and 27 deletions.
8 changes: 8 additions & 0 deletions CHANGES
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@ The rules for this file:
* accompany each entry with github issue/PR number (Issue #xyz)
* release numbers follow "Semantic Versioning" https://semver.org

**/**/**** xiki-tempula

* 2.5.0

Enhancements
- Parallelise read and preprocess for ABFE workflow. (PR #371)


09/19/2024 orbeckst, jaclark5

* 2.4.1
Expand Down
1 change: 1 addition & 0 deletions devtools/conda-envs/test_env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ dependencies:
- matplotlib>=3.7
- loguru
- pyarrow
- joblib

# Testing
- pytest
Expand Down
37 changes: 37 additions & 0 deletions docs/workflows/alchemlyb.workflows.ABFE.rst
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,43 @@ to the data generated at each stage of the analysis. ::
>>> # Convergence analysis
>>> workflow.check_convergence(10, dF_t='dF_t.pdf')

Parallelisation of Data Reading and Decorrelation
-------------------------------------------------

The estimation step of the workflow is parallelized using JAX. However, the
reading and decorrelation stages can be parallelized using `joblib`. This is
achieved by passing the number of jobs to run in parallel via the `n_jobs`
parameter to the following methods:

- :meth:`~alchemlyb.workflows.ABFE.read`
- :meth:`~alchemlyb.workflows.ABFE.preprocess`

To enable parallel execution, specify the `n_jobs` parameter. Setting
`n_jobs=-1` allows the use of all available resources. ::

>>> workflow = ABFE(units='kcal/mol', software='GROMACS', dir=dir,
>>> prefix='dhdl', suffix='xvg', T=298, outdirectory='./')
>>> workflow.read(n_jobs=-1)
>>> workflow.preprocess(n_jobs=-1)

In a fully automated mode, the `n_jobs=-1` parameter can be passed directly to
the :meth:`~alchemlyb.workflows.ABFE.run` method. This will implicitly
parallelise the reading and decorrelation stages. ::

>>> workflow = ABFE(units='kcal/mol', software='GROMACS', dir=dir,
>>> prefix='dhdl', suffix='xvg', T=298, outdirectory='./')
>>> workflow.run(n_jobs=-1)

While the default `joblib` settings are suitable for most environments, you
can customize the parallelisation backend depending on the infrastructure. For
example, using the threading backend can be specified as follows. ::

>>> import joblib
>>> workflow = ABFE(units='kcal/mol', software='GROMACS', dir=dir,
>>> prefix='dhdl', suffix='xvg', T=298, outdirectory='./')
>>> with joblib.parallel_config(backend="threading"):
>>> workflow.run(n_jobs=-1)

API Reference
-------------
.. autoclass:: alchemlyb.workflows.ABFE
Expand Down
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ dependencies:
- pyarrow
- matplotlib>=3.7
- loguru
- joblib
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ dependencies = [
"matplotlib>=3.7",
"loguru",
"pyarrow",
"joblib",
]


Expand Down
69 changes: 67 additions & 2 deletions src/alchemlyb/tests/test_workflow_ABFE.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os

import numpy as np
import pandas as pd
import pytest
from alchemtest.amber import load_bace_example
from alchemtest.gmx import load_ABFE
import joblib

import alchemlyb.parsing.amber
from alchemlyb.workflows.abfe import ABFE
Expand Down Expand Up @@ -88,15 +90,23 @@ def test_single_estimator(self, workflow, monkeypatch):
monkeypatch.setattr(workflow, "dHdl_sample_list", [])
monkeypatch.setattr(workflow, "estimator", dict())
workflow.run(
uncorr=None, estimators="MBAR", overlap=None, breakdown=True, forwrev=None
uncorr=None,
estimators="MBAR",
overlap=None,
breakdown=True,
forwrev=None,
)
assert "MBAR" in workflow.estimator

@pytest.mark.parametrize("forwrev", [None, False, 0])
def test_no_forwrev(self, workflow, monkeypatch, forwrev):
monkeypatch.setattr(workflow, "convergence", None)
workflow.run(
uncorr=None, estimators=None, overlap=None, breakdown=None, forwrev=forwrev
uncorr=None,
estimators=None,
overlap=None,
breakdown=None,
forwrev=forwrev,
)
assert workflow.convergence is None

Expand Down Expand Up @@ -445,3 +455,58 @@ def test_summary(self, workflow):
"""Test if if the summary is right."""
summary = workflow.generate_result()
assert np.isclose(summary["BAR"]["Stages"]["TOTAL"], 1.40405980473, 0.1)


class TestParallel:
@pytest.fixture(scope="class")
def workflow(self, tmp_path_factory):
outdir = tmp_path_factory.mktemp("out")
(outdir / "dhdl_00.xvg").symlink_to(load_ABFE()["data"]["complex"][0])
(outdir / "dhdl_01.xvg").symlink_to(load_ABFE()["data"]["complex"][1])
workflow = ABFE(
units="kcal/mol",
software="GROMACS",
dir=str(outdir),
prefix="dhdl",
suffix="xvg",
T=310,
)
workflow.read()
workflow.preprocess()
return workflow

@pytest.fixture(scope="class")
def parallel_workflow(self, tmp_path_factory):
outdir = tmp_path_factory.mktemp("out")
(outdir / "dhdl_00.xvg").symlink_to(load_ABFE()["data"]["complex"][0])
(outdir / "dhdl_01.xvg").symlink_to(load_ABFE()["data"]["complex"][1])
workflow = ABFE(
units="kcal/mol",
software="GROMACS",
dir=str(outdir),
prefix="dhdl",
suffix="xvg",
T=310,
)
with joblib.parallel_config(backend="threading"):
# The default backend is "loky", which is more robust but somehow didn't
# play well with pytest, but "loky" is perfectly fine outside pytest.
workflow.read(n_jobs=2)
workflow.preprocess(n_jobs=2)
return workflow

def test_read(self, workflow, parallel_workflow):
pd.testing.assert_frame_equal(
workflow.u_nk_list[0], parallel_workflow.u_nk_list[0]
)
pd.testing.assert_frame_equal(
workflow.u_nk_list[1], parallel_workflow.u_nk_list[1]
)

def test_preprocess(self, workflow, parallel_workflow):
pd.testing.assert_frame_equal(
workflow.u_nk_sample_list[0], parallel_workflow.u_nk_sample_list[0]
)
pd.testing.assert_frame_equal(
workflow.u_nk_sample_list[1], parallel_workflow.u_nk_sample_list[1]
)
82 changes: 57 additions & 25 deletions src/alchemlyb/workflows/abfe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import joblib
from loguru import logger

from .base import WorkflowBase
Expand Down Expand Up @@ -115,7 +116,7 @@ def __init__(
else:
raise NotImplementedError(f"{software} parser not found.")

def read(self, read_u_nk=True, read_dHdl=True):
def read(self, read_u_nk=True, read_dHdl=True, n_jobs=1):
"""Read the u_nk and dHdL data from the
:attr:`~alchemlyb.workflows.ABFE.file_list`
Expand All @@ -125,6 +126,9 @@ def read(self, read_u_nk=True, read_dHdl=True):
Whether to read the u_nk.
read_dHdl : bool
Whether to read the dHdl.
n_jobs : int
Number of parallel workers to use for reading the data.
(-1 means using all the threads)
Attributes
----------
Expand All @@ -136,29 +140,43 @@ def read(self, read_u_nk=True, read_dHdl=True):
self.u_nk_sample_list = None
self.dHdl_sample_list = None

u_nk_list = []
dHdl_list = []
for file in self.file_list:
if read_u_nk:
if read_u_nk:
def extract_u_nk(_extract_u_nk, file, T):
try:
u_nk = self._extract_u_nk(file, T=self.T)
u_nk = _extract_u_nk(file, T)
logger.info(f"Reading {len(u_nk)} lines of u_nk from {file}")
u_nk_list.append(u_nk)
return u_nk
except Exception as exc:
msg = f"Error reading u_nk from {file}."
logger.error(msg)
raise OSError(msg) from exc

if read_dHdl:
u_nk_list = joblib.Parallel(n_jobs=n_jobs)(
joblib.delayed(extract_u_nk)(self._extract_u_nk, file, self.T)
for file in self.file_list
)
else:
u_nk_list = []

if read_dHdl:

def extract_dHdl(_extract_dHdl, file, T):
try:
dhdl = self._extract_dHdl(file, T=self.T)
dhdl = _extract_dHdl(file, T)
logger.info(f"Reading {len(dhdl)} lines of dhdl from {file}")
dHdl_list.append(dhdl)
return dhdl
except Exception as exc:
msg = f"Error reading dHdl from {file}."
logger.error(msg)
raise OSError(msg) from exc

dHdl_list = joblib.Parallel(n_jobs=n_jobs)(
joblib.delayed(extract_dHdl)(self._extract_dHdl, file, self.T)
for file in self.file_list
)
else:
dHdl_list = []

# Sort the files according to the state
if read_u_nk:
logger.info("Sort files according to the u_nk.")
Expand Down Expand Up @@ -201,6 +219,7 @@ def run(
overlap="O_MBAR.pdf",
breakdown=True,
forwrev=None,
n_jobs=1,
*args,
**kwargs,
):
Expand Down Expand Up @@ -236,6 +255,9 @@ def run(
contain u_nk, please run
meth:`~alchemlyb.workflows.ABFE.check_convergence` manually
with estimator='TI'.
n_jobs : int
Number of parallel workers to use for reading and decorrelating the data.
(-1 means using all the threads)
Attributes
----------
Expand Down Expand Up @@ -266,11 +288,12 @@ def run(
)
logger.error(msg)
raise ValueError(msg)

self.read(use_FEP, use_TI)
self.read(read_u_nk=use_FEP, read_dHdl=use_TI, n_jobs=n_jobs)

if uncorr is not None:
self.preprocess(skiptime=skiptime, uncorr=uncorr, threshold=threshold)
self.preprocess(
skiptime=skiptime, uncorr=uncorr, threshold=threshold, n_jobs=n_jobs
)
if estimators is not None:
self.estimate(estimators)
self.generate_result()
Expand Down Expand Up @@ -307,7 +330,7 @@ def update_units(self, units=None):
logger.info(f"Set unit to {units}.")
self.units = units or None

def preprocess(self, skiptime=0, uncorr="dE", threshold=50):
def preprocess(self, skiptime=0, uncorr="dE", threshold=50, n_jobs=1):
"""Preprocess the data by removing the equilibration time and
decorrelate the date.
Expand All @@ -322,6 +345,9 @@ def preprocess(self, skiptime=0, uncorr="dE", threshold=50):
Proceed with correlated samples if the number of uncorrelated
samples is found to be less than this number. If 0 is given, the
time series analysis will not be performed at all. Default: 50.
n_jobs : int
Number of parallel workers to use for decorrelating the data.
(-1 means using all the threads)
Attributes
----------
Expand All @@ -338,33 +364,34 @@ def preprocess(self, skiptime=0, uncorr="dE", threshold=50):
if len(self.u_nk_list) > 0:
logger.info(f"Processing the u_nk data set with skiptime of {skiptime}.")

self.u_nk_sample_list = []
for index, u_nk in enumerate(self.u_nk_list):
# Find the starting frame

def _decorrelate_u_nk(u_nk, skiptime, threshold, index):
u_nk = u_nk[u_nk.index.get_level_values("time") >= skiptime]
subsample = decorrelate_u_nk(u_nk, uncorr, remove_burnin=True)

if len(subsample) < threshold:
logger.warning(
f"Number of u_nk {len(subsample)} "
f"for state {index} is less than the "
f"threshold {threshold}."
)
logger.info(f"Take all the u_nk for state {index}.")
self.u_nk_sample_list.append(u_nk)
subsample = u_nk
else:
logger.info(
f"Take {len(subsample)} uncorrelated "
f"u_nk for state {index}."
)
self.u_nk_sample_list.append(subsample)
return subsample

self.u_nk_sample_list = joblib.Parallel(n_jobs=n_jobs)(
joblib.delayed(_decorrelate_u_nk)(u_nk, skiptime, threshold, index)
for index, u_nk in enumerate(self.u_nk_list)
)
else:
logger.info("No u_nk data being subsampled")

if len(self.dHdl_list) > 0:
self.dHdl_sample_list = []
for index, dHdl in enumerate(self.dHdl_list):

def _decorrelate_dhdl(dHdl, skiptime, threshold, index):
dHdl = dHdl[dHdl.index.get_level_values("time") >= skiptime]
subsample = decorrelate_dhdl(dHdl, remove_burnin=True)
if len(subsample) < threshold:
Expand All @@ -374,13 +401,18 @@ def preprocess(self, skiptime=0, uncorr="dE", threshold=50):
f"threshold {threshold}."
)
logger.info(f"Take all the dHdl for state {index}.")
self.dHdl_sample_list.append(dHdl)
subsample = dHdl
else:
logger.info(
f"Take {len(subsample)} uncorrelated "
f"dHdl for state {index}."
)
self.dHdl_sample_list.append(subsample)
return subsample

self.dHdl_sample_list = joblib.Parallel(n_jobs=n_jobs)(
joblib.delayed(_decorrelate_dhdl)(dHdl, skiptime, threshold, index)
for index, dHdl in enumerate(self.dHdl_list)
)
else:
logger.info("No dHdl data being subsampled")

Expand Down

0 comments on commit 30ecce0

Please sign in to comment.