Skip to content

Commit

Permalink
New TransformationBase class (#2950)
Browse files Browse the repository at this point in the history
Fixes #2996 

## Work done in this PR
- Adds a TransformationBase class
- TransformationBase uses threadpoolctl to allow the number of threads used to be limited in order to improve performance.
- TransformationBase also includes a check for whether a transformation is parallelizable.
- Refactors existing transformation classes to use TransformationBase.
  • Loading branch information
yuxuanzhuang authored Apr 10, 2021
1 parent b1d8bcc commit 13a5df0
Show file tree
Hide file tree
Showing 17 changed files with 407 additions and 40 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/gh-ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ defaults:
shell: bash -l {0}

env:
MDA_CONDA_MIN_DEPS: "pip pytest==6.1.2 mmtf-python biopython networkx cython matplotlib-base scipy griddataformats hypothesis gsd codecov"
MDA_CONDA_MIN_DEPS: "pip pytest==6.1.2 mmtf-python biopython networkx cython matplotlib-base scipy griddataformats hypothesis gsd codecov threadpoolctl"
MDA_CONDA_EXTRA_DEPS: "seaborn>=0.7.0 clustalw=2.1 netcdf4 scikit-learn joblib>=0.12 chemfiles tqdm>=4.43.0 tidynamics>=1.0.0 rdkit>=2020.03.1 h5py"
MDA_PIP_MIN_DEPS: 'coveralls coverage<5 pytest-cov pytest-xdist'
MDA_PIP_EXTRA_DEPS: 'duecredit parmed'
Expand Down
4 changes: 2 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ matrix:
- conda activate base
- conda update --yes conda
- conda install --yes pip pytest==6.1.2 mmtf-python biopython networkx cython matplotlib-base "scipy>=1.6" griddataformats hypothesis gsd codecov -c conda-forge -c biobuilds
- pip install pytest-xdist
- pip install pytest-xdist threadpoolctl
- conda info
install:
- (cd package/ && python setup.py develop) && (cd testsuite/ && python setup.py install)
Expand All @@ -45,7 +45,7 @@ matrix:
if: type = cron
before_install:
- python -m pip install cython numpy scipy
- python -m pip install --no-build-isolation hypothesis matplotlib pytest pytest-cov pytest-xdist tqdm
- python -m pip install --no-build-isolation hypothesis matplotlib pytest pytest-cov pytest-xdist tqdm threadpoolctl
install:
- cd package
- python setup.py install
Expand Down
1 change: 1 addition & 0 deletions azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ jobs:
scipy
h5py
tqdm
threadpoolctl
displayName: 'Install dependencies'
# TODO: recent rdkit is not on PyPI
- script: >-
Expand Down
3 changes: 3 additions & 0 deletions package/CHANGELOG
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ Enhancements
'protein' selection (#2751 PR #2755)
* Added an RDKit converter that works for any input with all hydrogens
explicit in the topology (Issue #2468, PR #2775)
* Added `TransformationBase`. Implemented threadlimit and add an attribute for
checking if it can be used in parallel analysis. (Issue #2996, PR #2950)

Changes
* Fixed inaccurate docstring inside the RMSD class (Issue #2796, PR #3134)
Expand Down Expand Up @@ -198,6 +200,7 @@ Changes
* deprecated ``analysis.helanal`` module has been removed in favour of
``analysis.helix_analysis`` (PR #2929)
* Move water bridge analysis from hbonds to hydrogenbonds (Issue #2739 PR #2913)
* `threadpoolctl` is required for installation (Issue #2860)

Deprecations

Expand Down
1 change: 1 addition & 0 deletions package/MDAnalysis/transformations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def wrapped(ts):
method instead of being written as a function/closure.
"""

from .base import TransformationBase
from .translate import translate, center_in_box
from .rotate import rotateby
from .positionaveraging import PositionAverager
Expand Down
124 changes: 124 additions & 0 deletions package/MDAnalysis/transformations/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# -*- Mode: python; tab-width: 4; indent-tabs-mode:nil; coding:utf-8 -*-
# vim: tabstop=4 expandtab shiftwidth=4 softtabstop=4
#
# MDAnalysis --- https://www.mdanalysis.org
# Copyright (c) 2006-2017 The MDAnalysis Development Team and contributors
# (see the file AUTHORS for the full list of names)
#
# Released under the GNU Public Licence, v2 or any higher version
#
# Please cite your use of MDAnalysis in published work:
#
# R. J. Gowers, M. Linke, J. Barnoud, T. J. E. Reddy, M. N. Melo, S. L. Seyler,
# D. L. Dotson, J. Domanski, S. Buchoux, I. M. Kenney, and O. Beckstein.
# MDAnalysis: A Python package for the rapid analysis of molecular dynamics
# simulations. In S. Benthall and S. Rostrup editors, Proceedings of the 15th
# Python in Science Conference, pages 102-109, Austin, TX, 2016. SciPy.
#
# N. Michaud-Agrawal, E. J. Denning, T. B. Woolf, and O. Beckstein.
# MDAnalysis: A Toolkit for the Analysis of Molecular Dynamics Simulations.
# J. Comput. Chem. 32 (2011), 2319--2327, doi:10.1002/jcc.21787
#

"""\
Transformations Base Class --- :mod:`MDAnalysis.transformations.base`
=====================================================================
.. autoclass:: TransformationBase
"""
from threadpoolctl import threadpool_limits


class TransformationBase(object):
"""Base class for defining on-the-fly transformations
The class is designed as a template for creating on-the-fly
Transformation classes. This class will
1) set up a context manager framework on limiting the threads
per call, which may be the multi-thread OpenBlas backend of NumPy.
This backend may kill the performance when subscribing hyperthread
or oversubscribing the threads when used together with other parallel
engines e.g. Dask.
(PR `#2950 <https://github.com/MDAnalysis/mdanalysis/pull/2950>`_)
Define ``max_threads=1`` when that is the case.
2) set up a boolean attribute `parallelizable` for checking if the
transformation can be applied in a **split-apply-combine** parallelism.
For example, the :class:`~MDAnalysis.transformations.positionaveraging.PositionAverager`
is history-dependent and can not be used in parallel analysis natively.
(Issue `#2996 <https://github.com/MDAnalysis/mdanalysis/issues/2996>`_)
To define a new Transformation, :class:`TransformationBase`
has to be subclassed.
``max_threads`` will be set to ``None`` by default,
i.e. does not do anything and any settings in the environment such as
the environment variable :envvar:`OMP_NUM_THREADS`
(see the `OpenMP specification for OMP_NUM_THREADS <https://www.openmp.org/spec-html/5.0/openmpse50.html>`_)
are used.
``parallelizable`` will be set to ``True`` by default.
You may need to double check if it can be used in parallel analysis;
if not, override the value to ``False``.
Note this attribute is not checked anywhere in MDAnalysis yet.
Developers of the parallel analysis have to check it in their own code.
.. code-block:: python
class NewTransformation(TransformationBase):
def __init__(self, ag, parameter,
max_threads=1, parallelizable=True):
super().__init__(max_threads=max_threads,
parallelizable=parallelizable)
self.ag = ag
self._param = parameter
def _transform(self, ts):
# REQUIRED
ts.positions = some_function(ts, self.ag, self._param)
return ts
Afterwards the new transformation can be run like this.
.. code-block:: python
new_transformation = NewTransformation(ag, param)
u.trajectory.add_transformations(new_transformation)
.. versionadded:: 2.0.0
Add the base class for all transformations to limit threads and
check if it can be used in parallel analysis.
"""

def __init__(self, **kwargs):
"""
Parameters
----------
max_threads: int, optional
The maximum thread number can be used.
Default is ``None``, which means the default or the external setting.
parallelizable: bool, optional
A check for if this can be used in split-apply-combine parallel
analysis approach.
Default is ``True``.
"""
self.max_threads = kwargs.pop('max_threads', None)
self.parallelizable = kwargs.pop('parallelizable', True)

def __call__(self, ts):
"""The function that makes transformation can be called as a function
The thread limit works as a context manager with given `max_threads`
wrapping the real :func:`_transform` function
"""
with threadpool_limits(self.max_threads):
return self._transform(ts)

def _transform(self, ts):
"""Transform the given `Timestep`
It deals with the transformation of a single `Timestep`.
"""
raise NotImplementedError("Only implemented in child classes")
13 changes: 10 additions & 3 deletions package/MDAnalysis/transformations/boxdimensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@
"""
import numpy as np

from .base import TransformationBase

class set_dimensions:

class set_dimensions(TransformationBase):
"""
Set simulation box dimensions.
Expand Down Expand Up @@ -63,7 +65,12 @@ class set_dimensions:
:class:`~MDAnalysis.coordinates.base.Timestep` object
"""

def __init__(self, dimensions):
def __init__(self,
dimensions,
max_threads=None,
parallelizable=True):
super().__init__(max_threads=max_threads,
parallelizable=parallelizable)
self.dimensions = dimensions

try:
Expand All @@ -82,6 +89,6 @@ def __init__(self, dimensions):
'[a, b, c, alpha, beta, gamma]')
raise ValueError(errmsg)

def __call__(self, ts):
def _transform(self, ts):
ts.dimensions = self.dimensions
return ts
42 changes: 34 additions & 8 deletions package/MDAnalysis/transformations/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@
from ..analysis import align
from ..lib.transformations import euler_from_matrix, euler_matrix

from .base import TransformationBase

class fit_translation(object):

class fit_translation(TransformationBase):
"""Translates a given AtomGroup so that its center of geometry/mass matches
the respective center of the given reference. A plane can be given by the
user using the option `plane`, and will result in the removal of
Expand Down Expand Up @@ -82,10 +84,17 @@ class fit_translation(object):
.. versionchanged:: 2.0.0
The transformation was changed from a function/closure to a class
with ``__call__``.
The transformation was changed from a function/closure to a class
with ``__call__``.
.. versionchanged:: 2.0.0
The transformation was changed to inherit from the base class for
limiting threads and checking if it can be used in parallel analysis.
"""
def __init__(self, ag, reference, plane=None, weights=None):
def __init__(self, ag, reference, plane=None, weights=None,
max_threads=None, parallelizable=True):
super().__init__(max_threads=max_threads,
parallelizable=parallelizable)

self.ag = ag
self.reference = reference
self.plane = plane
Expand Down Expand Up @@ -117,7 +126,7 @@ def __init__(self, ag, reference, plane=None, weights=None):
self.weights = align.get_weights(self.ref.atoms, weights=self.weights)
self.ref_com = self.ref.center(self.weights)

def __call__(self, ts):
def _transform(self, ts):
mobile_com = np.asarray(self.mobile.atoms.center(self.weights),
np.float32)
vector = self.ref_com - mobile_com
Expand All @@ -128,7 +137,7 @@ def __call__(self, ts):
return ts


class fit_rot_trans(object):
class fit_rot_trans(TransformationBase):
"""Perform a spatial superposition by minimizing the RMSD.
Spatially align the group of atoms `ag` to `reference` by doing a RMSD
Expand All @@ -140,6 +149,11 @@ class fit_rot_trans(object):
removed. This is useful for protein-membrane systems to where the membrane
must remain in the same orientation.
Note
----
``max_threads`` is set to 1 for this transformation
with which it performs better.
Example
-------
Removing the translations and rotations of a given AtomGroup `ag` on the XY plane
Expand Down Expand Up @@ -174,8 +188,20 @@ class fit_rot_trans(object):
Returns
-------
MDAnalysis.coordinates.base.Timestep
.. versionchanged:: 2.0.0
The transformation was changed from a function/closure to a class
with ``__call__``.
.. versionchanged:: 2.0.0
The transformation was changed to inherit from the base class for
limiting threads and checking if it can be used in parallel analysis.
"""
def __init__(self, ag, reference, plane=None, weights=None):
def __init__(self, ag, reference, plane=None, weights=None,
max_threads=1, parallelizable=True):
super().__init__(max_threads=max_threads,
parallelizable=parallelizable)

self.ag = ag
self.reference = reference
self.plane = plane
Expand Down Expand Up @@ -207,7 +233,7 @@ def __init__(self, ag, reference, plane=None, weights=None):
self.ref_com = self.ref.center(self.weights)
self.ref_coordinates = self.ref.atoms.positions - self.ref_com

def __call__(self, ts):
def _transform(self, ts):
mobile_com = self.mobile.atoms.center(self.weights)
mobile_coordinates = self.mobile.atoms.positions - mobile_com
rotation, dump = align.rotation_matrix(mobile_coordinates,
Expand Down
19 changes: 14 additions & 5 deletions package/MDAnalysis/transformations/positionaveraging.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@
import numpy as np
import warnings

from .base import TransformationBase

class PositionAverager(object):

class PositionAverager(TransformationBase):
"""
Averages the coordinates of a given timestep so that the coordinates
of the AtomGroup correspond to the average positions of the N previous
Expand Down Expand Up @@ -132,11 +134,18 @@ class PositionAverager(object):
Returns
-------
MDAnalysis.coordinates.base.Timestep
.. versionchanged:: 2.0.0
The transformation was changed to inherit from the base class for
limiting threads and checking if it can be used in parallel analysis.
"""

def __init__(self, avg_frames, check_reset=True):
def __init__(self, avg_frames, check_reset=True,
max_threads=None,
parallelizable=False):
super().__init__(max_threads=max_threads,
parallelizable=parallelizable)
self.avg_frames = avg_frames
self.check_reset = check_reset
self.current_avg = 0
Expand All @@ -162,7 +171,7 @@ def rollposx(self, ts):
self.coord_array = np.roll(self.coord_array, 1, axis=2)
self.coord_array[..., 0] = ts.positions.copy()

def __call__(self, ts):
def _transform(self, ts):
# calling the same timestep will not add new data to coord_array
# This can prevent from getting different values when
# call `u.trajectory[i]` multiple times.
Expand Down
Loading

0 comments on commit 13a5df0

Please sign in to comment.