diff --git a/mdsynthesis/core/aggregators.py b/mdsynthesis/core/aggregators.py index c58f8a9..b042bc2 100644 --- a/mdsynthesis/core/aggregators.py +++ b/mdsynthesis/core/aggregators.py @@ -9,6 +9,7 @@ """ from MDAnalysis import Universe +from MDAnalysis.core.AtomGroup import AtomGroup import os import numpy as np from functools import wraps @@ -561,10 +562,9 @@ def __setitem__(self, handle, selection): """Selection for the given handle and the active universe. """ - if isinstance(selection, basestring): + if isinstance(selection, (basestring, AtomGroup)): selection = [selection] - self._backend.add_selection( - self._container._uname, handle, *selection) + self.add(handle, *selection) def __iter__(self): return self._backend.list_selections( @@ -593,12 +593,16 @@ def add(self, handle, *selection): *handle* name to use for the selection *selection* - selection string; multiple strings may be given and their - order will be preserved, which is useful for e.g. structural - alignments + selection string or AtomGroup; multiple selections may be given + and their order will be preserved, which is useful for e.g. + structural alignments """ + # Conversion function, leave strings alone, + # turn AtomGroups into their indices + conv = lambda x: x if isinstance(x, basestring) else x.indices() + self._backend.add_selection( - self._container._uname, handle, *selection) + self._container._uname, handle, *map(conv, selection)) def remove(self, *handle): """Remove an atom selection for the attached universe. @@ -643,7 +647,13 @@ def asAtomGroup(self, handle): raise KeyError( "No such selection '{}'; add it first.".format(handle)) - return self._container.universe.selectAtoms(*selstring) + # Selections might be either + # - a list of strings + # - a numpy array of indices + if isinstance(selstring[0], basestring): + return self._container.universe.selectAtoms(*selstring) + else: + return self._container.universe.atoms[selstring] def define(self, handle): """Get selection definition for given handle and the active universe. diff --git a/mdsynthesis/core/persistence.py b/mdsynthesis/core/persistence.py index d3f126b..aee1106 100644 --- a/mdsynthesis/core/persistence.py +++ b/mdsynthesis/core/persistence.py @@ -703,6 +703,12 @@ class _Selection(tables.IsDescription): """ selection = tables.StringCol(255) + class _SelectionAtoms(tables.IsDescription): + """Table definition for storing selections as indices. + + """ + selection = tables.UInt32Col() + class _Resnums(tables.IsDescription): """Table definition for storing resnums. @@ -1057,19 +1063,25 @@ def add_selection(self, universe, handle, *selection): *handle* name to use for the selection *selection* - selection string; multiple strings may be given and their - order will be preserved, which is useful for e.g. structural - alignments + selection string or numpy array of indices; multiple selections + may be given and their order will be preserved, which is + useful for e.g. structural alignments """ # TODO: add check for existence of selection table # TODO: add check for selections as strings; use finally statements # to delete table in case of failure # construct selection table + if isinstance(selection[0], basestring): + seltype = self._Selection + elif isinstance(selection[0], np.ndarray): + seltype = self._SelectionAtoms + selection = selection[0] + try: table = self.handle.create_table( '/universes/{}/selections'.format(universe), handle, - self._Selection, handle) + seltype, handle) except tables.NodeError: self.logger.info( "Replacing existing selection '{}'.".format(handle)) @@ -1077,7 +1089,7 @@ def add_selection(self, universe, handle, *selection): '/universes/{}/selections'.format(universe), handle) table = self.handle.create_table( '/universes/{}/selections'.format(universe), handle, - self._Selection, handle) + seltype, handle) # add selections to table for item in selection: diff --git a/tests/test_containers.py b/tests/test_containers.py index 01e46a3..60c8e22 100644 --- a/tests/test_containers.py +++ b/tests/test_containers.py @@ -440,6 +440,44 @@ def test_selection_get(self, container): with pytest.raises(KeyError): container.selections['CA'] + def test_add_selections_multiple_strings_via_add(self, container): + """Add a selection that has multiple selection strings""" + container.selections.add('funky town', 'name N', 'name CA') + assert 'funky town' in container.selections + + ref = container.universe.selectAtoms('name N', 'name CA') + sel = container.selections['funky town'] + assert (ref.indices() == sel.indices()).all() + + def test_add_selections_multiple_strings_via_setitem(self, container): + """Add a selection that has multiple selection strings""" + container.selections['funky town 2'] = 'name N', 'name CA' + assert 'funky town 2' in container.selections + + ref = container.universe.selectAtoms('name N', 'name CA') + sel = container.selections['funky town 2'] + assert (ref.indices() == sel.indices()).all() + + def test_add_selection_as_atomgroup_via_add(self, container): + """Make an arbitrary AtomGroup then save selection as AG""" + ag = container.universe.atoms[:10:2] + + container.selections.add('ag sel', ag) + assert 'ag sel' in container.selections + + ag2 = container.selections['ag sel'] + assert (ag.indices() == ag2.indices()).all() + + def test_add_selection_as_atomgroup_via_setitem(self, container): + """Make an arbitrary AtomGroup then save selection as AG""" + ag = container.universe.atoms[25:50:3] + + container.selections['ag sel 2'] = ag + assert 'ag sel 2' in container.selections + + ag2 = container.selections['ag sel 2'] + assert (ag.indices() == ag2.indices()).all() + class TestGroup(TestContainer): """Test Group-specific features.