Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of Parallelization to MDAnalysis.analysis.contacts #4820

Open
wants to merge 19 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion package/CHANGELOG
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ The rules for this file:


-------------------------------------------------------------------------------
??/??/?? IAlibay, ChiahsinChu, RMeli
??/??/?? IAlibay, ChiahsinChu, RMeli, talagayev

* 2.9.0

Expand All @@ -23,6 +23,7 @@ Fixes
Enhancements
* Add check and warning for empty (all zero) coordinates in RDKit converter (PR #4824)
* Added `precision` for XYZWriter (Issue #4775, PR #4771)
* Enables parallelization for analysis.contacts.Contacts (Issue #4660)

Changes

Expand Down
46 changes: 41 additions & 5 deletions package/MDAnalysis/analysis/contacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def is_any_closer(r, r0, dist=2.5):
from MDAnalysis.lib.util import openany
from MDAnalysis.analysis.distances import distance_array
from MDAnalysis.core.groups import AtomGroup, UpdatingAtomGroup
from .base import AnalysisBase
from .base import AnalysisBase, ResultsGroup

logger = logging.getLogger("MDAnalysis.analysis.contacts")

Expand Down Expand Up @@ -338,6 +338,28 @@ def contact_matrix(d, radius, out=None):
return out


def get_box(ts, pbc):
"""Retrieve the dimensions of the simulation box based on PBC.
RMeli marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
ts : Timestep
The current timestep of the simulation, which contains the
box dimensions.
pbc : bool
A flag indicating whether periodic boundary conditions (PBC)
are enabled. If `True`, the box dimensions are returned,
else returns `None`.

Returns
-------
box_dimensions : ndarray or None
The dimensions of the simulation box as a NumPy array if PBC
is True, else returns `None`.
"""
return ts.dimensions if pbc else None


class Contacts(AnalysisBase):
"""Calculate contacts based observables.

Expand Down Expand Up @@ -376,8 +398,22 @@ class Contacts(AnalysisBase):
:class:`MDAnalysis.analysis.base.Results` instance.
.. versionchanged:: 2.2.0
:class:`Contacts` accepts both AtomGroup and string for `select`
.. versionchanged:: 2.9.0
Introduced :meth:`get_supported_backends` allowing
for parallel execution on :mod:`multiprocessing`
and :mod:`dask` backends.
"""

_analysis_algorithm_is_parallelizable = True

@classmethod
def get_supported_backends(cls):
return (
"serial",
"multiprocessing",
"dask",
)

def __init__(self, u, select, refgroup, method="hard_cut", radius=4.5,
pbc=True, kwargs=None, **basekwargs):
"""
Expand Down Expand Up @@ -445,10 +481,8 @@ def __init__(self, u, select, refgroup, method="hard_cut", radius=4.5,
self.initial_contacts = []

#get dimension of box if pbc set to True
if self.pbc:
self._get_box = lambda ts: ts.dimensions
else:
self._get_box = lambda ts: None
self.pbc = pbc
self._get_box = functools.partial(get_box, pbc=self.pbc)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self._get_box = functools.partial(get_box, pbc=self.pbc)
self._get_box = lambda ts: ts.dimensions if self.pbc else None

?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@RMeli I tried running it with the suggestion locally and I get

AttributeError: Can't get local object 'Contacts.__init__.<locals>.<lambda>'

It is somehow due to the parallelization having problems with lamba, it also does not like this

        self._get_box = functools.partial(
            lambda ts, pbc: ts.dimensions if pbc else None, pbc=self.pbc
        )

it leads to the same error, so yes I would agree that the function is quite general, so moving
def get_box() to be a function in the Contacts class would be fine? since that would
works with the tests

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't put get_box() elsewhere because it's really a bit clunky and not really that useful except that it's needed here (or do we have other places where we need it??). I'd keep it as local and non-public as possible. Writing it as a lambda would be the nicest way but if that does not work (... good to know and a bit sad ...) then I'd make it a static private class method

@staticmethod
def _get_box_func(ts, pbc):
    """... keep the written docs ... """
    return ts.dimensions if pbc else None

which you can then use for the partial.

Add a comment to self._get_box = functools.partial(self._get_box_func, pbc=self.pbc) why all of this is necessary.


(I assume that following will run into the same issue as lambdas:

def get_box(ts, pbc=self.pbc):
   return ts.dimensions if pbc else None
self._get_box = get_box

as it's also local, right?)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @talagayev for the explanation, I didn't know a lambda would not work in this case. I think defining a function instead of a lambda will in our in the same issue, as you suspect @orbeckst, but haven't tried. If we don't want to move the function elsewhere, I agree it should be private and not exposed by the module.

Copy link
Member Author

@talagayev talagayev Dec 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I assume that following will run into the same issue as lambdas:

def get_box(ts, pbc=self.pbc):
   return ts.dimensions if pbc else None
self._get_box = get_box

as it's also local, right?)

Yes, tried it right now locally, leads sadly to the Can't get local object 'Contacts.__init__.<locals>.get_box'

Copy link
Member Author

@talagayev talagayev Dec 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but yes making it more local static method is a more sophisticated way then it is currently, will adjust it then :)

as for it being used anywhere else, not really, it is only used here once

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@orbeckst @RMeli I moved it now as you suggested and would re request the review, to see if it is correct now :)


if isinstance(refgroup[0], AtomGroup):
refA, refB = refgroup
Expand Down Expand Up @@ -506,6 +540,8 @@ def timeseries(self):
warnings.warn(wmsg, DeprecationWarning)
return self.results.timeseries

def _get_aggregator(self):
return ResultsGroup(lookup={'timeseries': ResultsGroup.ndarray_vstack})

def _new_selections(u_orig, selections, frame):
"""create stand alone AGs from selections at frame"""
Expand Down
8 changes: 8 additions & 0 deletions testsuite/MDAnalysisTests/analysis/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from MDAnalysis.analysis.hydrogenbonds.hbond_analysis import (
HydrogenBondAnalysis,
)
from MDAnalysis.analysis.contacts import Contacts
from MDAnalysis.lib.util import is_installed


Expand Down Expand Up @@ -141,3 +142,10 @@ def client_DSSP(request):
@pytest.fixture(scope='module', params=params_for_cls(HydrogenBondAnalysis))
def client_HydrogenBondAnalysis(request):
return request.param


# MDAnalysis.analysis.contacts

@pytest.fixture(scope="module", params=params_for_cls(Contacts))
def client_Contacts(request):
return request.param
87 changes: 54 additions & 33 deletions testsuite/MDAnalysisTests/analysis/test_contacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,8 @@ def universe():
return mda.Universe(PSF, DCD)

def _run_Contacts(
self, universe,
start=None, stop=None, step=None, **kwargs
self, universe, client_Contacts, start=None,
stop=None, step=None, **kwargs
):
acidic = universe.select_atoms(self.sel_acidic)
basic = universe.select_atoms(self.sel_basic)
Expand All @@ -181,7 +181,8 @@ def _run_Contacts(
select=(self.sel_acidic, self.sel_basic),
refgroup=(acidic, basic),
radius=6.0,
**kwargs).run(start=start, stop=stop, step=step)
**kwargs
).run(**client_Contacts, start=start, stop=stop, step=step)

@pytest.mark.parametrize("seltxt", [sel_acidic, sel_basic])
def test_select_valid_types(self, universe, seltxt):
Expand All @@ -195,7 +196,7 @@ def test_select_valid_types(self, universe, seltxt):

assert ag_from_string == ag_from_ag

def test_contacts_selections(self, universe):
def test_contacts_selections(self, universe, client_Contacts):
"""Test if Contacts can take both string and AtomGroup as selections.
"""
aga = universe.select_atoms(self.sel_acidic)
Expand All @@ -210,8 +211,8 @@ def test_contacts_selections(self, universe):
refgroup=(aga, agb)
)

cag.run()
csel.run()
cag.run(**client_Contacts)
csel.run(**client_Contacts)

assert cag.grA == csel.grA
assert cag.grB == csel.grB
Expand All @@ -228,26 +229,31 @@ def test_select_wrong_types(self, universe, ag):
) as te:
contacts.Contacts._get_atomgroup(universe, ag)

def test_startframe(self, universe):
def test_startframe(self, universe, client_Contacts):
"""test_startframe: TestContactAnalysis1: start frame set to 0 (resolution of
Issue #624)

"""
CA1 = self._run_Contacts(universe)
CA1 = self._run_Contacts(universe, client_Contacts=client_Contacts)
assert len(CA1.results.timeseries) == universe.trajectory.n_frames

def test_end_zero(self, universe):
def test_end_zero(self, universe, client_Contacts):
"""test_end_zero: TestContactAnalysis1: stop frame 0 is not ignored"""
CA1 = self._run_Contacts(universe, stop=0)
CA1 = self._run_Contacts(
universe, client_Contacts=client_Contacts, stop=0
)
assert len(CA1.results.timeseries) == 0

def test_slicing(self, universe):
def test_slicing(self, universe, client_Contacts):
start, stop, step = 10, 30, 5
CA1 = self._run_Contacts(universe, start=start, stop=stop, step=step)
CA1 = self._run_Contacts(
universe, client_Contacts=client_Contacts,
start=start, stop=stop, step=step
)
frames = np.arange(universe.trajectory.n_frames)[start:stop:step]
assert len(CA1.results.timeseries) == len(frames)

def test_villin_folded(self):
def test_villin_folded(self, client_Contacts):
# one folded, one unfolded
f = mda.Universe(contacts_villin_folded)
u = mda.Universe(contacts_villin_unfolded)
Expand All @@ -259,12 +265,12 @@ def test_villin_folded(self):
select=(sel, sel),
refgroup=(grF, grF),
method="soft_cut")
q.run()
q.run(**client_Contacts)

results = soft_cut(f, u, sel, sel)
assert_allclose(q.results.timeseries[:, 1], results[:, 1], rtol=0, atol=1.5e-7)

def test_villin_unfolded(self):
def test_villin_unfolded(self, client_Contacts):
# both folded
f = mda.Universe(contacts_villin_folded)
u = mda.Universe(contacts_villin_folded)
Expand All @@ -276,13 +282,13 @@ def test_villin_unfolded(self):
select=(sel, sel),
refgroup=(grF, grF),
method="soft_cut")
q.run()
q.run(**client_Contacts)

results = soft_cut(f, u, sel, sel)
assert_allclose(q.results.timeseries[:, 1], results[:, 1], rtol=0, atol=1.5e-7)

def test_hard_cut_method(self, universe):
ca = self._run_Contacts(universe)
def test_hard_cut_method(self, universe, client_Contacts):
ca = self._run_Contacts(universe, client_Contacts=client_Contacts)
expected = [1., 0.58252427, 0.52427184, 0.55339806, 0.54368932,
0.54368932, 0.51456311, 0.46601942, 0.48543689, 0.52427184,
0.46601942, 0.58252427, 0.51456311, 0.48543689, 0.48543689,
Expand All @@ -306,7 +312,7 @@ def test_hard_cut_method(self, universe):
assert len(ca.results.timeseries) == len(expected)
assert_allclose(ca.results.timeseries[:, 1], expected, rtol=0, atol=1.5e-7)

def test_radius_cut_method(self, universe):
def test_radius_cut_method(self, universe, client_Contacts):
acidic = universe.select_atoms(self.sel_acidic)
basic = universe.select_atoms(self.sel_basic)
r = contacts.distance_array(acidic.positions, basic.positions)
Expand All @@ -316,15 +322,20 @@ def test_radius_cut_method(self, universe):
r = contacts.distance_array(acidic.positions, basic.positions)
expected.append(contacts.radius_cut_q(r[initial_contacts], None, radius=6.0))

ca = self._run_Contacts(universe, method='radius_cut')
ca = self._run_Contacts(
universe, client_Contacts=client_Contacts, method="radius_cut"
)
assert_array_equal(ca.results.timeseries[:, 1], expected)

@staticmethod
def _is_any_closer(r, r0, dist=2.5):
return np.any(r < dist)

def test_own_method(self, universe):
ca = self._run_Contacts(universe, method=self._is_any_closer)
def test_own_method(self, universe, client_Contacts):
ca = self._run_Contacts(
universe, client_Contacts=client_Contacts,
method=self._is_any_closer
)

bound_expected = [1., 1., 0., 1., 1., 0., 0., 1., 0., 1., 1., 0., 0.,
1., 0., 0., 0., 0., 1., 1., 0., 0., 0., 1., 0., 1.,
Expand All @@ -340,21 +351,28 @@ def test_own_method(self, universe):
def _weird_own_method(r, r0):
return 'aaa'

def test_own_method_no_array_cast(self, universe):
def test_own_method_no_array_cast(self, universe, client_Contacts):
with pytest.raises(ValueError):
self._run_Contacts(universe, method=self._weird_own_method, stop=2)

def test_non_callable_method(self, universe):
self._run_Contacts(
universe,
client_Contacts=client_Contacts,
method=self._weird_own_method,
stop=2,
)

def test_non_callable_method(self, universe, client_Contacts):
with pytest.raises(ValueError):
self._run_Contacts(universe, method=2, stop=2)
self._run_Contacts(
universe, client_Contacts=client_Contacts, method=2, stop=2
)

@pytest.mark.parametrize("pbc,expected", [
(True, [1., 0.43138152, 0.3989021, 0.43824337, 0.41948765,
0.42223239, 0.41354071, 0.43641354, 0.41216834, 0.38334858]),
(False, [1., 0.42327791, 0.39192399, 0.40950119, 0.40902613,
0.42470309, 0.41140143, 0.42897862, 0.41472684, 0.38574822])
])
def test_distance_box(self, pbc, expected):
def test_distance_box(self, pbc, expected, client_Contacts):
u = mda.Universe(TPR, XTC)
sel_basic = "(resname ARG LYS)"
sel_acidic = "(resname ASP GLU)"
Expand All @@ -363,13 +381,15 @@ def test_distance_box(self, pbc, expected):

r = contacts.Contacts(u, select=(sel_acidic, sel_basic),
refgroup=(acidic, basic), radius=6.0, pbc=pbc)
r.run()
r.run(**client_Contacts)
assert_allclose(r.results.timeseries[:, 1], expected,rtol=0, atol=1.5e-7)

def test_warn_deprecated_attr(self, universe):
def test_warn_deprecated_attr(self, universe, client_Contacts):
"""Test for warning message emitted on using deprecated `timeseries`
attribute"""
CA1 = self._run_Contacts(universe, stop=1)
CA1 = self._run_Contacts(
universe, client_Contacts=client_Contacts, stop=1
)
wmsg = "The `timeseries` attribute was deprecated in MDAnalysis"
with pytest.warns(DeprecationWarning, match=wmsg):
assert_equal(CA1.timeseries, CA1.results.timeseries)
Expand All @@ -385,10 +405,11 @@ def test_n_initial_contacts(self, datafiles, expected):
r = contacts.Contacts(u, select=select, refgroup=refgroup)
assert_equal(r.n_initial_contacts, expected)

def test_q1q2():

def test_q1q2(client_Contacts):
u = mda.Universe(PSF, DCD)
q1q2 = contacts.q1q2(u, 'name CA', radius=8)
q1q2.run()
q1q2.run(**client_Contacts)

q1_expected = [1., 0.98092643, 0.97366031, 0.97275204, 0.97002725,
0.97275204, 0.96276113, 0.96730245, 0.9582198, 0.96185286,
Expand Down
Loading