Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pickling of Atom, Residue, Segment and groups. #3953

Merged
merged 1 commit into from
Jan 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions package/AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ Chronological list of authors
- Jennifer A Clark
- Jake Fennick
- Utsav Khatu
- Patricio Barletta


External code
Expand Down
5 changes: 4 additions & 1 deletion package/CHANGELOG
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,17 @@ The rules for this file:
* release numbers follow "Semantic Versioning" http://semver.org

------------------------------------------------------------------------------
??/??/?? IAlibay
??/??/?? IAlibay, pgbarletta

* 2.5.0

Fixes

Enhancements

* Add pickling support for Atom, Residue, Segment, ResidueGroup
pgbarletta marked this conversation as resolved.
Show resolved Hide resolved
and SegmentGroup. (PR #3953)

Changes

Deprecations
Expand Down
26 changes: 23 additions & 3 deletions package/MDAnalysis/core/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ def _unpickle(u, ix):
return u.atoms[ix]


# TODO 3.0: deprecate _unpickle in favor of _unpickle2.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a comment here, unless it changes user-facing behaviour there is no need to deprecate a private method like this. We can just throw it out right now and none would be the wiser.

Please raise a relevant issue regarding this though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

def _unpickle2(u, ix, cls):
hmacdope marked this conversation as resolved.
Show resolved Hide resolved
return cls(ix, u)


def _unpickle_uag(basepickle, selections, selstrs):
bfunc, bargs = basepickle[0], basepickle[1:][0]
basegroup = bfunc(*bargs)
Expand Down Expand Up @@ -2525,9 +2530,6 @@ class AtomGroup(GroupBase):
Indexing an AtomGroup with ``None`` raises a ``TypeError``.
"""

def __reduce__(self):
return (_unpickle, (self.universe, self.ix))

def __getattr__(self, attr):
# special-case timestep info
if attr in ('velocities', 'forces'):
Expand All @@ -2536,6 +2538,9 @@ def __getattr__(self, attr):
raise NoDataError('This Universe has no coordinates')
return super(AtomGroup, self).__getattr__(attr)

def __reduce__(self):
return (_unpickle, (self.universe, self.ix))

@property
def atoms(self):
"""The :class:`AtomGroup` itself.
Expand Down Expand Up @@ -3655,6 +3660,9 @@ class ResidueGroup(GroupBase):
Indexing an ResidueGroup with ``None`` raises a ``TypeError``.
"""

def __reduce__(self):
return (_unpickle2, (self.universe, self.ix, ResidueGroup))

@property
def atoms(self):
"""An :class:`AtomGroup` of :class:`Atoms<Atom>` present in this
Expand Down Expand Up @@ -3848,6 +3856,9 @@ class SegmentGroup(GroupBase):
Indexing an SegmentGroup with ``None`` raises a ``TypeError``.
"""

def __reduce__(self):
return (_unpickle2, (self.universe, self.ix, SegmentGroup))

@property
def atoms(self):
"""An :class:`AtomGroup` of :class:`Atoms<Atom>` present in this
Expand Down Expand Up @@ -4140,6 +4151,9 @@ def __repr__(self):
me += ' and altLoc {}'.format(self.altLoc)
return me + '>'

def __reduce__(self):
return (_unpickle2, (self.universe, self.ix, Atom))

def __getattr__(self, attr):
# special-case timestep info
ts = {'velocity': 'velocities', 'force': 'forces'}
Expand Down Expand Up @@ -4262,6 +4276,9 @@ def __repr__(self):

return me + '>'

def __reduce__(self):
return (_unpickle2, (self.universe, self.ix, Residue))

@property
def atoms(self):
"""An :class:`AtomGroup` of :class:`Atoms<Atom>` present in this
Expand Down Expand Up @@ -4312,6 +4329,9 @@ def __repr__(self):
me += ' {}'.format(self.segid)
return me + '>'

def __reduce__(self):
return (_unpickle2, (self.universe, self.ix, Segment))

@property
def atoms(self):
"""An :class:`AtomGroup` of :class:`Atoms<Atom>` present in this
Expand Down
12 changes: 10 additions & 2 deletions testsuite/MDAnalysisTests/core/test_atom.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
# MDAnalysis: A Toolkit for the Analysis of Molecular Dynamics Simulations.
# J. Comput. Chem. 32 (2011), 2319--2327, doi:10.1002/jcc.21787
#
import MDAnalysis as mda
import numpy as np
import pickle
import pytest
import numpy as np

import MDAnalysis as mda
from MDAnalysis import NoDataError
from MDAnalysisTests.datafiles import (
PSF, DCD,
Expand Down Expand Up @@ -113,6 +115,12 @@ def test_undefined_occupancy(self, universe):
with pytest.raises(AttributeError):
universe.atoms[0].occupancy

@pytest.mark.parametrize("ix", (1, -1))
def test_atom_pickle(self, universe, ix):
atm_out = universe.atoms[ix]
atm_in = pickle.loads(pickle.dumps(atm_out))
assert atm_in == atm_out


class TestAtomNoForceNoVel(object):
@staticmethod
Expand Down
15 changes: 15 additions & 0 deletions testsuite/MDAnalysisTests/core/test_atomgroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from glob import glob
import itertools
from os import path
import pickle

import numpy as np

Expand Down Expand Up @@ -1789,3 +1790,17 @@ def test_sort_position(self, ag):
ref = [6, 5, 4, 3, 2, 1, 0]
agsort = ag.sort("positions", keyfunc=lambda x: x[:, 1])
assert np.array_equal(ref, agsort.ix)


class TestAtomGroupPickle(object):
"""Test AtomGroup pickling support."""

@pytest.fixture()
def universe(self):
return mda.Universe(PSF, DCD)

@pytest.mark.parametrize("selection", ("name CA", "segid 4AKE"))
def test_atomgroup_pickle(self, universe, selection):
sel = universe.select_atoms(selection)
atm = pickle.loads(pickle.dumps(sel))
assert_almost_equal(sel.positions, atm.positions)
7 changes: 6 additions & 1 deletion testsuite/MDAnalysisTests/core/test_residue.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
assert_equal,
)
import pytest

import pickle
import MDAnalysis as mda

from MDAnalysisTests.datafiles import PSF, DCD
Expand Down Expand Up @@ -54,3 +54,8 @@ def test_index(res):
def test_atom_order(res):
assert_equal(res.atoms.indices,
sorted(res.atoms.indices))


def test_residue_pickle(res):
res_in = pickle.loads(pickle.dumps(res))
assert res_in == res
9 changes: 8 additions & 1 deletion testsuite/MDAnalysisTests/core/test_residuegroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
# J. Comput. Chem. 32 (2011), 2319--2327, doi:10.1002/jcc.21787
#
import numpy as np
from numpy.testing import assert_equal
import pickle
from numpy.testing import assert_equal, assert_almost_equal
import pytest

import MDAnalysis as mda
Expand Down Expand Up @@ -281,3 +282,9 @@ def test_get_prev_residue(self, rg):
prev_resids = [r.resid if r is not None else None for r in prev_res]
assert_equal(len(prev_res), len(unsorted_rep_res))
assert_equal(prev_resids, resids)

@pytest.mark.parametrize("selection", ("name CA", "segid 4AKE"))
def test_residuegroup_pickle(self, universe, selection):
seg_res = universe.select_atoms(selection).residues
seg = pickle.loads(pickle.dumps(seg_res))
assert_almost_equal(seg_res.atoms.positions, seg.atoms.positions)
7 changes: 7 additions & 0 deletions testsuite/MDAnalysisTests/core/test_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
assert_equal,
)
import pytest
import pickle

