Skip to content

Commit

Permalink
Merge pull request #2221 from Fenil3510/develop
Browse files Browse the repository at this point in the history
Writes compressed output of given format (fixes #2216)
  • Loading branch information
orbeckst authored Apr 5, 2019
2 parents 7dfb65a + 2c92916 commit 1be9f73
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 14 deletions.
1 change: 1 addition & 0 deletions package/AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ Chronological list of authors
- Daniele Padula
2019
- Ninad Bhat
- Fenil Suchak

External code
-------------
Expand Down
2 changes: 2 additions & 0 deletions package/CHANGELOG
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ Enhancements
distance_array. (Issue #2103, PR #2209)
* updated analysis.distances.contact_matrix to use capped_distance. (Issue #2102,
PR #2215)
* added functionality to write files in compressed form(gz,bz2). (Issue #2216,
PR #2221)

Changes
* added official support for Python 3.7 (PR #1963)
Expand Down
2 changes: 1 addition & 1 deletion package/MDAnalysis/coordinates/GRO.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def __init__(self, filename, convert_units=None, n_atoms=None, **kwargs):
w.write(u.atoms)
"""
self.filename = util.filename(filename, ext='gro')
self.filename = util.filename(filename, ext='gro', keep=True)
self.n_atoms = n_atoms
self.reindex = kwargs.pop('reindex', True)

Expand Down
19 changes: 7 additions & 12 deletions package/MDAnalysis/core/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -2975,7 +2975,7 @@ def improper(self):
"improper only makes sense for a group with exactly 4 atoms")
return topologyobjects.ImproperDihedral(self.ix, self.universe)

def write(self, filename=None, file_format="PDB",
def write(self, filename=None, file_format=None,
filenamefmt="{trjname}_{frame}", frames=None, **kwargs):
"""Write `AtomGroup` to a file.
Expand Down Expand Up @@ -3054,8 +3054,9 @@ def write(self, filename=None, file_format="PDB",
if filename is None:
trjname, ext = os.path.splitext(os.path.basename(trj.filename))
filename = filenamefmt.format(trjname=trjname, frame=trj.frame)
filename = util.filename(filename, ext=file_format.lower(), keep=True)

filename = util.filename(filename,
ext=file_format if file_format is not None else 'PDB',
keep=True)
# Some writer behave differently when they are given a "multiframe"
# argument. It is the case of the PDB writer tht writes models when
# "multiframe" is True.
Expand All @@ -3080,14 +3081,7 @@ def write(self, filename=None, file_format="PDB",
# Try and select a Class using get_ methods (becomes `writer`)
# Once (and if!) class is selected, use it in with block
try:
# format keyword works differently in get_writer and get_selection_writer
# here it overrides everything, in get_sel it is just a default
# apply sparingly here!
format = os.path.splitext(filename)[1][1:] # strip initial dot!
format = format or file_format
format = format.strip().upper()

writer = get_writer_for(filename, format=format, multiframe=multiframe)
writer = get_writer_for(filename, format=file_format, multiframe=multiframe)
except (ValueError, TypeError):
pass
else:
Expand All @@ -3106,7 +3100,8 @@ def write(self, filename=None, file_format="PDB",
try:
# here `file_format` is only used as default,
# anything pulled off `filename` will be used preferentially
writer = get_selection_writer_for(filename, file_format)
writer = get_selection_writer_for(filename,
file_format if file_format is not None else 'PDB')
except (TypeError, NotImplementedError):
pass
else:
Expand Down
1 change: 1 addition & 0 deletions package/MDAnalysis/lib/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ def filename(name, ext=None, keep=False):
Also permits :class:`NamedStream` to pass through.
"""
if ext is not None:
ext = ext.lower()
if not ext.startswith(os.path.extsep):
ext = os.path.extsep + ext
root, origext = os.path.splitext(name)
Expand Down
32 changes: 31 additions & 1 deletion testsuite/MDAnalysisTests/core/test_atomgroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def test_write_frame_iterator(self, u, tmpdir, frames):

assert_array_almost_equal(new_positions, ref_positions)

@pytest.mark.parametrize('extension', ('xtc', 'dcd', 'pdb', 'xyz'))
@pytest.mark.parametrize('extension', ('xtc', 'dcd', 'pdb', 'xyz', 'PDB'))
def test_write_frame_none(self, u, tmpdir, extension):
destination = str(tmpdir / 'test.' + extension)
u.atoms.write(destination, frames=None)
Expand All @@ -167,6 +167,26 @@ def test_write_frame_none(self, u, tmpdir, extension):
u.atoms.positions[None, ...], new_positions, decimal=2
)

@pytest.mark.parametrize('extension', ('xtc', 'dcd', 'pdb', 'xyz', 'PDB'))
def test_compressed_write_frame_none(self, u, tmpdir, extension):
for ext in ('.gz', '.bz2'):
destination = str(tmpdir / 'test.' + extension + ext)
u.atoms.write(destination, frames=None)
u_new = mda.Universe(destination)
new_positions = np.stack([ts.positions for ts in u_new.trajectory])
assert_array_almost_equal(
u.atoms.positions[None, ...], new_positions, decimal=2
)

def test_compressed_write_frames_all(self, u, tmpdir):
for ext in ('.gz', '.bz2'):
destination = str(tmpdir / 'test.dcd') + ext
u.atoms.write(destination, frames='all')
u_new = mda.Universe(destination)
ref_positions = np.stack([ts.positions for ts in u.trajectory])
new_positions = np.stack([ts.positions for ts in u_new.trajectory])
assert_array_almost_equal(new_positions, ref_positions)

def test_write_frames_all(self, u, tmpdir):
destination = str(tmpdir / 'test.dcd')
u.atoms.write(destination, frames='all')
Expand Down Expand Up @@ -238,6 +258,16 @@ def test_write_atoms(self, universe, outfile):
err_msg=("atom coordinate mismatch between original and {0!s} file"
"".format(self.ext)))

def test_compressed_write_atoms(self, universe, outfile):
for compressed_ext in ('.gz', '.bz2'):
universe.atoms.write(outfile + compressed_ext)
u2 = self.universe_from_tmp(outfile + compressed_ext)
assert_almost_equal(
universe.atoms.positions, u2.atoms.positions,
self.precision,
err_msg=("atom coordinate mismatch between original and {0!s} file"
"".format(self.ext)))

def test_write_empty_atomgroup(self, universe, outfile):
sel = universe.select_atoms('name doesntexist')
with pytest.raises(IndexError):
Expand Down

0 comments on commit 1be9f73

Please sign in to comment.