Skip to content

Commit

Permalink
Address comments from @orbeckst on #1934
Browse files Browse the repository at this point in the history
  • Loading branch information
jbarnoud committed Jul 12, 2018
1 parent 6542374 commit 06e386f
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 25 deletions.
4 changes: 2 additions & 2 deletions package/MDAnalysis/coordinates/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,7 +935,7 @@ def __len__(self):

def __iter__(self):
for i in range(self.start, self.stop, self.step):
yield self.trajectory._read_frame_with_aux(i)
yield self.trajectory[i]
self.trajectory.rewind()

def __getitem__(self, frame):
Expand Down Expand Up @@ -1528,7 +1528,7 @@ def __getitem__(self, frame):
frame = self._apply_limits(frame)
return self._read_frame_with_aux(frame)
elif isinstance(frame, (list, np.ndarray)):
if isinstance(frame[0], (bool, np.bool_)):
if len(frame) != 0 and isinstance(frame[0], (bool, np.bool_)):
# Avoid having list of bools
frame = np.asarray(frame, dtype=np.bool)
# Convert bool array to int array
Expand Down
74 changes: 52 additions & 22 deletions package/MDAnalysis/core/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -2535,17 +2535,27 @@ def write(self, filename=None, file_format="PDB",
"""Write `AtomGroup` to a file.
The output can either be a coordinate file or a selection, depending on
the `format`. Only single-frame coordinate files are supported. If you
need to write out a trajectory, see :mod:`MDAnalysis.coordinates`.
the format.
Examples
--------
>>> ag = u.atoms
>>> ag.write('selection.ndx') # Write a gromacs index file
>>> ag.write('coordinates.pdb') # Write the current frame as PDB
>>> # Write the trajectory in XTC format
>>> ag.write('trajectory.xtc', frames='all')
>>> # Write every other frame of the trajectory in PBD format
>>> ag.write('trajectory.pdb', frames=u.trajectory[::2])
Parameters
----------
filename : str, optional
``None``: create TRJNAME_FRAME.FORMAT from filenamefmt [``None``]
file_format : str, optional
PDB, CRD, GRO, VMD (tcl), PyMol (pml), Gromacs (ndx) CHARMM (str)
Jmol (spt); case-insensitive and can also be supplied as the
filename extension [PDB]
The name or extension of a coordinate, trajectory, or selection
file format such as PDB, CRD, GRO, VMD (tcl), PyMol (pml), Gromacs
(ndx) CHARMM (str) or Jmol (spt); case-insensitive [PDB]
filenamefmt : str, optional
format string for default filename; use substitution tokens
'trjname' and 'frame' ["%(trjname)s_%(frame)d"]
Expand All @@ -2555,25 +2565,33 @@ def write(self, filename=None, file_format="PDB",
file. ``"all"``: write out all bonds, both the original defined and
those guessed by MDAnalysis. ``None``: do not write out bonds.
Default is ``"conect"``.
frames:
frames: array-like or slice or FrameIteratorBase or str, optional
An ensemble of frames to write. The ensemble can be an list or
array of frame indices, a mask of booleans, an instance of
:class:`slice`, or an indexed trajectory. By default, 'frames' is
set to ``None`` and only the current frame is written.
:class:`slice`, or the value returned when a trajectory is indexed.
By default, `frames` is set to ``None`` and only the current frame
is written. If `frames` is set to "all", then all the frame from
trajectory are written.
.. versionchanged:: 0.9.0 Merged with write_selection. This method can
now write both selections out.
.. versionchanged:: 0.19.0
Can write multiframe trajectories with the 'frames' argument.
"""
# TODO: Add a 'verbose' option alongside 'frames'.

# check that AtomGroup actually has any atoms (Issue #434)
if len(self.atoms) == 0:
raise IndexError("Cannot write an AtomGroup with 0 atoms")

trj = self.universe.trajectory # unified trajectory API
if frames is None:
if frames is None or frames == 'all':
trj_frames = trj[::]
elif isinstance(frames, numbers.Integral):
# We accept everything that indexes a trajectory and returns a
# subset of it. Though, numbers return a Timestep instead.
raise TypeError('The "frames" argument cannot be a number.')
else:
try:
test_trajectory = frames.trajectory
Expand All @@ -2587,20 +2605,31 @@ def write(self, filename=None, file_format="PDB",
'AtomGroup.'.format(frames)
)
trj_frames = frames
if len(trj_frames) > 1 and kwargs.get("multiframe") == False:
raise ValueError(
'Cannot explicitely set "multiframe" to False and request '
'more than 1 frame with the "frames" keyword argument.'
)

if len(trj_frames) == 1:
kwargs.setdefault("multiframe", False)

if filename is None:
trjname, ext = os.path.splitext(os.path.basename(trj.filename))
filename = filenamefmt.format(trjname=trjname, frame=trj.frame)
filename = util.filename(filename, ext=file_format.lower(), keep=True)

# Some writer behave differently when they are given a "multiframe"
# argument. It is the case of the PDB writer tht writes models when
# "multiframe" is True.
# We want to honor what the user provided with the argument if
# provided explicitly. If not, then we need to figure out if we write
# multiple frames or not.
multiframe = kwargs.pop('multiframe', None)
if len(trj_frames) > 1 and multiframe == False:
raise ValueError(
'Cannot explicitely set "multiframe" to False and request '
'more than 1 frame with the "frames" keyword argument.'
)
elif multiframe is None:
if frames is None:
# By default we only write the current frame.
multiframe = False
else:
multiframe = len(trj_frames) > 1

# From the following blocks, one must pass.
# Both can't pass as the extensions don't overlap.
# Try and select a Class using get_ methods (becomes `writer`)
Expand All @@ -2613,19 +2642,20 @@ def write(self, filename=None, file_format="PDB",
format = format or file_format
format = format.strip().upper()

multiframe = kwargs.pop('multiframe', None)

writer = get_writer_for(filename, format=format, multiframe=multiframe)
#MDAnalysis.coordinates.writer(filename, **kwargs)
except (ValueError, TypeError):
pass
else:
with writer(filename, n_atoms=self.n_atoms, **kwargs) as w:
if frames is None:
w.write(self.atoms)
else:
for _ in trj_frames:
w.write(self.atoms)
current_frame = trj.ts.frame
try:
for _ in trj_frames:
w.write(self.atoms)
finally:
trj[current_frame]
return

try:
Expand Down
13 changes: 12 additions & 1 deletion testsuite/MDAnalysisTests/coordinates/test_reader_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,9 @@ class TestMultiFrameReader(_Multi):
(1, 5, -1), # Stop less than start
(-100, None, None),
(100, None, None), # Outside of range of trajectory
(-2, 10, -2)
(-2, 10, -2),
(0, 0, 1), # empty
(10, 1, 2), # empty
])
def test_slice(self, start, stop, step, reader):
"""Compare the slice applied to trajectory, to slice of list"""
Expand Down Expand Up @@ -226,6 +228,8 @@ def test_getitem(self, slice_cls, sl, reader):
slice(10, 0, -1),
slice(2, 7, 2),
slice(7, 2, -2),
slice(7, 2, 1), # empty
slice(0, 0, 1), # empty
])
def test_getitem_len(self, sl, reader):
traj_iterable = reader[sl]
Expand All @@ -234,6 +238,12 @@ def test_getitem_len(self, sl, reader):
ref = self.reference[sl]
assert len(traj_iterable) == len(ref)

