diff --git a/AUTHORS b/AUTHORS index 8d55219..6878b7b 100644 --- a/AUTHORS +++ b/AUTHORS @@ -17,3 +17,4 @@ Contributors: * Dominik Mierzejewski * Tyler Luchko * Giacomo Fiorin +* Eloy FĂ©lix diff --git a/CHANGELOG b/CHANGELOG index e75bd5b..2c2a4ed 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -13,6 +13,15 @@ The rules for this file: * accompany each entry with github issue/PR number (Issue #xyz) ------------------------------------------------------------------------------ +??/??/2019 eloyfelix + + * 0.6.0 + + Enhancements + + * Allow parsing/writing gzipped DX files + + 05/16/2019 giacomofiorin, orbeckst * 0.5.0 diff --git a/gridData/OpenDX.py b/gridData/OpenDX.py index c3df43b..9333445 100644 --- a/gridData/OpenDX.py +++ b/gridData/OpenDX.py @@ -165,6 +165,7 @@ import re from six import next from six.moves import range +import gzip import warnings @@ -177,17 +178,24 @@ def __init__(self,classid): self.component = None # component type self.D = None # dimensions - def write(self,file,optstring="",quote=False): + def write(self, stream, optstring="", quote=False): """write the 'object' line; additional args are packed in string""" classid = str(self.id) if quote: classid = '"'+classid+'"' # Only use a *single* space between tokens; both chimera's and pymol's DX parser # does not properly implement the OpenDX specs and produces garbage with multiple # spaces. (Chimera 1.4.1, PyMOL 1.3) - file.write('object '+classid+' class '+str(self.name)+' '+\ - optstring+'\n') + to_write = 'object '+classid+' class '+str(self.name)+' '+optstring+'\n' + self._write_line(stream, to_write) - def read(self,file): + @staticmethod + def _write_line(stream, line="", quote=False): + """write a line to the file""" + if isinstance(stream, gzip.GzipFile): + line = line.encode() + stream.write(line) + + def read(self, stream): raise NotImplementedError('Reading is currently not supported.') def ndformat(self,s): @@ -227,12 +235,14 @@ def __init__(self,classid,shape=None,origin=None,delta=None,**kwargs): # anything more complicated raise NotImplementedError('Only regularly spaced grids allowed, ' 'not delta={}'.format(self.delta)) - def write(self,file): - DXclass.write(self,file, - ('counts '+self.ndformat(' %d')) % tuple(self.shape)) - file.write('origin %f %f %f\n' % tuple(self.origin)) + def write(self, stream): + super(gridpositions, self).write( + stream, ('counts '+self.ndformat(' %d')) % tuple(self.shape)) + self._write_line(stream, 'origin %f %f %f\n' % tuple(self.origin)) for delta in self.delta: - file.write(('delta '+self.ndformat(' %f')+'\n') % tuple(delta)) + self._write_line( + stream, ('delta '+self.ndformat(' %f')+'\n') % tuple(delta)) + def edges(self): """Edges of the grid cells, origin at centre of 0,0,..,0 grid cell. @@ -251,9 +261,11 @@ def __init__(self,classid,shape=None,**kwargs): self.name = 'gridconnections' self.component = 'connections' self.shape = numpy.asarray(shape) # D dimensional shape - def write(self,file): - DXclass.write(self,file, - ('counts '+self.ndformat(' %d')) % tuple(self.shape)) + + def write(self, stream): + super(gridconnections, self).write( + stream, ('counts '+self.ndformat(' %d')) % tuple(self.shape)) + class array(DXclass): """OpenDX array class. @@ -350,12 +362,12 @@ def __init__(self, classid, array=None, type=None, typequote='"', self.type = type self.typequote = typequote - def write(self, file): + def write(self, stream): """Write the *class array* section. Parameters ---------- - file : file + stream : stream Raises ------ @@ -370,9 +382,9 @@ def write(self, file): "Use the type= keyword argument.").format( self.type, list(self.dx_types.keys()))) typelabel = (self.typequote+self.type+self.typequote) - DXclass.write(self,file, - 'type {0} rank 0 items {1} data follows'.format( - typelabel, self.array.size)) + super(array, self).write(stream, 'type {0} rank 0 items {1} data follows'.format( + typelabel, self.array.size)) + # grid data, serialized as a C array (z fastest varying) # (flat iterator is equivalent to: for x: for y: for z: grid[x,y,z]) # VMD's DX reader requires exactly 3 values per line @@ -385,12 +397,12 @@ def write(self, file): while 1: try: for i in range(values_per_line): - file.write(fmt_string.format(next(values)) + "\t") - file.write('\n') + self._write_line(stream, fmt_string.format(next(values)) + "\t") + self._write_line(stream, '\n') except StopIteration: - file.write('\n') + self._write_line(stream, '\n') break - file.write('attribute "dep" string "positions"\n') + self._write_line(stream, 'attribute "dep" string "positions"\n') class field(DXclass): """OpenDX container class @@ -459,6 +471,13 @@ def __init__(self,classid='0',components=None,comments=None): self.components = components self.comments= comments + def _openfile_writing(self, filename): + """Returns a regular or gz file stream for writing""" + if filename.endswith('.gz'): + return gzip.open(filename, 'wb') + else: + return open(filename, 'w') + def write(self, filename): """Write the complete dx object to the file. @@ -471,19 +490,20 @@ def write(self, filename): """ # comments (VMD chokes on lines of len > 80, so truncate) maxcol = 80 - with open(str(filename), 'w') as outfile: + with self._openfile_writing(str(filename)) as outfile: for line in self.comments: comment = '# '+str(line) - outfile.write(comment[:maxcol]+'\n') + self._write_line(outfile, comment[:maxcol]+'\n') # each individual object - for component,object in self.sorted_components(): + for component, object in self.sorted_components(): object.write(outfile) # the field object itself - DXclass.write(self,outfile,quote=True) - for component,object in self.sorted_components(): - outfile.write('component "%s" value %s\n' % (component,str(object.id))) + super(field, self).write(outfile, quote=True) + for component, object in self.sorted_components(): + self._write_line(outfile, 'component "%s" value %s\n' % ( + component, str(object.id))) - def read(self,file): + def read(self, stream): """Read DX field from file. dx = OpenDX.field.read(dxfile) @@ -491,7 +511,7 @@ def read(self,file): The classid is discarded and replaced with the one from the file. """ DXfield = self - p = DXParser(file) + p = DXParser(stream) p.parse(DXfield) def add(self,component,DXobj): @@ -652,7 +672,7 @@ def __init__(self, filename): } - def parse(self,DXfield): + def parse(self, DXfield): """Parse the dx file and construct a DX field object with component classes. A :class:`field` instance *DXfield* must be provided to be @@ -678,8 +698,13 @@ def parse(self,DXfield): self.currentobject = None # containers for data self.objects = [] # | self.tokens = [] # token buffer - with open(self.filename, 'r') as self.dxfile: - self.use_parser('general') # parse the whole file and populate self.objects + + if self.filename.endswith('.gz'): + with gzip.open(self.filename, 'rt') as self.dxfile: + self.use_parser('general') + else: + with open(self.filename, 'r') as self.dxfile: + self.use_parser('general') # parse the whole file and populate self.objects # assemble field from objects for o in self.objects: diff --git a/gridData/core.py b/gridData/core.py index 4440833..72a019c 100644 --- a/gridData/core.py +++ b/gridData/core.py @@ -380,7 +380,11 @@ def _guess_format(self, filename, file_format=None, export=True): else: available = self._loaders if file_format is None: - file_format = os.path.splitext(filename)[1][1:] + splitted = os.path.splitext(filename) + if splitted[1][1:] in ('gz', ): + file_format = os.path.splitext(splitted[0])[1][1:] + else: + file_format = splitted[1][1:] file_format = file_format.upper() if not file_format: file_format = self.default_format @@ -544,6 +548,8 @@ def _export_dx(self, filename, type=None, typequote='"', **kwargs): data=OpenDX.array(3, self.grid, type=type, typequote=typequote), ) dx = OpenDX.field('density', components=components, comments=comments) + if ext == '.gz': + filename = root + ext dx.write(filename) def save(self, filename): diff --git a/gridData/tests/datafiles/__init__.py b/gridData/tests/datafiles/__init__.py index 4e3f21f..e38acf1 100644 --- a/gridData/tests/datafiles/__init__.py +++ b/gridData/tests/datafiles/__init__.py @@ -5,6 +5,7 @@ __all__ = ["DX", "CCP4", "gOpenMol"] DX = resource_filename(__name__, 'test.dx') +DXGZ = resource_filename(__name__, 'test.dx.gz') CCP4 = resource_filename(__name__, 'test.ccp4') # from http://www.ebi.ac.uk/pdbe/coordinates/files/1jzv.ccp4 # (see issue #57) diff --git a/gridData/tests/datafiles/test.dx.gz b/gridData/tests/datafiles/test.dx.gz new file mode 100644 index 0000000..7795203 Binary files /dev/null and b/gridData/tests/datafiles/test.dx.gz differ diff --git a/gridData/tests/test_dx.py b/gridData/tests/test_dx.py index 1296e72..35fb1a4 100644 --- a/gridData/tests/test_dx.py +++ b/gridData/tests/test_dx.py @@ -10,8 +10,9 @@ from . import datafiles -def test_read_dx(): - g = Grid(datafiles.DX) +@pytest.mark.parametrize("infile", [datafiles.DX, datafiles.DXGZ]) +def test_read_dx(infile): + g = Grid(infile) POINTS = 8 ref = np.ones(POINTS) ref[4] = 1e-6 @@ -21,7 +22,7 @@ def test_read_dx(): assert_equal(g.delta, np.ones(3)) assert_equal(g.origin, np.array([20.1, 3., -10.])) - +@pytest.mark.parametrize("outfile", ["grid.dx", "grid.dx.gz"]) @pytest.mark.parametrize("nptype,dxtype", [ ("float16", "float"), ("float32", "float"), @@ -35,7 +36,7 @@ def test_read_dx(): ("int8", "signed byte"), ("uint8", "byte"), ]) -def test_write_dx(tmpdir, nptype, dxtype, counts=100, ndim=3): +def test_write_dx(tmpdir, nptype, dxtype, outfile, counts=100, ndim=3): # conversion from numpy array to DX file h, edges = np.histogramdd(np.random.random((counts, ndim)), bins=10) @@ -47,7 +48,6 @@ def test_write_dx(tmpdir, nptype, dxtype, counts=100, ndim=3): assert_equal(g.grid.sum(), counts) with tmpdir.as_cwd(): - outfile = "grid.dx" g.export(outfile) g2 = Grid(outfile) @@ -66,9 +66,10 @@ def test_write_dx(tmpdir, nptype, dxtype, counts=100, ndim=3): assert_equal(out_dxtype, dxtype) +@pytest.mark.parametrize("outfile", ["grid.dx", "grid.dx.gz"]) @pytest.mark.parametrize('nptype', ("complex64", "complex128", "bool_")) @pytest.mark.filterwarnings("ignore:array dtype.name =") -def test_write_dx_ValueError(tmpdir, nptype, counts=100, ndim=3): +def test_write_dx_ValueError(tmpdir, nptype, outfile, counts=100, ndim=3): h, edges = np.histogramdd(np.random.random((counts, ndim)), bins=10) g = Grid(h, edges) @@ -77,8 +78,5 @@ def test_write_dx_ValueError(tmpdir, nptype, counts=100, ndim=3): with pytest.raises(ValueError): with tmpdir.as_cwd(): - outfile = "grid.dx" g.export(outfile) - - diff --git a/setup.py b/setup.py index 6226115..9a5c100 100644 --- a/setup.py +++ b/setup.py @@ -38,8 +38,8 @@ 'Topic :: Software Development :: Libraries :: Python Modules', ], packages=find_packages(exclude=[]), - package_data={'gridData': ['tests/datafiles/*.dx', 'tests/datafiles/*.ccp4', - 'tests/datafiles/*.plt']}, + package_data={'gridData': ['tests/datafiles/*.dx', 'tests/datafiles/*.dx.gz', + 'tests/datafiles/*.ccp4', 'tests/datafiles/*.plt']}, install_requires=['numpy>=1.0.3', 'six', 'scipy'], tests_require=['pytest', 'numpy'], zip_safe=True,