diff --git a/pghoard/rohmu/compressor.py b/pghoard/rohmu/compressor.py index 540cd959..8a503835 100644 --- a/pghoard/rohmu/compressor.py +++ b/pghoard/rohmu/compressor.py @@ -8,6 +8,7 @@ from .errors import InvalidConfigurationError from .filewrap import Sink, Stream from .snappyfile import SnappyFile +from .zstdfile import open as zstd_open import lzma try: @@ -15,6 +16,11 @@ except ImportError: snappy = None +try: + import zstandard as zstd +except ImportError: + zstd = None + def CompressionFile(dst_fp, algorithm, level=0): """This looks like a class to users, but is actually a function that instantiates a class based on algorithm.""" @@ -24,6 +30,9 @@ def CompressionFile(dst_fp, algorithm, level=0): if algorithm == "snappy": return SnappyFile(dst_fp, "wb") + if algorithm == "zstd": + return zstd_open(dst_fp, "wb") + if algorithm: raise InvalidConfigurationError("invalid compression algorithm: {!r}".format(algorithm)) @@ -39,6 +48,8 @@ def __init__(self, src_fp, algorithm, level=0): self._compressor = lzma.LZMACompressor(lzma.FORMAT_XZ, -1, level, None) elif algorithm == "snappy": self._compressor = snappy.StreamCompressor() + elif algorithm == "zstd": + self._compressor = zstd.ZstdCompressor(level=level).compressobj() else: InvalidConfigurationError("invalid compression algorithm: {!r}".format(algorithm)) @@ -57,6 +68,9 @@ def DecompressionFile(src_fp, algorithm): if algorithm == "snappy": return SnappyFile(src_fp, "rb") + if algorithm == "zstd": + return zstd_open(src_fp, "rb") + if algorithm: raise InvalidConfigurationError("invalid compression algorithm: {!r}".format(algorithm)) @@ -73,6 +87,8 @@ def _create_decompressor(self, alg): return snappy.StreamDecompressor() elif alg == "lzma": return lzma.LZMADecompressor() + elif alg == "zstd": + return zstd.ZstdDecompressor().decompressobj() raise InvalidConfigurationError("invalid compression algorithm: {!r}".format(alg)) def write(self, data): diff --git a/pghoard/rohmu/zstdfile.py b/pghoard/rohmu/zstdfile.py new file mode 100644 index 00000000..498fe774 --- /dev/null +++ b/pghoard/rohmu/zstdfile.py @@ -0,0 +1,87 @@ +""" +rohmu - file-like interface for zstd + +Copyright (c) 2016 Ohmu Ltd +See LICENSE for details +""" + +from . import IO_BLOCK_SIZE +from .filewrap import FileWrap +import io + +try: + import zstandard as zstd +except ImportError: + zstd = None + + +class _ZstdFileWriter(FileWrap): + + def __init__(self, next_fp, level): + self._zstd = zstd.ZstdCompressor(level=level).compressobj() + super().__init__(next_fp) + + def close(self): + if self.closed: + return + data = self._zstd.flush() or b"" + if data: + self.next_fp.write(data) + self.next_fp.flush() + super().close() + + def write(self, data): + self._check_not_closed() + compressed_data = self._zstd.compress(data) + self.next_fp.write(compressed_data) + self.offset += len(data) + return len(data) + + def writable(self): + return True + + +class _ZtsdFileReader(FileWrap): + + def __init__(self, next_fp): + self._zstd = zstd.ZstdDecompressor().decompressobj() + super().__init__(next_fp) + self._done = False + + def close(self): + if self.closed: + return + super().close() + + def read(self, size=-1): # pylint: disable=unused-argument + # NOTE: size arg is ignored, random size output is returned + self._check_not_closed() + while not self._done: + compressed = self.next_fp.read(IO_BLOCK_SIZE) + if not compressed: + self._done = True + output = self._zstd.flush() or b"" + else: + output = self._zstd.decompress(compressed) + + if output: + self.offset += len(output) + return output + + return b"" + + def readable(self): + return True + + +def open(fp, mode, level=0): # pylint: disable=redefined-builtin + if zstd is None: + raise io.UnsupportedOperation("zstd is not available") + + if mode == "wb": + return _ZstdFileWriter(fp, level) + + if mode == "rb": + return _ZtsdFileReader(fp) + + raise io.UnsupportedOperation("unsupported mode for zstd") diff --git a/requirements-zstd.txt b/requirements-zstd.txt new file mode 100644 index 00000000..713da120 --- /dev/null +++ b/requirements-zstd.txt @@ -0,0 +1,2 @@ +-r requirements.txt +zstandard diff --git a/setup.py b/setup.py index d337197b..1a095d7c 100644 --- a/setup.py +++ b/setup.py @@ -23,6 +23,7 @@ "python-dateutil", "python-snappy >= 0.5", "requests >= 1.2.0", + "zstandard >= 0.11.1", ], extras_require={}, dependency_links=[], diff --git a/test/test_compressor.py b/test/test_compressor.py index 5fde01a3..e53000d8 100644 --- a/test/test_compressor.py +++ b/test/test_compressor.py @@ -11,6 +11,7 @@ from pghoard.compressor import CompressorThread from pghoard.rohmu import compressor, IO_BLOCK_SIZE, rohmufile from pghoard.rohmu.snappyfile import snappy, SnappyFile +from pghoard.rohmu.compressor import zstd from queue import Queue import io import lzma @@ -460,3 +461,17 @@ def test_snappy_read(self, tmpdir): full = b"".join(out) assert full == b"hello, world" + + +@pytest.mark.skipif(not zstd, reason="zstd not installed") +class TestZstdCompression(CompressionCase): + algorithm = "zstd" + + def compress(self, data): + return zstd.ZstdCompressor().compress(data) + + def decompress(self, data): + return zstd.ZstdDecompressor().decompressobj().decompress(data) + + def make_compress_stream(self, src_fp): + return compressor.CompressionStream(src_fp, "zstd")