@pytest.mark.parametrize('iter_type', (list, np.array))
def test_getitem_len_empty(self, reader, iter_type):
# Indexing a numpy array with an empty array tends to break.
traj_iterable = reader[iter_type([])]
assert len(traj_iterable) == 0

# All the sl1 slice must be 5 frames long so that the sl2 can be a mask
@pytest.mark.parametrize('sl1', [
[0, 1, 2, 3, 4],
Expand Down Expand Up @@ -271,6 +281,7 @@ def test_double_getitem(self, sl1, sl2, reader):
[True, False, ] * 5,
slice(None, None, 2),
slice(None, None, -2),
slice(None, None, None),
])
@pytest.mark.parametrize('idx2', [0, 2, 4, -1, -2, -4])
def test_double_getitem_int(self, sl1, idx2, reader):
Expand Down
32 changes: 32 additions & 0 deletions testsuite/MDAnalysisTests/core/test_atomgroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,31 @@ def test_write_frame_iterator(self, u, tmpdir, frames):

assert_array_almost_equal(new_positions, ref_positions)

@pytest.mark.parametrize('extension', ('xtc', 'dcd', 'pdb', 'xyz'))
def test_write_frame_none(self, u, tmpdir, extension):
destination = str(tmpdir / 'test.' + extension)
u.atoms.write(destination, frames=None)
u_new = mda.Universe(destination)
new_positions = np.stack([ts.positions for ts in u_new.trajectory])
# Most format only save 3 decimals; XTC even has only 2.
assert_array_almost_equal(
u.atoms.positions[None, ...], new_positions, decimal=2
)

def test_write_frames_all(self, u, tmpdir):
destination = str(tmpdir / 'test.dcd')
u.atoms.write(destination, frames='all')
u_new = mda.Universe(destination)
ref_positions = np.stack([ts.positions for ts in u.trajectory])
new_positions = np.stack([ts.positions for ts in u_new.trajectory])
assert_array_almost_equal(new_positions, ref_positions)

@pytest.mark.parametrize('frames', ('invalid', 8, True, False, 3.2))
def test_write_frames_invalid(self, u, tmpdir, frames):
destination = str(tmpdir / 'test.dcd')
with pytest.raises(TypeError):
u.atoms.write(destination, frames=frames)

def test_incompatible_arguments(self, u, tmpdir):
destination = str(tmpdir / 'test.dcd')
with pytest.raises(ValueError):
Expand All @@ -166,6 +191,13 @@ def test_incompatible_trajectories(self, tmpdir):
with pytest.raises(ValueError):
u1.atoms.write(destination, frames=u2.trajectory)

def test_write_no_traj_move(self, u, tmpdir):
destination = str(tmpdir / 'test.dcd')
u.trajectory[10]
u.atoms.write(destination, frames=[1, 2, 3])
assert u.trajectory.ts.frame == 10


def test_write_selection(self, u, tmpdir):
with tmpdir.as_cwd():
u.atoms.write("test.vmd")
Expand Down

0 comments on commit 06e386f

Please sign in to comment.