-
Notifications
You must be signed in to change notification settings - Fork 663
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Proof of concept trajectory class iterators #1934
Changes from 4 commits
5595681
15222e3
1ef8010
6542374
06e386f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -859,6 +859,205 @@ def time(self): | |
del self.data['time'] | ||
|
||
|
||
class FrameIteratorBase(object): | ||
""" | ||
Base iterable over the frames of a trajectory. | ||
|
||
A frame iterable has a length that can be accessed with the :func:`len` | ||
function, and can be indexed similarly to a full trajectory. When indexed, | ||
indices are resolved relative to the iterable and not relative to the | ||
trajectory. | ||
|
||
Parameters | ||
---------- | ||
trajectory: ProtoReader | ||
The trajectory over which to iterate. | ||
|
||
.. versionadded:: 0.19.0 | ||
|
||
""" | ||
def __init__(self, trajectory): | ||
self._trajectory = trajectory | ||
|
||
def __len__(self): | ||
raise NotImplementedError() | ||
|
||
@staticmethod | ||
def _avoid_bool_list(frames): | ||
if isinstance(frames, list) and frames and isinstance(frames[0], bool): | ||
return np.array(frames, dtype=bool) | ||
return frames | ||
|
||
@property | ||
def trajectory(self): | ||
return self._trajectory | ||
|
||
|
||
class FrameIteratorSliced(FrameIteratorBase): | ||
""" | ||
Iterable over the frames of a trajectory on the basis of a slice. | ||
|
||
Parameters | ||
---------- | ||
trajectory: ProtoReader | ||
The trajectory over which to iterate. | ||
frames: slice | ||
A slice to select the frames of interest. | ||
|
||
See Also | ||
-------- | ||
FrameIteratorBase | ||
|
||
.. versionadded:: 0.19.0 | ||
|
||
""" | ||
def __init__(self, trajectory, frames): | ||
# It would be easier to store directly a range object, as it would | ||
# store its parameters in a single place, calculate its length, and | ||
# take care of most the indexing. Though, doing so is not compatible | ||
# with python 2 where xrange (or range with six) is only an iterator. | ||
super(FrameIteratorSliced, self).__init__(trajectory) | ||
self._start, self._stop, self._step = trajectory.check_slice_indices( | ||
frames.start, frames.stop, frames.step, | ||
) | ||
|
||
def __len__(self): | ||
start, stop, step = self.start, self.stop, self.step | ||
if (step > 0 and start < stop): | ||
# We go from a lesser number to a larger one. | ||
return int(1 + (stop - 1 - start) // step) | ||
elif (step < 0 and start > stop): | ||
# We count backward from a larger number to a lesser one. | ||
return int(1 + (start - 1 - stop) // (-step)) | ||
else: | ||
# The range is empty. | ||
return 0 | ||
|
||
def __iter__(self): | ||
for i in range(self.start, self.stop, self.step): | ||
yield self.trajectory._read_frame_with_aux(i) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this identical to yield self.trajectory[i] ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It should be. |
||
self.trajectory.rewind() | ||
|
||
def __getitem__(self, frame): | ||
if isinstance(frame, numbers.Integral): | ||
length = len(self) | ||
if not -length < frame < length: | ||
raise IndexError('Index {} is out of range of the range of length {}.' | ||
.format(frame, length)) | ||
if frame < 0: | ||
frame = len(self) + frame | ||
frame = self.start + frame * self.step | ||
return self.trajectory._read_frame_with_aux(frame) | ||
elif isinstance(frame, slice): | ||
start = self.start + (frame.start or 0) * self.step | ||
if frame.stop is None: | ||
stop = self.stop | ||
else: | ||
stop = self.start + (frame.stop or 0) * self.step | ||
step = (frame.step or 1) * self.step | ||
|
||
if step > 0: | ||
start = max(0, start) | ||
else: | ||
stop = max(0, stop) | ||
|
||
new_slice = slice(start, stop, step) | ||
return FrameIteratorSliced(self.trajectory, new_slice) | ||
else: | ||
# Indexing with a lists of bools does not behave the same in all | ||
# version of numpy. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we/do we want to avoid this workaround by bumping our required version of numpy? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We may have to bump by several versions. I do not know by how much, though. The workaround is needed at least for numpy 0.10 and 0.11. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, then workaround for now. |
||
frame = self._avoid_bool_list(frame) | ||
frames = np.array(list(range(self.start, self.stop, self.step)))[frame] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is copied from above where I wrote a similar trick There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's link #1982. This should not slow down the library since we do it here instead of doing the same thing somewhere else; but a faster version would make sense indeed. |
||
return FrameIteratorIndices(self.trajectory, frames) | ||
|
||
@property | ||
def start(self): | ||
return self._start | ||
|
||
@property | ||
def stop(self): | ||
return self._stop | ||
|
||
@property | ||
def step(self): | ||
return self._step | ||
|
||
|
||
class FrameIteratorAll(FrameIteratorBase): | ||
""" | ||
Iterable over all the frames of a trajectory. | ||
|
||
Parameters | ||
---------- | ||
trajectory: ProtoReader | ||
The trajectory over which to iterate. | ||
|
||
See Also | ||
-------- | ||
FrameIteratorBase | ||
|
||
.. versionadded:: 0.19.0 | ||
|
||
""" | ||
def __init__(self, trajectory): | ||
super(FrameIteratorAll, self).__init__(trajectory) | ||
|
||
def __len__(self): | ||
return self.trajectory.n_frames | ||
|
||
def __iter__(self): | ||
return iter(self.trajectory) | ||
|
||
def __getitem__(self, frame): | ||
return self.trajectory[frame] | ||
|
||
|
||
class FrameIteratorIndices(FrameIteratorBase): | ||
""" | ||
Iterable over the frames of a trajectory listed in a sequence of indices. | ||
|
||
Parameters | ||
---------- | ||
trajectory: ProtoReader | ||
The trajectory over which to iterate. | ||
frames: sequence | ||
A sequence of indices. | ||
|
||
See Also | ||
-------- | ||
FrameIteratorBase | ||
""" | ||
def __init__(self, trajectory, frames): | ||
super(FrameIteratorIndices, self).__init__(trajectory) | ||
self._frames = [] | ||
for frame in frames: | ||
if not isinstance(frame, numbers.Integral): | ||
raise TypeError("Frames indices must be integers.") | ||
frame = trajectory._apply_limits(frame) | ||
self._frames.append(frame) | ||
self._frames = tuple(self._frames) | ||
|
||
def __len__(self): | ||
return len(self.frames) | ||
|
||
def __iter__(self): | ||
for frame in self.frames: | ||
yield self.trajectory._read_frame_with_aux(frame) | ||
|
||
def __getitem__(self, frame): | ||
if isinstance(frame, numbers.Integral): | ||
frame = self.frames[frame] | ||
return self.trajectory._read_frame_with_aux(frame) | ||
else: | ||
frame = self._avoid_bool_list(frame) | ||
frames = np.array(self.frames)[frame] | ||
return FrameIteratorIndices(self.trajectory, frames) | ||
|
||
@property | ||
def frames(self): | ||
return self._frames | ||
|
||
|
||
class IOBase(object): | ||
"""Base class bundling common functionality for trajectory I/O. | ||
|
||
|
@@ -1251,6 +1450,11 @@ def time(self): | |
""" | ||
return self.ts.time | ||
|
||
@property | ||
def trajectory(self): | ||
# Makes a reader effectively commpatible with a FrameIteratorBase | ||
return self | ||
|
||
def Writer(self, filename, **kwargs): | ||
"""A trajectory writer with the same properties as this trajectory.""" | ||
raise NotImplementedError( | ||
|
@@ -1299,6 +1503,14 @@ def _reopen(self): | |
""" | ||
pass | ||
|
||
def _apply_limits(self, frame): | ||
if frame < 0: | ||
frame += len(self) | ||
if frame < 0 or frame >= len(self): | ||
raise IndexError("Index {} exceeds length of trajectory ({})." | ||
"".format(frame, len(self))) | ||
return frame | ||
|
||
def __getitem__(self, frame): | ||
"""Return the Timestep corresponding to *frame*. | ||
|
||
|
@@ -1312,39 +1524,23 @@ def __getitem__(self, frame): | |
---- | ||
*frame* is a 0-based frame index. | ||
""" | ||
|
||
def apply_limits(frame): | ||
if frame < 0: | ||
frame += len(self) | ||
if frame < 0 or frame >= len(self): | ||
raise IndexError("Index {} exceeds length of trajectory ({})." | ||
"".format(frame, len(self))) | ||
return frame | ||
|
||
if isinstance(frame, numbers.Integral): | ||
frame = apply_limits(frame) | ||
frame = self._apply_limits(frame) | ||
return self._read_frame_with_aux(frame) | ||
elif isinstance(frame, (list, np.ndarray)): | ||
if isinstance(frame[0], (bool, np.bool_)): | ||
# Avoid having list of bools | ||
frame = np.asarray(frame, dtype=np.bool) | ||
# Convert bool array to int array | ||
frame = np.arange(len(self))[frame] | ||
|
||
def listiter(frames): | ||
for f in frames: | ||
if not isinstance(f, numbers.Integral): | ||
raise TypeError("Frames indices must be integers") | ||
yield self._read_frame_with_aux(apply_limits(f)) | ||
|
||
return listiter(frame) | ||
return FrameIteratorIndices(self, frame) | ||
elif isinstance(frame, slice): | ||
start, stop, step = self.check_slice_indices( | ||
frame.start, frame.stop, frame.step) | ||
if start == 0 and stop == len(self) and step == 1: | ||
return self.__iter__() | ||
return FrameIteratorAll(self) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I could return There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like the |
||
else: | ||
return self._sliced_iter(start, stop, step) | ||
return FrameIteratorSliced(self, frame) | ||
else: | ||
raise TypeError("Trajectories must be an indexed using an integer," | ||
" slice or list of indices") | ||
|
@@ -1464,7 +1660,7 @@ def check_slice_indices(self, start, stop, step): | |
if start < 0: | ||
start = 0 | ||
|
||
if step < 0 and start > nframes: | ||
if step < 0 and start >= nframes: | ||
start = nframes - 1 | ||
|
||
if stop is None: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this how
range
calculates its length and is this what you get when you doarray[start:stop:step]
? Just asking so that we don't overlook any edge cases.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is how python implements it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No concerns, 👍