Skip to content

Commit

Permalink
Merge pull request #1934 from MDAnalysis/issue1894-iterator-len
Browse files Browse the repository at this point in the history
Proof of concept trajectory class iterators
  • Loading branch information
richardjgowers authored Jul 12, 2018
2 parents 4be325c + 06e386f commit 675252a
Show file tree
Hide file tree
Showing 6 changed files with 492 additions and 47 deletions.
3 changes: 3 additions & 0 deletions package/CHANGELOG
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ Enhancements
* ChainReader can correctly handle continuous trajectories split into multiple files,
generated with gromacs -noappend (PR #1728)
* MDAnalysis.lib.mdamath now supports triclinic boxes and rewrote in Cython (PR #1965)
* AtomGroup.write can write a trajectory of selected frames (Issue #1037)

Fixes
* rewind in the SingleFrameReader now reads the frame from the file (Issue #1929)
Expand Down Expand Up @@ -74,6 +75,8 @@ Changes
*Group.unique will always return the same object unless the group is
updated or modified. (PR #1922)
* The TPR parser reads SETTLE constraints as bonds. (Issue #1949)
* Indexing a trajectory with a slice or an array now returns an iterable
(Issue #1894)

Deprecations
* almost all "save()", "save_results()", "save_table()" methods in
Expand Down
9 changes: 8 additions & 1 deletion package/MDAnalysis/coordinates/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ class can choose an appropriate reader automatically.
``__getitem__(arg)``
advance to time step `arg` = `frame` and return :class:`Timestep`; or if `arg` is a
slice, then return an iterator over that part of the trajectory.
slice, then return an iterable over that part of the trajectory.
The first functionality allows one to randomly access frames in the
trajectory::
Expand All @@ -524,6 +524,9 @@ class can choose an appropriate reader automatically.
The last example starts reading the trajectory at frame 1000 and
reads every 100th frame until the end.
A sequence of indices or a mask of booleans can also be provided to index
a trajectory.
The performance of the ``__getitem__()`` method depends on the underlying
trajectory reader and if it can implement random access to frames. In many
cases this is not easily (or reliably) implementable and thus one is
Expand All @@ -537,6 +540,10 @@ class can choose an appropriate reader automatically.
:class:`MDAnalysis.coordinates.base.ProtoReader.__iter__` (which is always
implemented) and other slices raise :exc:`TypeError`.
When indexed with a slice, a sequence of indices, or a mask of booleans,
the return value is an instance of :class:`FrameIteratorSliced` or
:class:`FrameIteratorIndices`.
``parse_n_atoms(filename, **kwargs)``
Provide the number of atoms in the trajectory file, allowing the Reader
to be used to provide an extremely minimal Topology.
Expand Down
240 changes: 218 additions & 22 deletions package/MDAnalysis/coordinates/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,6 +859,205 @@ def time(self):
del self.data['time']


class FrameIteratorBase(object):
"""
Base iterable over the frames of a trajectory.
A frame iterable has a length that can be accessed with the :func:`len`
function, and can be indexed similarly to a full trajectory. When indexed,
indices are resolved relative to the iterable and not relative to the
trajectory.
Parameters
----------
trajectory: ProtoReader
The trajectory over which to iterate.
.. versionadded:: 0.19.0
"""
def __init__(self, trajectory):
self._trajectory = trajectory

def __len__(self):
raise NotImplementedError()

@staticmethod
def _avoid_bool_list(frames):
if isinstance(frames, list) and frames and isinstance(frames[0], bool):
return np.array(frames, dtype=bool)
return frames

@property
def trajectory(self):
return self._trajectory


class FrameIteratorSliced(FrameIteratorBase):
"""
Iterable over the frames of a trajectory on the basis of a slice.
Parameters
----------
trajectory: ProtoReader
The trajectory over which to iterate.
frames: slice
A slice to select the frames of interest.
See Also
--------
FrameIteratorBase
.. versionadded:: 0.19.0
"""
def __init__(self, trajectory, frames):
# It would be easier to store directly a range object, as it would
# store its parameters in a single place, calculate its length, and
# take care of most the indexing. Though, doing so is not compatible
# with python 2 where xrange (or range with six) is only an iterator.
super(FrameIteratorSliced, self).__init__(trajectory)
self._start, self._stop, self._step = trajectory.check_slice_indices(
frames.start, frames.stop, frames.step,
)

def __len__(self):
start, stop, step = self.start, self.stop, self.step
if (step > 0 and start < stop):
# We go from a lesser number to a larger one.
return int(1 + (stop - 1 - start) // step)
elif (step < 0 and start > stop):
# We count backward from a larger number to a lesser one.
return int(1 + (start - 1 - stop) // (-step))
else:
# The range is empty.
return 0

def __iter__(self):
for i in range(self.start, self.stop, self.step):
yield self.trajectory[i]
self.trajectory.rewind()

def __getitem__(self, frame):
if isinstance(frame, numbers.Integral):
length = len(self)
if not -length < frame < length:
raise IndexError('Index {} is out of range of the range of length {}.'
.format(frame, length))
if frame < 0:
frame = len(self) + frame
frame = self.start + frame * self.step
return self.trajectory._read_frame_with_aux(frame)
elif isinstance(frame, slice):
start = self.start + (frame.start or 0) * self.step
if frame.stop is None:
stop = self.stop
else:
stop = self.start + (frame.stop or 0) * self.step
step = (frame.step or 1) * self.step

if step > 0:
start = max(0, start)
else:
stop = max(0, stop)

new_slice = slice(start, stop, step)
return FrameIteratorSliced(self.trajectory, new_slice)
else:
# Indexing with a lists of bools does not behave the same in all
# version of numpy.
frame = self._avoid_bool_list(frame)
frames = np.array(list(range(self.start, self.stop, self.step)))[frame]
return FrameIteratorIndices(self.trajectory, frames)

@property
def start(self):
return self._start

@property
def stop(self):
return self._stop

@property
def step(self):
return self._step


class FrameIteratorAll(FrameIteratorBase):
"""
Iterable over all the frames of a trajectory.
Parameters
----------
trajectory: ProtoReader
The trajectory over which to iterate.
See Also
--------
FrameIteratorBase
.. versionadded:: 0.19.0
"""
def __init__(self, trajectory):
super(FrameIteratorAll, self).__init__(trajectory)

def __len__(self):
return self.trajectory.n_frames

def __iter__(self):
return iter(self.trajectory)

def __getitem__(self, frame):
return self.trajectory[frame]


class FrameIteratorIndices(FrameIteratorBase):
"""
Iterable over the frames of a trajectory listed in a sequence of indices.
Parameters
----------
trajectory: ProtoReader
The trajectory over which to iterate.
frames: sequence
A sequence of indices.
See Also
--------
FrameIteratorBase
"""
def __init__(self, trajectory, frames):
super(FrameIteratorIndices, self).__init__(trajectory)
self._frames = []
for frame in frames:
if not isinstance(frame, numbers.Integral):
raise TypeError("Frames indices must be integers.")
frame = trajectory._apply_limits(frame)
self._frames.append(frame)
self._frames = tuple(self._frames)

def __len__(self):
return len(self.frames)

def __iter__(self):
for frame in self.frames:
yield self.trajectory._read_frame_with_aux(frame)

def __getitem__(self, frame):
if isinstance(frame, numbers.Integral):
frame = self.frames[frame]
return self.trajectory._read_frame_with_aux(frame)
else:
frame = self._avoid_bool_list(frame)
frames = np.array(self.frames)[frame]
return FrameIteratorIndices(self.trajectory, frames)

@property
def frames(self):
return self._frames


class IOBase(object):
"""Base class bundling common functionality for trajectory I/O.
Expand Down Expand Up @@ -1251,6 +1450,11 @@ def time(self):
"""
return self.ts.time

@property
def trajectory(self):
# Makes a reader effectively commpatible with a FrameIteratorBase
return self

def Writer(self, filename, **kwargs):
"""A trajectory writer with the same properties as this trajectory."""
raise NotImplementedError(
Expand Down Expand Up @@ -1299,6 +1503,14 @@ def _reopen(self):
"""
pass

def _apply_limits(self, frame):
if frame < 0:
frame += len(self)
if frame < 0 or frame >= len(self):
raise IndexError("Index {} exceeds length of trajectory ({})."
"".format(frame, len(self)))
return frame

def __getitem__(self, frame):
"""Return the Timestep corresponding to *frame*.
Expand All @@ -1312,39 +1524,23 @@ def __getitem__(self, frame):
----
*frame* is a 0-based frame index.
"""

def apply_limits(frame):
if frame < 0:
frame += len(self)
if frame < 0 or frame >= len(self):
raise IndexError("Index {} exceeds length of trajectory ({})."
"".format(frame, len(self)))
return frame

if isinstance(frame, numbers.Integral):
frame = apply_limits(frame)
frame = self._apply_limits(frame)
return self._read_frame_with_aux(frame)
elif isinstance(frame, (list, np.ndarray)):
if isinstance(frame[0], (bool, np.bool_)):
if len(frame) != 0 and isinstance(frame[0], (bool, np.bool_)):
# Avoid having list of bools
frame = np.asarray(frame, dtype=np.bool)
# Convert bool array to int array
frame = np.arange(len(self))[frame]

def listiter(frames):
for f in frames:
if not isinstance(f, numbers.Integral):
raise TypeError("Frames indices must be integers")
yield self._read_frame_with_aux(apply_limits(f))

return listiter(frame)
return FrameIteratorIndices(self, frame)
elif isinstance(frame, slice):
start, stop, step = self.check_slice_indices(
frame.start, frame.stop, frame.step)
if start == 0 and stop == len(self) and step == 1:
return self.__iter__()
return FrameIteratorAll(self)
else:
return self._sliced_iter(start, stop, step)
return FrameIteratorSliced(self, frame)
else:
raise TypeError("Trajectories must be an indexed using an integer,"
" slice or list of indices")
Expand Down Expand Up @@ -1464,7 +1660,7 @@ def check_slice_indices(self, start, stop, step):
if start < 0:
start = 0

if step < 0 and start > nframes:
if step < 0 and start >= nframes:
start = nframes - 1

if stop is None:
Expand Down
Loading

0 comments on commit 675252a

Please sign in to comment.