diff --git a/package/AUTHORS b/package/AUTHORS index 8ba4857455f..8056f7f4f9f 100644 --- a/package/AUTHORS +++ b/package/AUTHORS @@ -117,6 +117,7 @@ Chronological list of authors - Daniele Padula 2019 - Ninad Bhat + - Fenil Suchak External code ------------- diff --git a/package/CHANGELOG b/package/CHANGELOG index a283ff2fed5..12fca2b7c43 100644 --- a/package/CHANGELOG +++ b/package/CHANGELOG @@ -49,6 +49,8 @@ Enhancements distance_array. (Issue #2103, PR #2209) * updated analysis.distances.contact_matrix to use capped_distance. (Issue #2102, PR #2215) + * added functionality to write files in compressed form(gz,bz2). (Issue #2216, + PR #2221) Changes * added official support for Python 3.7 (PR #1963) diff --git a/package/MDAnalysis/coordinates/GRO.py b/package/MDAnalysis/coordinates/GRO.py index 98e753b4303..3303f28cbff 100644 --- a/package/MDAnalysis/coordinates/GRO.py +++ b/package/MDAnalysis/coordinates/GRO.py @@ -326,7 +326,7 @@ def __init__(self, filename, convert_units=None, n_atoms=None, **kwargs): w.write(u.atoms) """ - self.filename = util.filename(filename, ext='gro') + self.filename = util.filename(filename, ext='gro', keep=True) self.n_atoms = n_atoms self.reindex = kwargs.pop('reindex', True) diff --git a/package/MDAnalysis/core/groups.py b/package/MDAnalysis/core/groups.py index f8788e6aca6..e9dcfecc458 100644 --- a/package/MDAnalysis/core/groups.py +++ b/package/MDAnalysis/core/groups.py @@ -2975,7 +2975,7 @@ def improper(self): "improper only makes sense for a group with exactly 4 atoms") return topologyobjects.ImproperDihedral(self.ix, self.universe) - def write(self, filename=None, file_format="PDB", + def write(self, filename=None, file_format=None, filenamefmt="{trjname}_{frame}", frames=None, **kwargs): """Write `AtomGroup` to a file. @@ -3054,8 +3054,9 @@ def write(self, filename=None, file_format="PDB", if filename is None: trjname, ext = os.path.splitext(os.path.basename(trj.filename)) filename = filenamefmt.format(trjname=trjname, frame=trj.frame) - filename = util.filename(filename, ext=file_format.lower(), keep=True) - + filename = util.filename(filename, + ext=file_format if file_format is not None else 'PDB', + keep=True) # Some writer behave differently when they are given a "multiframe" # argument. It is the case of the PDB writer tht writes models when # "multiframe" is True. @@ -3080,14 +3081,7 @@ def write(self, filename=None, file_format="PDB", # Try and select a Class using get_ methods (becomes `writer`) # Once (and if!) class is selected, use it in with block try: - # format keyword works differently in get_writer and get_selection_writer - # here it overrides everything, in get_sel it is just a default - # apply sparingly here! - format = os.path.splitext(filename)[1][1:] # strip initial dot! - format = format or file_format - format = format.strip().upper() - - writer = get_writer_for(filename, format=format, multiframe=multiframe) + writer = get_writer_for(filename, format=file_format, multiframe=multiframe) except (ValueError, TypeError): pass else: @@ -3106,7 +3100,8 @@ def write(self, filename=None, file_format="PDB", try: # here `file_format` is only used as default, # anything pulled off `filename` will be used preferentially - writer = get_selection_writer_for(filename, file_format) + writer = get_selection_writer_for(filename, + file_format if file_format is not None else 'PDB') except (TypeError, NotImplementedError): pass else: diff --git a/package/MDAnalysis/lib/util.py b/package/MDAnalysis/lib/util.py index 28c9457198c..18627e859d1 100644 --- a/package/MDAnalysis/lib/util.py +++ b/package/MDAnalysis/lib/util.py @@ -249,6 +249,7 @@ def filename(name, ext=None, keep=False): Also permits :class:`NamedStream` to pass through. """ if ext is not None: + ext = ext.lower() if not ext.startswith(os.path.extsep): ext = os.path.extsep + ext root, origext = os.path.splitext(name) diff --git a/testsuite/MDAnalysisTests/core/test_atomgroup.py b/testsuite/MDAnalysisTests/core/test_atomgroup.py index c2f223bf2ec..617cdb7b959 100644 --- a/testsuite/MDAnalysisTests/core/test_atomgroup.py +++ b/testsuite/MDAnalysisTests/core/test_atomgroup.py @@ -156,7 +156,7 @@ def test_write_frame_iterator(self, u, tmpdir, frames): assert_array_almost_equal(new_positions, ref_positions) - @pytest.mark.parametrize('extension', ('xtc', 'dcd', 'pdb', 'xyz')) + @pytest.mark.parametrize('extension', ('xtc', 'dcd', 'pdb', 'xyz', 'PDB')) def test_write_frame_none(self, u, tmpdir, extension): destination = str(tmpdir / 'test.' + extension) u.atoms.write(destination, frames=None) @@ -167,6 +167,26 @@ def test_write_frame_none(self, u, tmpdir, extension): u.atoms.positions[None, ...], new_positions, decimal=2 ) + @pytest.mark.parametrize('extension', ('xtc', 'dcd', 'pdb', 'xyz', 'PDB')) + def test_compressed_write_frame_none(self, u, tmpdir, extension): + for ext in ('.gz', '.bz2'): + destination = str(tmpdir / 'test.' + extension + ext) + u.atoms.write(destination, frames=None) + u_new = mda.Universe(destination) + new_positions = np.stack([ts.positions for ts in u_new.trajectory]) + assert_array_almost_equal( + u.atoms.positions[None, ...], new_positions, decimal=2 + ) + + def test_compressed_write_frames_all(self, u, tmpdir): + for ext in ('.gz', '.bz2'): + destination = str(tmpdir / 'test.dcd') + ext + u.atoms.write(destination, frames='all') + u_new = mda.Universe(destination) + ref_positions = np.stack([ts.positions for ts in u.trajectory]) + new_positions = np.stack([ts.positions for ts in u_new.trajectory]) + assert_array_almost_equal(new_positions, ref_positions) + def test_write_frames_all(self, u, tmpdir): destination = str(tmpdir / 'test.dcd') u.atoms.write(destination, frames='all') @@ -238,6 +258,16 @@ def test_write_atoms(self, universe, outfile): err_msg=("atom coordinate mismatch between original and {0!s} file" "".format(self.ext))) + def test_compressed_write_atoms(self, universe, outfile): + for compressed_ext in ('.gz', '.bz2'): + universe.atoms.write(outfile + compressed_ext) + u2 = self.universe_from_tmp(outfile + compressed_ext) + assert_almost_equal( + universe.atoms.positions, u2.atoms.positions, + self.precision, + err_msg=("atom coordinate mismatch between original and {0!s} file" + "".format(self.ext))) + def test_write_empty_atomgroup(self, universe, outfile): sel = universe.select_atoms('name doesntexist') with pytest.raises(IndexError):