import MDAnalysis as mda

Expand Down Expand Up @@ -62,3 +63,9 @@ def test_advanced_slicing(self, sB):
def test_atom_order(self, universe):
assert_equal(universe.segments[0].atoms.indices,
sorted(universe.segments[0].atoms.indices))

@pytest.mark.parametrize("ix", (1, -1))
def test_residue_pickle(self, universe, ix):
seg_out = universe.segments[ix]
seg_in = pickle.loads(pickle.dumps(seg_out))
assert seg_in == seg_out
17 changes: 14 additions & 3 deletions testsuite/MDAnalysisTests/core/test_segmentgroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,9 @@
# MDAnalysis: A Toolkit for the Analysis of Molecular Dynamics Simulations.
# J. Comput. Chem. 32 (2011), 2319--2327, doi:10.1002/jcc.21787
#
from numpy.testing import (
assert_equal,
)
from numpy.testing import assert_equal
import pytest
import pickle

import MDAnalysis as mda

Expand Down Expand Up @@ -109,3 +108,15 @@ def test_set_segids_many():
def test_atom_order(universe):
assert_equal(universe.segments.atoms.indices,
sorted(universe.segments.atoms.indices))


def test_segmentgroup_pickle():
u = mda.Universe.empty(10)
u.add_Segment(segid="X")
u.add_Segment(segid="Y")
u.add_Segment(segid="Z")
segids = ["A", "X", "Y", "Z"]
u.add_TopologyAttr("segids", values=["A", "X", "Y", "Z"])
seg_group = mda.SegmentGroup((1, 3), u)
seg = pickle.loads(pickle.dumps(seg_group))
assert_equal(seg.universe.segments.segids, segids)