From 404cd45ec75d5530e54a2ea10d4c7390115b5de7 Mon Sep 17 00:00:00 2001 From: Patricio Date: Tue, 3 Jan 2023 09:06:14 +0100 Subject: [PATCH] Add pickling of Atom, Residue, Segment and groups. Add reduce() method to both ResidueGroup and SegmentGroup Change _unpickle() so it can create any kind of group from its indices. Add one test for each group in the parallelism section. --- package/AUTHORS | 1 + package/CHANGELOG | 5 +++- package/MDAnalysis/core/groups.py | 26 ++++++++++++++++--- testsuite/MDAnalysisTests/core/test_atom.py | 12 +++++++-- .../MDAnalysisTests/core/test_atomgroup.py | 15 +++++++++++ .../MDAnalysisTests/core/test_residue.py | 7 ++++- .../MDAnalysisTests/core/test_residuegroup.py | 9 ++++++- .../MDAnalysisTests/core/test_segment.py | 7 +++++ .../MDAnalysisTests/core/test_segmentgroup.py | 17 +++++++++--- 9 files changed, 88 insertions(+), 11 deletions(-) diff --git a/package/AUTHORS b/package/AUTHORS index 2eeb0f0d937..e28256e0ba1 100644 --- a/package/AUTHORS +++ b/package/AUTHORS @@ -200,6 +200,7 @@ Chronological list of authors - Jennifer A Clark - Jake Fennick - Utsav Khatu + - Patricio Barletta External code diff --git a/package/CHANGELOG b/package/CHANGELOG index 9cc41c4cdae..0dea541bdde 100644 --- a/package/CHANGELOG +++ b/package/CHANGELOG @@ -13,7 +13,7 @@ The rules for this file: * release numbers follow "Semantic Versioning" http://semver.org ------------------------------------------------------------------------------ -??/??/?? IAlibay +??/??/?? IAlibay, pgbarletta * 2.5.0 @@ -21,6 +21,9 @@ Fixes Enhancements +* Add pickling support for Atom, Residue, Segment, ResidueGroup + and SegmentGroup. (PR #3953) + Changes Deprecations diff --git a/package/MDAnalysis/core/groups.py b/package/MDAnalysis/core/groups.py index b0e56abf94b..d3600d1c95e 100644 --- a/package/MDAnalysis/core/groups.py +++ b/package/MDAnalysis/core/groups.py @@ -119,6 +119,11 @@ def _unpickle(u, ix): return u.atoms[ix] +# TODO 3.0: deprecate _unpickle in favor of _unpickle2. +def _unpickle2(u, ix, cls): + return cls(ix, u) + + def _unpickle_uag(basepickle, selections, selstrs): bfunc, bargs = basepickle[0], basepickle[1:][0] basegroup = bfunc(*bargs) @@ -2525,9 +2530,6 @@ class AtomGroup(GroupBase): Indexing an AtomGroup with ``None`` raises a ``TypeError``. """ - def __reduce__(self): - return (_unpickle, (self.universe, self.ix)) - def __getattr__(self, attr): # special-case timestep info if attr in ('velocities', 'forces'): @@ -2536,6 +2538,9 @@ def __getattr__(self, attr): raise NoDataError('This Universe has no coordinates') return super(AtomGroup, self).__getattr__(attr) + def __reduce__(self): + return (_unpickle, (self.universe, self.ix)) + @property def atoms(self): """The :class:`AtomGroup` itself. @@ -3655,6 +3660,9 @@ class ResidueGroup(GroupBase): Indexing an ResidueGroup with ``None`` raises a ``TypeError``. """ + def __reduce__(self): + return (_unpickle2, (self.universe, self.ix, ResidueGroup)) + @property def atoms(self): """An :class:`AtomGroup` of :class:`Atoms` present in this @@ -3848,6 +3856,9 @@ class SegmentGroup(GroupBase): Indexing an SegmentGroup with ``None`` raises a ``TypeError``. """ + def __reduce__(self): + return (_unpickle2, (self.universe, self.ix, SegmentGroup)) + @property def atoms(self): """An :class:`AtomGroup` of :class:`Atoms` present in this @@ -4140,6 +4151,9 @@ def __repr__(self): me += ' and altLoc {}'.format(self.altLoc) return me + '>' + def __reduce__(self): + return (_unpickle2, (self.universe, self.ix, Atom)) + def __getattr__(self, attr): # special-case timestep info ts = {'velocity': 'velocities', 'force': 'forces'} @@ -4262,6 +4276,9 @@ def __repr__(self): return me + '>' + def __reduce__(self): + return (_unpickle2, (self.universe, self.ix, Residue)) + @property def atoms(self): """An :class:`AtomGroup` of :class:`Atoms` present in this @@ -4312,6 +4329,9 @@ def __repr__(self): me += ' {}'.format(self.segid) return me + '>' + def __reduce__(self): + return (_unpickle2, (self.universe, self.ix, Segment)) + @property def atoms(self): """An :class:`AtomGroup` of :class:`Atoms` present in this diff --git a/testsuite/MDAnalysisTests/core/test_atom.py b/testsuite/MDAnalysisTests/core/test_atom.py index e1d2f58b2c0..35d5d67477a 100644 --- a/testsuite/MDAnalysisTests/core/test_atom.py +++ b/testsuite/MDAnalysisTests/core/test_atom.py @@ -20,9 +20,11 @@ # MDAnalysis: A Toolkit for the Analysis of Molecular Dynamics Simulations. # J. Comput. Chem. 32 (2011), 2319--2327, doi:10.1002/jcc.21787 # -import MDAnalysis as mda -import numpy as np +import pickle import pytest +import numpy as np + +import MDAnalysis as mda from MDAnalysis import NoDataError from MDAnalysisTests.datafiles import ( PSF, DCD, @@ -113,6 +115,12 @@ def test_undefined_occupancy(self, universe): with pytest.raises(AttributeError): universe.atoms[0].occupancy + @pytest.mark.parametrize("ix", (1, -1)) + def test_atom_pickle(self, universe, ix): + atm_out = universe.atoms[ix] + atm_in = pickle.loads(pickle.dumps(atm_out)) + assert atm_in == atm_out + class TestAtomNoForceNoVel(object): @staticmethod diff --git a/testsuite/MDAnalysisTests/core/test_atomgroup.py b/testsuite/MDAnalysisTests/core/test_atomgroup.py index 2dcff29f7e4..dade39305c8 100644 --- a/testsuite/MDAnalysisTests/core/test_atomgroup.py +++ b/testsuite/MDAnalysisTests/core/test_atomgroup.py @@ -23,6 +23,7 @@ from glob import glob import itertools from os import path +import pickle import numpy as np @@ -1789,3 +1790,17 @@ def test_sort_position(self, ag): ref = [6, 5, 4, 3, 2, 1, 0] agsort = ag.sort("positions", keyfunc=lambda x: x[:, 1]) assert np.array_equal(ref, agsort.ix) + + +class TestAtomGroupPickle(object): + """Test AtomGroup pickling support.""" + + @pytest.fixture() + def universe(self): + return mda.Universe(PSF, DCD) + + @pytest.mark.parametrize("selection", ("name CA", "segid 4AKE")) + def test_atomgroup_pickle(self, universe, selection): + sel = universe.select_atoms(selection) + atm = pickle.loads(pickle.dumps(sel)) + assert_almost_equal(sel.positions, atm.positions) diff --git a/testsuite/MDAnalysisTests/core/test_residue.py b/testsuite/MDAnalysisTests/core/test_residue.py index 546eb80f4b6..68cd0f28268 100644 --- a/testsuite/MDAnalysisTests/core/test_residue.py +++ b/testsuite/MDAnalysisTests/core/test_residue.py @@ -24,7 +24,7 @@ assert_equal, ) import pytest - +import pickle import MDAnalysis as mda from MDAnalysisTests.datafiles import PSF, DCD @@ -54,3 +54,8 @@ def test_index(res): def test_atom_order(res): assert_equal(res.atoms.indices, sorted(res.atoms.indices)) + + +def test_residue_pickle(res): + res_in = pickle.loads(pickle.dumps(res)) + assert res_in == res diff --git a/testsuite/MDAnalysisTests/core/test_residuegroup.py b/testsuite/MDAnalysisTests/core/test_residuegroup.py index 835d4379806..1b661964b72 100644 --- a/testsuite/MDAnalysisTests/core/test_residuegroup.py +++ b/testsuite/MDAnalysisTests/core/test_residuegroup.py @@ -21,7 +21,8 @@ # J. Comput. Chem. 32 (2011), 2319--2327, doi:10.1002/jcc.21787 # import numpy as np -from numpy.testing import assert_equal +import pickle +from numpy.testing import assert_equal, assert_almost_equal import pytest import MDAnalysis as mda @@ -281,3 +282,9 @@ def test_get_prev_residue(self, rg): prev_resids = [r.resid if r is not None else None for r in prev_res] assert_equal(len(prev_res), len(unsorted_rep_res)) assert_equal(prev_resids, resids) + + @pytest.mark.parametrize("selection", ("name CA", "segid 4AKE")) + def test_residuegroup_pickle(self, universe, selection): + seg_res = universe.select_atoms(selection).residues + seg = pickle.loads(pickle.dumps(seg_res)) + assert_almost_equal(seg_res.atoms.positions, seg.atoms.positions) diff --git a/testsuite/MDAnalysisTests/core/test_segment.py b/testsuite/MDAnalysisTests/core/test_segment.py index e83eacefe18..3e0c675d5a3 100644 --- a/testsuite/MDAnalysisTests/core/test_segment.py +++ b/testsuite/MDAnalysisTests/core/test_segment.py @@ -24,6 +24,7 @@ assert_equal, ) import pytest +import pickle import MDAnalysis as mda @@ -62,3 +63,9 @@ def test_advanced_slicing(self, sB): def test_atom_order(self, universe): assert_equal(universe.segments[0].atoms.indices, sorted(universe.segments[0].atoms.indices)) + + @pytest.mark.parametrize("ix", (1, -1)) + def test_residue_pickle(self, universe, ix): + seg_out = universe.segments[ix] + seg_in = pickle.loads(pickle.dumps(seg_out)) + assert seg_in == seg_out diff --git a/testsuite/MDAnalysisTests/core/test_segmentgroup.py b/testsuite/MDAnalysisTests/core/test_segmentgroup.py index 546c5ad44cf..c47f1f6367d 100644 --- a/testsuite/MDAnalysisTests/core/test_segmentgroup.py +++ b/testsuite/MDAnalysisTests/core/test_segmentgroup.py @@ -20,10 +20,9 @@ # MDAnalysis: A Toolkit for the Analysis of Molecular Dynamics Simulations. # J. Comput. Chem. 32 (2011), 2319--2327, doi:10.1002/jcc.21787 # -from numpy.testing import ( - assert_equal, -) +from numpy.testing import assert_equal import pytest +import pickle import MDAnalysis as mda @@ -109,3 +108,15 @@ def test_set_segids_many(): def test_atom_order(universe): assert_equal(universe.segments.atoms.indices, sorted(universe.segments.atoms.indices)) + + +def test_segmentgroup_pickle(): + u = mda.Universe.empty(10) + u.add_Segment(segid="X") + u.add_Segment(segid="Y") + u.add_Segment(segid="Z") + segids = ["A", "X", "Y", "Z"] + u.add_TopologyAttr("segids", values=["A", "X", "Y", "Z"]) + seg_group = mda.SegmentGroup((1, 3), u) + seg = pickle.loads(pickle.dumps(seg_group)) + assert_equal(seg.universe.segments.segids, segids)