From 5595681741c1b1b4855d44f7073a7ce339916086 Mon Sep 17 00:00:00 2001 From: Jonathan Barnoud Date: Tue, 12 Jun 2018 12:34:24 +0200 Subject: [PATCH 1/5] Frame iterables returned when indexing trajectories Replaces the iterators returned by the proto reader __getitem__ by an iterable class. That class has a __len__ method to adress #1894. In the current implementation, each type of input for the proto reader's __getitem__ returns a different class. This allow to test for the correct behavior only once. Ideally, the iterables should be imutable to avoid inconsistencies. Fixes #1894 --- package/MDAnalysis/coordinates/__init__.py | 9 +- package/MDAnalysis/coordinates/base.py | 237 ++++++++++++++++-- .../coordinates/test_reader_api.py | 91 ++++++- 3 files changed, 315 insertions(+), 22 deletions(-) diff --git a/package/MDAnalysis/coordinates/__init__.py b/package/MDAnalysis/coordinates/__init__.py index 6246672c589..2706f492151 100644 --- a/package/MDAnalysis/coordinates/__init__.py +++ b/package/MDAnalysis/coordinates/__init__.py @@ -502,7 +502,7 @@ class can choose an appropriate reader automatically. ``__getitem__(arg)`` advance to time step `arg` = `frame` and return :class:`Timestep`; or if `arg` is a - slice, then return an iterator over that part of the trajectory. + slice, then return an iterable over that part of the trajectory. The first functionality allows one to randomly access frames in the trajectory:: @@ -524,6 +524,9 @@ class can choose an appropriate reader automatically. The last example starts reading the trajectory at frame 1000 and reads every 100th frame until the end. + A sequence of indices or a mask of booleans can also be provided to index + a trajectory. + The performance of the ``__getitem__()`` method depends on the underlying trajectory reader and if it can implement random access to frames. In many cases this is not easily (or reliably) implementable and thus one is @@ -537,6 +540,10 @@ class can choose an appropriate reader automatically. :class:`MDAnalysis.coordinates.base.ProtoReader.__iter__` (which is always implemented) and other slices raise :exc:`TypeError`. + When indexed with a slice, a sequence of indices, or a mask of booleans, + the return value is an instance of :class:`FrameIteratorSliced` or + :class:`FrameIteratorIndices`. + ``parse_n_atoms(filename, **kwargs)`` Provide the number of atoms in the trajectory file, allowing the Reader to be used to provide an extremely minimal Topology. diff --git a/package/MDAnalysis/coordinates/base.py b/package/MDAnalysis/coordinates/base.py index 859f7a29c80..0add0954e3f 100644 --- a/package/MDAnalysis/coordinates/base.py +++ b/package/MDAnalysis/coordinates/base.py @@ -859,6 +859,205 @@ def time(self): del self.data['time'] +class FrameIteratorBase(object): + """ + Base iterable over the frames of a trajectory. + + A frame iterable has a length that can be accessed with the :func:`len` + function, and can be indexed similarly to a full trajectory. When indexed, + indices are resolved relative to the iterable and not relative to the + trajectory. + + Parameters + ---------- + trajectory: ProtoReader + The trajectory over which to iterate. + + .. versionadded:: 0.19.0 + + """ + def __init__(self, trajectory): + self._trajectory = trajectory + + def __len__(self): + raise NotImplementedError() + + @staticmethod + def _avoid_bool_list(frames): + if isinstance(frames, list) and frames and isinstance(frames[0], bool): + return np.array(frames, dtype=bool) + return frames + + @property + def trajectory(self): + return self._trajectory + + +class FrameIteratorSliced(FrameIteratorBase): + """ + Iterable over the frames of a trajectory on the basis of a slice. + + Parameters + ---------- + trajectory: ProtoReader + The trajectory over which to iterate. + frames: slice + A slice to select the frames of interest. + + See Also + -------- + FrameIteratorBase + + .. versionadded:: 0.19.0 + + """ + def __init__(self, trajectory, frames): + # It would be easier to store directly a range object, as it would + # store its parameters in a single place, calculate its length, and + # take care of most the indexing. Though, doing so is not compatible + # with python 2 where xrange (or range with six) is only an iterator. + super(FrameIteratorSliced, self).__init__(trajectory) + self._start, self._stop, self._step = trajectory.check_slice_indices( + frames.start, frames.stop, frames.step, + ) + + def __len__(self): + start, stop, step = self.start, self.stop, self.step + if (step > 0 and start < stop): + # We go from a lesser number to a larger one. + return int(1 + (stop - 1 - start) // step) + elif (step < 0 and start > stop): + # We count backward from a larger number to a lesser one. + return int(1 + (start - 1 - stop) // (-step)) + else: + # The range is empty. + return 0 + + def __iter__(self): + for i in range(self.start, self.stop, self.step): + yield self.trajectory._read_frame_with_aux(i) + self.trajectory.rewind() + + def __getitem__(self, frame): + if isinstance(frame, numbers.Integral): + length = len(self) + if not -length < frame < length: + raise IndexError('Index {} is out of range of the range of length {}.' + .format(frame, length)) + if frame < 0: + frame = len(self) + frame + frame = self.start + frame * self.step + return self.trajectory._read_frame_with_aux(frame) + elif isinstance(frame, slice): + start = self.start + (frame.start or 0) * self.step + if frame.stop is None: + stop = self.stop + else: + stop = self.start + (frame.stop or 0) * self.step + step = (frame.step or 1) * self.step + + if step > 0: + start = max(0, start) + else: + stop = max(0, stop) + + new_slice = slice(start, stop, step) + return FrameIteratorSliced(self.trajectory, new_slice) + else: + # Indexing with a lists of bools does not behave the same in all + # version of numpy. + frame = self._avoid_bool_list(frame) + frames = np.array(list(range(self.start, self.stop, self.step)))[frame] + return FrameIteratorIndices(self.trajectory, frames) + + @property + def start(self): + return self._start + + @property + def stop(self): + return self._stop + + @property + def step(self): + return self._step + + +class FrameIteratorAll(FrameIteratorBase): + """ + Iterable over all the frames of a trajectory. + + Parameters + ---------- + trajectory: ProtoReader + The trajectory over which to iterate. + + See Also + -------- + FrameIteratorBase + + .. versionadded:: 0.19.0 + + """ + def __init__(self, trajectory): + super(FrameIteratorAll, self).__init__(trajectory) + + def __len__(self): + return self.trajectory.n_frames + + def __iter__(self): + return iter(self.trajectory) + + def __getitem__(self, frame): + return self.trajectory[frame] + + +class FrameIteratorIndices(FrameIteratorBase): + """ + Iterable over the frames of a trajectory listed in a sequence of indices. + + Parameters + ---------- + trajectory: ProtoReader + The trajectory over which to iterate. + frames: sequence + A sequence of indices. + + See Also + -------- + FrameIteratorBase + """ + def __init__(self, trajectory, frames): + super(FrameIteratorIndices, self).__init__(trajectory) + self._frames = [] + for frame in frames: + if not isinstance(frame, numbers.Integral): + raise TypeError("Frames indices must be integers.") + frame = trajectory._apply_limits(frame) + self._frames.append(frame) + self._frames = tuple(self._frames) + + def __len__(self): + return len(self.frames) + + def __iter__(self): + for frame in self.frames: + yield self.trajectory._read_frame_with_aux(frame) + + def __getitem__(self, frame): + if isinstance(frame, numbers.Integral): + frame = self.frames[frame] + return self.trajectory._read_frame_with_aux(frame) + else: + frame = self._avoid_bool_list(frame) + frames = np.array(self.frames)[frame] + return FrameIteratorIndices(self.trajectory, frames) + + @property + def frames(self): + return self._frames + + class IOBase(object): """Base class bundling common functionality for trajectory I/O. @@ -1299,6 +1498,14 @@ def _reopen(self): """ pass + def _apply_limits(self, frame): + if frame < 0: + frame += len(self) + if frame < 0 or frame >= len(self): + raise IndexError("Index {} exceeds length of trajectory ({})." + "".format(frame, len(self))) + return frame + def __getitem__(self, frame): """Return the Timestep corresponding to *frame*. @@ -1312,17 +1519,8 @@ def __getitem__(self, frame): ---- *frame* is a 0-based frame index. """ - - def apply_limits(frame): - if frame < 0: - frame += len(self) - if frame < 0 or frame >= len(self): - raise IndexError("Index {} exceeds length of trajectory ({})." - "".format(frame, len(self))) - return frame - if isinstance(frame, numbers.Integral): - frame = apply_limits(frame) + frame = self._apply_limits(frame) return self._read_frame_with_aux(frame) elif isinstance(frame, (list, np.ndarray)): if isinstance(frame[0], (bool, np.bool_)): @@ -1331,20 +1529,21 @@ def apply_limits(frame): # Convert bool array to int array frame = np.arange(len(self))[frame] - def listiter(frames): - for f in frames: - if not isinstance(f, numbers.Integral): - raise TypeError("Frames indices must be integers") - yield self._read_frame_with_aux(apply_limits(f)) - - return listiter(frame) + #def listiter(frames): + # for f in frames: + # if not isinstance(f, numbers.Integral): + # raise TypeError("Frames indices must be integers") + # yield self._read_frame_with_aux(apply_limits(f)) + # + #return listiter(frame) + return FrameIteratorIndices(self, frame) elif isinstance(frame, slice): start, stop, step = self.check_slice_indices( frame.start, frame.stop, frame.step) if start == 0 and stop == len(self) and step == 1: - return self.__iter__() + return FrameIteratorAll(self) else: - return self._sliced_iter(start, stop, step) + return FrameIteratorSliced(self, frame) else: raise TypeError("Trajectories must be an indexed using an integer," " slice or list of indices") diff --git a/testsuite/MDAnalysisTests/coordinates/test_reader_api.py b/testsuite/MDAnalysisTests/coordinates/test_reader_api.py index 6f236b3e2ac..60e4f0889d9 100644 --- a/testsuite/MDAnalysisTests/coordinates/test_reader_api.py +++ b/testsuite/MDAnalysisTests/coordinates/test_reader_api.py @@ -59,6 +59,10 @@ def _read_next_timestep(self): return self.ts def _read_frame(self, frame): + if frame < 0: + frame = self.n_frames + frame + if not (0 <= frame < self.n_frames): + raise IOError self.ts.frame = frame return self.ts @@ -149,6 +153,8 @@ class TestMultiFrameReader(_Multi): (2, 5, None), # start & end (None, None, 2), # set skip (None, None, -1), # backwards skip + (None, -1, -1), + (10, 0, -1), (0, 10, 1), (0, 10, 2), (None, 20, None), # end beyond real end @@ -187,7 +193,6 @@ def sl(): with pytest.raises(TypeError): sl() - @pytest.mark.parametrize('slice_cls', [list, np.array]) @pytest.mark.parametrize('sl', [ [0, 1, 4, 5], @@ -207,6 +212,76 @@ def test_getitem(self, slice_cls, sl, reader): assert_equal(res, ref) + @pytest.mark.parametrize('sl', [ + [0, 1, 2, 3], # ordered list of indices without duplicates + [1, 3, 4, 2, 9], # disordered list of indices without duplicates + [0, 1, 1, 2, 2, 2], # ordered list with duplicates + [-1, -2, 3, -1, 0], # disordered list with duplicates + [True, ] * 10, + [False, ] * 10, + [True, False, ] * 5, + slice(None, None, None), + slice(0, 10, 1), + slice(None, None, -1), + slice(10, 0, -1), + slice(2, 7, 2), + slice(7, 2, -2), + ]) + def test_getitem_len(self, sl, reader): + traj_iterable = reader[sl] + if not isinstance(sl, slice): + sl = np.array(sl) + ref = self.reference[sl] + assert len(traj_iterable) == len(ref) + + # All the sl1 slice must be 5 frames long so that the sl2 can be a mask + @pytest.mark.parametrize('sl1', [ + [0, 1, 2, 3, 4], + [1, 1, 1, 1, 1], + [True, False, ] * 5, + slice(None, None, 2), + slice(None, None, -2), + ]) + @pytest.mark.parametrize('sl2', [ + [0, -1, 2], + [-1,-1, -1], + [True, False, True, True, False], + np.array([True, False, True, True, False]), + slice(None, None, None), + slice(None, 3, None), + slice(4, 0, -1), + ]) + def test_double_getitem(self, sl1, sl2, reader): + traj_iterable = reader[sl1][sl2] + # Old versions of numpy do not behave the same when indexing with a + # list or with an array. + if not isinstance(sl1, slice): + sl1 = np.asarray(sl1) + if not isinstance(sl2, slice): + sl2 = np.asarray(sl2) + print(sl1, sl2, type(sl1), type(sl2)) + ref = self.reference[sl1][sl2] + res = [ts.frame for ts in traj_iterable] + assert_equal(res, ref) + assert len(traj_iterable) == len(ref) + + @pytest.mark.parametrize('sl1', [ + [0, 1, 2, 3, 4], + [1, 1, 1, 1, 1], + [True, False, ] * 5, + slice(None, None, 2), + slice(None, None, -2), + ]) + @pytest.mark.parametrize('idx2', [0, 2, 4, -1, -2, -4]) + def test_double_getitem_int(self, sl1, idx2, reader): + ts = reader[sl1][idx2] + # Old versions of numpy do not behave the same when indexing with a + # list or with an array. + if not isinstance(sl1, slice): + sl1 = np.asarray(sl1) + ref = self.reference[sl1][idx2] + assert ts.frame == ref + def test_list_TE(self, reader): def sl(): return list(reader[[0, 'a', 5, 6]]) @@ -214,7 +289,6 @@ def sl(): with pytest.raises(TypeError): sl() - def test_array_TE(self, reader): def sl(): return list(reader[np.array([1.2, 3.4, 5.6])]) @@ -222,6 +296,19 @@ def sl(): with pytest.raises(TypeError): sl() + @pytest.mark.parametrize('sl1', [ + [0, 1, 2, 3, 4], + [1, 1, 1, 1, 1], + [True, False, ] * 5, + slice(None, None, 2), + slice(None, None, -2), + ]) + @pytest.mark.parametrize('idx2', [5, -6]) + def test_getitem_IE(self, sl1, idx2, reader): + partial_reader = reader[sl1] + with pytest.raises(IndexError): + partial_reader[idx2] + class _Single(_TestReader): n_frames = 1 From 15222e3588bfb25390b15edb4851e82a3d6920db Mon Sep 17 00:00:00 2001 From: Jonathan Barnoud Date: Sun, 17 Jun 2018 09:34:07 +0200 Subject: [PATCH 2/5] Fixes #1944 --- package/MDAnalysis/coordinates/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/package/MDAnalysis/coordinates/base.py b/package/MDAnalysis/coordinates/base.py index 0add0954e3f..28bd3ede310 100644 --- a/package/MDAnalysis/coordinates/base.py +++ b/package/MDAnalysis/coordinates/base.py @@ -1663,7 +1663,7 @@ def check_slice_indices(self, start, stop, step): if start < 0: start = 0 - if step < 0 and start > nframes: + if step < 0 and start >= nframes: start = nframes - 1 if stop is None: From 1ef80100577f8dc5457679ea17460e37413b56b5 Mon Sep 17 00:00:00 2001 From: Jonathan Barnoud Date: Sun, 8 Jul 2018 13:22:42 +0200 Subject: [PATCH 3/5] Allow ag.write to write trajectory and selected frames Fixes #1037 --- package/MDAnalysis/coordinates/base.py | 5 ++ package/MDAnalysis/core/groups.py | 55 +++++++++++++++---- .../MDAnalysisTests/core/test_atomgroup.py | 48 +++++++++++++++- 3 files changed, 96 insertions(+), 12 deletions(-) diff --git a/package/MDAnalysis/coordinates/base.py b/package/MDAnalysis/coordinates/base.py index 28bd3ede310..34e79e02135 100644 --- a/package/MDAnalysis/coordinates/base.py +++ b/package/MDAnalysis/coordinates/base.py @@ -1450,6 +1450,11 @@ def time(self): """ return self.ts.time + @property + def trajectory(self): + # Makes a reader effectively commpatible with a FrameIteratorBase + return self + def Writer(self, filename, **kwargs): """A trajectory writer with the same properties as this trajectory.""" raise NotImplementedError( diff --git a/package/MDAnalysis/core/groups.py b/package/MDAnalysis/core/groups.py index 68217167934..6bddd71f9f3 100644 --- a/package/MDAnalysis/core/groups.py +++ b/package/MDAnalysis/core/groups.py @@ -2531,7 +2531,7 @@ def improper(self): return topologyobjects.ImproperDihedral(self.ix, self.universe) def write(self, filename=None, file_format="PDB", - filenamefmt="{trjname}_{frame}", **kwargs): + filenamefmt="{trjname}_{frame}", frames=None, **kwargs): """Write `AtomGroup` to a file. The output can either be a coordinate file or a selection, depending on @@ -2554,19 +2554,46 @@ def write(self, filename=None, file_format="PDB", ``"conect"``: write only the CONECT records defined in the original file. ``"all"``: write out all bonds, both the original defined and those guessed by MDAnalysis. ``None``: do not write out bonds. - Default os ``"conect"``. + Default is ``"conect"``. + frames: + An ensemble of frames to write. The ensemble can be an list or + array of frame indices, a mask of booleans, an instance of + :class:`slice`, or an indexed trajectory. By default, 'frames' is + set to ``None`` and only the current frame is written. .. versionchanged:: 0.9.0 Merged with write_selection. This method can now write both selections out. + .. versionchanged:: 0.19.0 + Can write multiframe trajectories with the 'frames' argument. """ # check that AtomGroup actually has any atoms (Issue #434) if len(self.atoms) == 0: raise IndexError("Cannot write an AtomGroup with 0 atoms") trj = self.universe.trajectory # unified trajectory API - - if trj.n_frames == 1: + if frames is None: + trj_frames = trj[::] + else: + try: + test_trajectory = frames.trajectory + except AttributeError: + trj_frames = trj[frames] + else: + if test_trajectory is not trj: + raise ValueError( + 'The trajectory of {} provided to the frames keyword ' + 'attribute is different from the trajectory of the ' + 'AtomGroup.'.format(frames) + ) + trj_frames = frames + if len(trj_frames) > 1 and kwargs.get("multiframe") == False: + raise ValueError( + 'Cannot explicitely set "multiframe" to False and request ' + 'more than 1 frame with the "frames" keyword argument.' + ) + + if len(trj_frames) == 1: kwargs.setdefault("multiframe", False) if filename is None: @@ -2590,23 +2617,29 @@ def write(self, filename=None, file_format="PDB", writer = get_writer_for(filename, format=format, multiframe=multiframe) #MDAnalysis.coordinates.writer(filename, **kwargs) - coords = True except (ValueError, TypeError): - coords = False + pass + else: + with writer(filename, n_atoms=self.n_atoms, **kwargs) as w: + if frames is None: + w.write(self.atoms) + else: + for _ in trj_frames: + w.write(self.atoms) + return try: # here `file_format` is only used as default, # anything pulled off `filename` will be used preferentially writer = get_selection_writer_for(filename, file_format) - selection = True except (TypeError, NotImplementedError): - selection = False - - if not (coords or selection): - raise ValueError("No writer found for format: {}".format(filename)) + pass else: with writer(filename, n_atoms=self.n_atoms, **kwargs) as w: w.write(self.atoms) + return + + raise ValueError("No writer found for format: {}".format(filename)) class ResidueGroup(GroupBase): diff --git a/testsuite/MDAnalysisTests/core/test_atomgroup.py b/testsuite/MDAnalysisTests/core/test_atomgroup.py index c5e88fb4d97..4c98d4beeee 100644 --- a/testsuite/MDAnalysisTests/core/test_atomgroup.py +++ b/testsuite/MDAnalysisTests/core/test_atomgroup.py @@ -30,6 +30,7 @@ from numpy.testing import ( assert_almost_equal, assert_equal, + assert_array_almost_equal, ) import MDAnalysis as mda @@ -111,7 +112,7 @@ def test_write_no_args(self, u, tmpdir): name = path.splitext(path.basename(DCD))[0] assert_equal(files[0], "{}_0.pdb".format(name)) - def test_raises(self, u, tmpdir): + def test_raises_unknown_format(self, u, tmpdir): with tmpdir.as_cwd(): with pytest.raises(ValueError): u.atoms.write('useless.format123') @@ -120,6 +121,51 @@ def test_write_coordinates(self, u, tmpdir): with tmpdir.as_cwd(): u.atoms.write("test.xtc") + @pytest.mark.parametrize('frames', ( + [4], + [2, 3, 3, 1], + slice(2, 6, 1), + )) + def test_write_frames(self, u, tmpdir, frames): + destination = str(tmpdir / 'test.dcd') + selection = u.trajectory[frames] + ref_positions = np.stack([ts.positions for ts in selection]) + u.atoms.write(destination, frames=frames) + + u_new = mda.Universe(destination) + new_positions = np.stack([ts.positions for ts in u_new.trajectory]) + + assert_array_almost_equal(new_positions, ref_positions) + + @pytest.mark.parametrize('frames', ( + [4], + [2, 3, 3, 1], + slice(2, 6, 1), + )) + def test_write_frame_iterator(self, u, tmpdir, frames): + destination = str(tmpdir / 'test.dcd') + selection = u.trajectory[frames] + ref_positions = np.stack([ts.positions for ts in selection]) + u.atoms.write(destination, frames=selection) + + u_new = mda.Universe(destination) + new_positions = np.stack([ts.positions for ts in u_new.trajectory]) + + assert_array_almost_equal(new_positions, ref_positions) + + def test_incompatible_arguments(self, u, tmpdir): + destination = str(tmpdir / 'test.dcd') + with pytest.raises(ValueError): + u.atoms.write(destination, frames=[0, 1, 2], multiframe=False) + + def test_incompatible_trajectories(self, tmpdir): + destination = str(tmpdir / 'test.dcd') + u1 = make_Universe(trajectory=True) + u2 = make_Universe(trajectory=True) + destination = str(tmpdir / 'test.dcd') + with pytest.raises(ValueError): + u1.atoms.write(destination, frames=u2.trajectory) + def test_write_selection(self, u, tmpdir): with tmpdir.as_cwd(): u.atoms.write("test.vmd") From 65423742a59ac11b4da12558242f3675d80827f0 Mon Sep 17 00:00:00 2001 From: Jonathan Barnoud Date: Tue, 3 Jul 2018 16:27:17 +0200 Subject: [PATCH 4/5] Update changelog --- package/CHANGELOG | 3 +++ package/MDAnalysis/coordinates/base.py | 8 -------- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/package/CHANGELOG b/package/CHANGELOG index d54408d699a..2010db12ca0 100644 --- a/package/CHANGELOG +++ b/package/CHANGELOG @@ -45,6 +45,7 @@ Enhancements * ChainReader can correctly handle continuous trajectories split into multiple files, generated with gromacs -noappend (PR #1728) * MDAnalysis.lib.mdamath now supports triclinic boxes and rewrote in Cython (PR #1965) + * AtomGroup.write can write a trajectory of selected frames (Issue #1037) Fixes * rewind in the SingleFrameReader now reads the frame from the file (Issue #1929) @@ -70,6 +71,8 @@ Changes *Group.unique will always return the same object unless the group is updated or modified. (PR #1922) * The TPR parser reads SETTLE constraints as bonds. (Issue #1949) + * Indexing a trajectory with a slice or an array now returns an iterable + (Issue #1894) 04/15/18 tylerjereddy, richardjgowers, palnabarun, bieniekmateusz, kain88-de, diff --git a/package/MDAnalysis/coordinates/base.py b/package/MDAnalysis/coordinates/base.py index 34e79e02135..ca64969d289 100644 --- a/package/MDAnalysis/coordinates/base.py +++ b/package/MDAnalysis/coordinates/base.py @@ -1533,14 +1533,6 @@ def __getitem__(self, frame): frame = np.asarray(frame, dtype=np.bool) # Convert bool array to int array frame = np.arange(len(self))[frame] - - #def listiter(frames): - # for f in frames: - # if not isinstance(f, numbers.Integral): - # raise TypeError("Frames indices must be integers") - # yield self._read_frame_with_aux(apply_limits(f)) - # - #return listiter(frame) return FrameIteratorIndices(self, frame) elif isinstance(frame, slice): start, stop, step = self.check_slice_indices( From 06e386fcd69bddfdaf6e49a2d32fda0425a86ac2 Mon Sep 17 00:00:00 2001 From: Jonathan Barnoud Date: Thu, 12 Jul 2018 02:04:45 +0200 Subject: [PATCH 5/5] Address comments from @orbeckst on #1934 --- package/MDAnalysis/coordinates/base.py | 4 +- package/MDAnalysis/core/groups.py | 74 +++++++++++++------ .../coordinates/test_reader_api.py | 13 +++- .../MDAnalysisTests/core/test_atomgroup.py | 32 ++++++++ 4 files changed, 98 insertions(+), 25 deletions(-) diff --git a/package/MDAnalysis/coordinates/base.py b/package/MDAnalysis/coordinates/base.py index ca64969d289..c22db65db92 100644 --- a/package/MDAnalysis/coordinates/base.py +++ b/package/MDAnalysis/coordinates/base.py @@ -935,7 +935,7 @@ def __len__(self): def __iter__(self): for i in range(self.start, self.stop, self.step): - yield self.trajectory._read_frame_with_aux(i) + yield self.trajectory[i] self.trajectory.rewind() def __getitem__(self, frame): @@ -1528,7 +1528,7 @@ def __getitem__(self, frame): frame = self._apply_limits(frame) return self._read_frame_with_aux(frame) elif isinstance(frame, (list, np.ndarray)): - if isinstance(frame[0], (bool, np.bool_)): + if len(frame) != 0 and isinstance(frame[0], (bool, np.bool_)): # Avoid having list of bools frame = np.asarray(frame, dtype=np.bool) # Convert bool array to int array diff --git a/package/MDAnalysis/core/groups.py b/package/MDAnalysis/core/groups.py index 6bddd71f9f3..f7af77f7b8b 100644 --- a/package/MDAnalysis/core/groups.py +++ b/package/MDAnalysis/core/groups.py @@ -2535,17 +2535,27 @@ def write(self, filename=None, file_format="PDB", """Write `AtomGroup` to a file. The output can either be a coordinate file or a selection, depending on - the `format`. Only single-frame coordinate files are supported. If you - need to write out a trajectory, see :mod:`MDAnalysis.coordinates`. + the format. + + Examples + -------- + + >>> ag = u.atoms + >>> ag.write('selection.ndx') # Write a gromacs index file + >>> ag.write('coordinates.pdb') # Write the current frame as PDB + >>> # Write the trajectory in XTC format + >>> ag.write('trajectory.xtc', frames='all') + >>> # Write every other frame of the trajectory in PBD format + >>> ag.write('trajectory.pdb', frames=u.trajectory[::2]) Parameters ---------- filename : str, optional ``None``: create TRJNAME_FRAME.FORMAT from filenamefmt [``None``] file_format : str, optional - PDB, CRD, GRO, VMD (tcl), PyMol (pml), Gromacs (ndx) CHARMM (str) - Jmol (spt); case-insensitive and can also be supplied as the - filename extension [PDB] + The name or extension of a coordinate, trajectory, or selection + file format such as PDB, CRD, GRO, VMD (tcl), PyMol (pml), Gromacs + (ndx) CHARMM (str) or Jmol (spt); case-insensitive [PDB] filenamefmt : str, optional format string for default filename; use substitution tokens 'trjname' and 'frame' ["%(trjname)s_%(frame)d"] @@ -2555,11 +2565,13 @@ def write(self, filename=None, file_format="PDB", file. ``"all"``: write out all bonds, both the original defined and those guessed by MDAnalysis. ``None``: do not write out bonds. Default is ``"conect"``. - frames: + frames: array-like or slice or FrameIteratorBase or str, optional An ensemble of frames to write. The ensemble can be an list or array of frame indices, a mask of booleans, an instance of - :class:`slice`, or an indexed trajectory. By default, 'frames' is - set to ``None`` and only the current frame is written. + :class:`slice`, or the value returned when a trajectory is indexed. + By default, `frames` is set to ``None`` and only the current frame + is written. If `frames` is set to "all", then all the frame from + trajectory are written. .. versionchanged:: 0.9.0 Merged with write_selection. This method can @@ -2567,13 +2579,19 @@ def write(self, filename=None, file_format="PDB", .. versionchanged:: 0.19.0 Can write multiframe trajectories with the 'frames' argument. """ + # TODO: Add a 'verbose' option alongside 'frames'. + # check that AtomGroup actually has any atoms (Issue #434) if len(self.atoms) == 0: raise IndexError("Cannot write an AtomGroup with 0 atoms") trj = self.universe.trajectory # unified trajectory API - if frames is None: + if frames is None or frames == 'all': trj_frames = trj[::] + elif isinstance(frames, numbers.Integral): + # We accept everything that indexes a trajectory and returns a + # subset of it. Though, numbers return a Timestep instead. + raise TypeError('The "frames" argument cannot be a number.') else: try: test_trajectory = frames.trajectory @@ -2587,20 +2605,31 @@ def write(self, filename=None, file_format="PDB", 'AtomGroup.'.format(frames) ) trj_frames = frames - if len(trj_frames) > 1 and kwargs.get("multiframe") == False: - raise ValueError( - 'Cannot explicitely set "multiframe" to False and request ' - 'more than 1 frame with the "frames" keyword argument.' - ) - - if len(trj_frames) == 1: - kwargs.setdefault("multiframe", False) if filename is None: trjname, ext = os.path.splitext(os.path.basename(trj.filename)) filename = filenamefmt.format(trjname=trjname, frame=trj.frame) filename = util.filename(filename, ext=file_format.lower(), keep=True) + # Some writer behave differently when they are given a "multiframe" + # argument. It is the case of the PDB writer tht writes models when + # "multiframe" is True. + # We want to honor what the user provided with the argument if + # provided explicitly. If not, then we need to figure out if we write + # multiple frames or not. + multiframe = kwargs.pop('multiframe', None) + if len(trj_frames) > 1 and multiframe == False: + raise ValueError( + 'Cannot explicitely set "multiframe" to False and request ' + 'more than 1 frame with the "frames" keyword argument.' + ) + elif multiframe is None: + if frames is None: + # By default we only write the current frame. + multiframe = False + else: + multiframe = len(trj_frames) > 1 + # From the following blocks, one must pass. # Both can't pass as the extensions don't overlap. # Try and select a Class using get_ methods (becomes `writer`) @@ -2613,10 +2642,7 @@ def write(self, filename=None, file_format="PDB", format = format or file_format format = format.strip().upper() - multiframe = kwargs.pop('multiframe', None) - writer = get_writer_for(filename, format=format, multiframe=multiframe) - #MDAnalysis.coordinates.writer(filename, **kwargs) except (ValueError, TypeError): pass else: @@ -2624,8 +2650,12 @@ def write(self, filename=None, file_format="PDB", if frames is None: w.write(self.atoms) else: - for _ in trj_frames: - w.write(self.atoms) + current_frame = trj.ts.frame + try: + for _ in trj_frames: + w.write(self.atoms) + finally: + trj[current_frame] return try: diff --git a/testsuite/MDAnalysisTests/coordinates/test_reader_api.py b/testsuite/MDAnalysisTests/coordinates/test_reader_api.py index 60e4f0889d9..8552a7283fa 100644 --- a/testsuite/MDAnalysisTests/coordinates/test_reader_api.py +++ b/testsuite/MDAnalysisTests/coordinates/test_reader_api.py @@ -170,7 +170,9 @@ class TestMultiFrameReader(_Multi): (1, 5, -1), # Stop less than start (-100, None, None), (100, None, None), # Outside of range of trajectory - (-2, 10, -2) + (-2, 10, -2), + (0, 0, 1), # empty + (10, 1, 2), # empty ]) def test_slice(self, start, stop, step, reader): """Compare the slice applied to trajectory, to slice of list""" @@ -226,6 +228,8 @@ def test_getitem(self, slice_cls, sl, reader): slice(10, 0, -1), slice(2, 7, 2), slice(7, 2, -2), + slice(7, 2, 1), # empty + slice(0, 0, 1), # empty ]) def test_getitem_len(self, sl, reader): traj_iterable = reader[sl] @@ -234,6 +238,12 @@ def test_getitem_len(self, sl, reader): ref = self.reference[sl] assert len(traj_iterable) == len(ref) + @pytest.mark.parametrize('iter_type', (list, np.array)) + def test_getitem_len_empty(self, reader, iter_type): + # Indexing a numpy array with an empty array tends to break. + traj_iterable = reader[iter_type([])] + assert len(traj_iterable) == 0 + # All the sl1 slice must be 5 frames long so that the sl2 can be a mask @pytest.mark.parametrize('sl1', [ [0, 1, 2, 3, 4], @@ -271,6 +281,7 @@ def test_double_getitem(self, sl1, sl2, reader): [True, False, ] * 5, slice(None, None, 2), slice(None, None, -2), + slice(None, None, None), ]) @pytest.mark.parametrize('idx2', [0, 2, 4, -1, -2, -4]) def test_double_getitem_int(self, sl1, idx2, reader): diff --git a/testsuite/MDAnalysisTests/core/test_atomgroup.py b/testsuite/MDAnalysisTests/core/test_atomgroup.py index 4c98d4beeee..e1c7e115cd1 100644 --- a/testsuite/MDAnalysisTests/core/test_atomgroup.py +++ b/testsuite/MDAnalysisTests/core/test_atomgroup.py @@ -153,6 +153,31 @@ def test_write_frame_iterator(self, u, tmpdir, frames): assert_array_almost_equal(new_positions, ref_positions) + @pytest.mark.parametrize('extension', ('xtc', 'dcd', 'pdb', 'xyz')) + def test_write_frame_none(self, u, tmpdir, extension): + destination = str(tmpdir / 'test.' + extension) + u.atoms.write(destination, frames=None) + u_new = mda.Universe(destination) + new_positions = np.stack([ts.positions for ts in u_new.trajectory]) + # Most format only save 3 decimals; XTC even has only 2. + assert_array_almost_equal( + u.atoms.positions[None, ...], new_positions, decimal=2 + ) + + def test_write_frames_all(self, u, tmpdir): + destination = str(tmpdir / 'test.dcd') + u.atoms.write(destination, frames='all') + u_new = mda.Universe(destination) + ref_positions = np.stack([ts.positions for ts in u.trajectory]) + new_positions = np.stack([ts.positions for ts in u_new.trajectory]) + assert_array_almost_equal(new_positions, ref_positions) + + @pytest.mark.parametrize('frames', ('invalid', 8, True, False, 3.2)) + def test_write_frames_invalid(self, u, tmpdir, frames): + destination = str(tmpdir / 'test.dcd') + with pytest.raises(TypeError): + u.atoms.write(destination, frames=frames) + def test_incompatible_arguments(self, u, tmpdir): destination = str(tmpdir / 'test.dcd') with pytest.raises(ValueError): @@ -166,6 +191,13 @@ def test_incompatible_trajectories(self, tmpdir): with pytest.raises(ValueError): u1.atoms.write(destination, frames=u2.trajectory) + def test_write_no_traj_move(self, u, tmpdir): + destination = str(tmpdir / 'test.dcd') + u.trajectory[10] + u.atoms.write(destination, frames=[1, 2, 3]) + assert u.trajectory.ts.frame == 10 + + def test_write_selection(self, u, tmpdir): with tmpdir.as_cwd(): u.atoms.write("test.vmd")