Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of Parallelization to MDAnalysis.analysis.contacts #4820

Open
wants to merge 19 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions package/CHANGELOG
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@ Fixes
the function to prevent shared state. (Issue #4655)

Enhancements
* Enables parallelization for analysis.contacts.Contacts (Issue #4660)
* Enable parallelization for analysis.nucleicacids.NucPairDist (Issue #4670)
* Add check and warning for empty (all zero) coordinates in RDKit converter (PR #4824)
* Added `precision` for XYZWriter (Issue #4775, PR #4771)


Changes

Deprecations
Expand Down
48 changes: 41 additions & 7 deletions package/MDAnalysis/analysis/contacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def is_any_closer(r, r0, dist=2.5):
from MDAnalysis.lib.util import openany
from MDAnalysis.analysis.distances import distance_array
from MDAnalysis.core.groups import AtomGroup, UpdatingAtomGroup
from .base import AnalysisBase
from .base import AnalysisBase, ResultsGroup

logger = logging.getLogger("MDAnalysis.analysis.contacts")

Expand Down Expand Up @@ -376,8 +376,22 @@ class Contacts(AnalysisBase):
:class:`MDAnalysis.analysis.base.Results` instance.
.. versionchanged:: 2.2.0
:class:`Contacts` accepts both AtomGroup and string for `select`
.. versionchanged:: 2.9.0
Introduced :meth:`get_supported_backends` allowing
for parallel execution on :mod:`multiprocessing`
and :mod:`dask` backends.
"""

_analysis_algorithm_is_parallelizable = True

@classmethod
def get_supported_backends(cls):
return (
"serial",
"multiprocessing",
"dask",
)

def __init__(self, u, select, refgroup, method="hard_cut", radius=4.5,
pbc=True, kwargs=None, **basekwargs):
"""
Expand Down Expand Up @@ -444,11 +458,8 @@ def __init__(self, u, select, refgroup, method="hard_cut", radius=4.5,
self.r0 = []
self.initial_contacts = []

#get dimension of box if pbc set to True
if self.pbc:
self._get_box = lambda ts: ts.dimensions
else:
self._get_box = lambda ts: None
# get dimensions via partial for parallelization compatibility
self._get_box = functools.partial(self._get_box_func, pbc=self.pbc)

if isinstance(refgroup[0], AtomGroup):
refA, refB = refgroup
Expand All @@ -464,7 +475,6 @@ def __init__(self, u, select, refgroup, method="hard_cut", radius=4.5,

self.n_initial_contacts = self.initial_contacts[0].sum()


@staticmethod
def _get_atomgroup(u, sel):
select_error_message = ("selection must be either string or a "
Expand All @@ -480,6 +490,28 @@ def _get_atomgroup(u, sel):
else:
raise TypeError(select_error_message)

@staticmethod
def _get_box_func(ts, pbc):
"""Retrieve the dimensions of the simulation box based on PBC.

Parameters
----------
ts : Timestep
The current timestep of the simulation, which contains the
box dimensions.
pbc : bool
A flag indicating whether periodic boundary conditions (PBC)
are enabled. If `True`, the box dimensions are returned,
else returns `None`.

Returns
-------
box_dimensions : ndarray or None
The dimensions of the simulation box as a NumPy array if PBC
is True, else returns `None`.
"""
return ts.dimensions if pbc else None

def _prepare(self):
self.results.timeseries = np.empty((self.n_frames, len(self.r0)+1))

Expand All @@ -506,6 +538,8 @@ def timeseries(self):
warnings.warn(wmsg, DeprecationWarning)
return self.results.timeseries

def _get_aggregator(self):
return ResultsGroup(lookup={'timeseries': ResultsGroup.ndarray_vstack})

def _new_selections(u_orig, selections, frame):
"""create stand alone AGs from selections at frame"""
Expand Down
8 changes: 8 additions & 0 deletions testsuite/MDAnalysisTests/analysis/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
HydrogenBondAnalysis,
)
from MDAnalysis.analysis.nucleicacids import NucPairDist
from MDAnalysis.analysis.contacts import Contacts
from MDAnalysis.lib.util import is_installed


Expand Down Expand Up @@ -149,3 +150,10 @@ def client_HydrogenBondAnalysis(request):
@pytest.fixture(scope="module", params=params_for_cls(NucPairDist))
def client_NucPairDist(request):
return request.param


# MDAnalysis.analysis.contacts

@pytest.fixture(scope="module", params=params_for_cls(Contacts))
def client_Contacts(request):
return request.param
87 changes: 54 additions & 33 deletions testsuite/MDAnalysisTests/analysis/test_contacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,8 @@ def universe():
return mda.Universe(PSF, DCD)

def _run_Contacts(
self, universe,
start=None, stop=None, step=None, **kwargs
self, universe, client_Contacts, start=None,
stop=None, step=None, **kwargs
):
acidic = universe.select_atoms(self.sel_acidic)
basic = universe.select_atoms(self.sel_basic)
Expand All @@ -181,7 +181,8 @@ def _run_Contacts(
select=(self.sel_acidic, self.sel_basic),
refgroup=(acidic, basic),
radius=6.0,
**kwargs).run(start=start, stop=stop, step=step)
**kwargs
).run(**client_Contacts, start=start, stop=stop, step=step)

@pytest.mark.parametrize("seltxt", [sel_acidic, sel_basic])
def test_select_valid_types(self, universe, seltxt):
Expand All @@ -195,7 +196,7 @@ def test_select_valid_types(self, universe, seltxt):

assert ag_from_string == ag_from_ag

def test_contacts_selections(self, universe):
def test_contacts_selections(self, universe, client_Contacts):
"""Test if Contacts can take both string and AtomGroup as selections.
"""
aga = universe.select_atoms(self.sel_acidic)
Expand All @@ -210,8 +211,8 @@ def test_contacts_selections(self, universe):
refgroup=(aga, agb)
)

cag.run()
csel.run()
cag.run(**client_Contacts)
csel.run(**client_Contacts)

assert cag.grA == csel.grA
assert cag.grB == csel.grB
Expand All @@ -228,26 +229,31 @@ def test_select_wrong_types(self, universe, ag):
) as te:
contacts.Contacts._get_atomgroup(universe, ag)

def test_startframe(self, universe):
def test_startframe(self, universe, client_Contacts):
"""test_startframe: TestContactAnalysis1: start frame set to 0 (resolution of
Issue #624)

"""
CA1 = self._run_Contacts(universe)
CA1 = self._run_Contacts(universe, client_Contacts=client_Contacts)
assert len(CA1.results.timeseries) == universe.trajectory.n_frames

def test_end_zero(self, universe):
def test_end_zero(self, universe, client_Contacts):
"""test_end_zero: TestContactAnalysis1: stop frame 0 is not ignored"""
CA1 = self._run_Contacts(universe, stop=0)
CA1 = self._run_Contacts(
universe, client_Contacts=client_Contacts, stop=0
)
assert len(CA1.results.timeseries) == 0

def test_slicing(self, universe):
def test_slicing(self, universe, client_Contacts):
start, stop, step = 10, 30, 5
CA1 = self._run_Contacts(universe, start=start, stop=stop, step=step)
CA1 = self._run_Contacts(
universe, client_Contacts=client_Contacts,
start=start, stop=stop, step=step
)
frames = np.arange(universe.trajectory.n_frames)[start:stop:step]
assert len(CA1.results.timeseries) == len(frames)

def test_villin_folded(self):
def test_villin_folded(self, client_Contacts):
# one folded, one unfolded
f = mda.Universe(contacts_villin_folded)
u = mda.Universe(contacts_villin_unfolded)
Expand All @@ -259,12 +265,12 @@ def test_villin_folded(self):
select=(sel, sel),
refgroup=(grF, grF),
method="soft_cut")
q.run()
q.run(**client_Contacts)

results = soft_cut(f, u, sel, sel)
assert_allclose(q.results.timeseries[:, 1], results[:, 1], rtol=0, atol=1.5e-7)

def test_villin_unfolded(self):
def test_villin_unfolded(self, client_Contacts):
# both folded
f = mda.Universe(contacts_villin_folded)
u = mda.Universe(contacts_villin_folded)
Expand All @@ -276,13 +282,13 @@ def test_villin_unfolded(self):
select=(sel, sel),
refgroup=(grF, grF),
method="soft_cut")
q.run()
q.run(**client_Contacts)

results = soft_cut(f, u, sel, sel)
assert_allclose(q.results.timeseries[:, 1], results[:, 1], rtol=0, atol=1.5e-7)

def test_hard_cut_method(self, universe):
ca = self._run_Contacts(universe)
def test_hard_cut_method(self, universe, client_Contacts):
ca = self._run_Contacts(universe, client_Contacts=client_Contacts)
expected = [1., 0.58252427, 0.52427184, 0.55339806, 0.54368932,
0.54368932, 0.51456311, 0.46601942, 0.48543689, 0.52427184,
0.46601942, 0.58252427, 0.51456311, 0.48543689, 0.48543689,
Expand All @@ -306,7 +312,7 @@ def test_hard_cut_method(self, universe):
assert len(ca.results.timeseries) == len(expected)
assert_allclose(ca.results.timeseries[:, 1], expected, rtol=0, atol=1.5e-7)

def test_radius_cut_method(self, universe):
def test_radius_cut_method(self, universe, client_Contacts):
acidic = universe.select_atoms(self.sel_acidic)
basic = universe.select_atoms(self.sel_basic)
r = contacts.distance_array(acidic.positions, basic.positions)
Expand All @@ -316,15 +322,20 @@ def test_radius_cut_method(self, universe):
r = contacts.distance_array(acidic.positions, basic.positions)
expected.append(contacts.radius_cut_q(r[initial_contacts], None, radius=6.0))

ca = self._run_Contacts(universe, method='radius_cut')
ca = self._run_Contacts(
universe, client_Contacts=client_Contacts, method="radius_cut"
)
assert_array_equal(ca.results.timeseries[:, 1], expected)

@staticmethod
def _is_any_closer(r, r0, dist=2.5):
return np.any(r < dist)

def test_own_method(self, universe):
ca = self._run_Contacts(universe, method=self._is_any_closer)
def test_own_method(self, universe, client_Contacts):
ca = self._run_Contacts(
universe, client_Contacts=client_Contacts,
method=self._is_any_closer
)

bound_expected = [1., 1., 0., 1., 1., 0., 0., 1., 0., 1., 1., 0., 0.,
1., 0., 0., 0., 0., 1., 1., 0., 0., 0., 1., 0., 1.,
Expand All @@ -340,21 +351,28 @@ def test_own_method(self, universe):
def _weird_own_method(r, r0):
return 'aaa'

def test_own_method_no_array_cast(self, universe):
def test_own_method_no_array_cast(self, universe, client_Contacts):
with pytest.raises(ValueError):
self._run_Contacts(universe, method=self._weird_own_method, stop=2)

def test_non_callable_method(self, universe):
self._run_Contacts(
universe,
client_Contacts=client_Contacts,
method=self._weird_own_method,
stop=2,
)

def test_non_callable_method(self, universe, client_Contacts):
with pytest.raises(ValueError):
self._run_Contacts(universe, method=2, stop=2)
self._run_Contacts(
universe, client_Contacts=client_Contacts, method=2, stop=2
)

@pytest.mark.parametrize("pbc,expected", [
(True, [1., 0.43138152, 0.3989021, 0.43824337, 0.41948765,
0.42223239, 0.41354071, 0.43641354, 0.41216834, 0.38334858]),
(False, [1., 0.42327791, 0.39192399, 0.40950119, 0.40902613,
0.42470309, 0.41140143, 0.42897862, 0.41472684, 0.38574822])
])
def test_distance_box(self, pbc, expected):
def test_distance_box(self, pbc, expected, client_Contacts):
u = mda.Universe(TPR, XTC)
sel_basic = "(resname ARG LYS)"
sel_acidic = "(resname ASP GLU)"
Expand All @@ -363,13 +381,15 @@ def test_distance_box(self, pbc, expected):

r = contacts.Contacts(u, select=(sel_acidic, sel_basic),
refgroup=(acidic, basic), radius=6.0, pbc=pbc)
r.run()
r.run(**client_Contacts)
assert_allclose(r.results.timeseries[:, 1], expected,rtol=0, atol=1.5e-7)

def test_warn_deprecated_attr(self, universe):
def test_warn_deprecated_attr(self, universe, client_Contacts):
"""Test for warning message emitted on using deprecated `timeseries`
attribute"""
CA1 = self._run_Contacts(universe, stop=1)
CA1 = self._run_Contacts(
universe, client_Contacts=client_Contacts, stop=1
)
wmsg = "The `timeseries` attribute was deprecated in MDAnalysis"
with pytest.warns(DeprecationWarning, match=wmsg):
assert_equal(CA1.timeseries, CA1.results.timeseries)
Expand All @@ -385,10 +405,11 @@ def test_n_initial_contacts(self, datafiles, expected):
r = contacts.Contacts(u, select=select, refgroup=refgroup)
assert_equal(r.n_initial_contacts, expected)

def test_q1q2():

def test_q1q2(client_Contacts):
u = mda.Universe(PSF, DCD)
q1q2 = contacts.q1q2(u, 'name CA', radius=8)
q1q2.run()
q1q2.run(**client_Contacts)

q1_expected = [1., 0.98092643, 0.97366031, 0.97275204, 0.97002725,
0.97275204, 0.96276113, 0.96730245, 0.9582198, 0.96185286,
Expand Down
Loading