From 8eb4029e6575258f4fc13cb2729b9d01d68e57c0 Mon Sep 17 00:00:00 2001 From: Gregory Szorc Date: Sun, 17 Feb 2019 12:43:45 -0800 Subject: [PATCH] cffi: consume remaining data in input buffer Before, the CFFI implementation of ZstdDecompressionReader.read() may skip input data left over in an internal buffer. This would result in feeding incorrect input into the decompressor, which would likely manifest as a malformed zstd data error. As part of fixing this, we added a test to reproduce the failure. And we also improved the fuzzing coverage of this method. Closes #71. --- NEWS.rst | 10 +++ tests/common.py | 14 +++- tests/test_decompressor.py | 17 +++++ tests/test_decompressor_fuzzing.py | 111 +++++++++++++++++++++++++---- zstandard/cffi.py | 6 ++ 5 files changed, 144 insertions(+), 14 deletions(-) diff --git a/NEWS.rst b/NEWS.rst index f574f2d8..a7068cee 100644 --- a/NEWS.rst +++ b/NEWS.rst @@ -118,9 +118,19 @@ Backwards Compatibility Nodes ``import zstandard`` to cause an appropriate backend module to be loaded automatically. +Bug Fixes +--------- + +* CFFI backend could encounter an error when calling + ``ZstdDecompressionReader.read()`` if there was data remaining in an + internal buffer. The issue has been fixed. (#71) + Changes ------- +* CFFI's ``ZstdDecompressionReader.read()`` now properly handles data + remaining in any internal buffer. Before, repeated ``read()`` could + result in *random* errors. #71. * Upgraded various Python packages in CI environment. * Upgrade to hypothesis 4.5.11. * In the CFFI backend, ``CompressionReader`` and ``DecompressionReader`` diff --git a/tests/common.py b/tests/common.py index f65320c2..8bffd7a7 100644 --- a/tests/common.py +++ b/tests/common.py @@ -142,6 +142,13 @@ def random_input_data(): except OSError: pass + # Also add some actual random data. + _source_files.append(os.urandom(100)) + _source_files.append(os.urandom(1000)) + _source_files.append(os.urandom(10000)) + _source_files.append(os.urandom(100000)) + _source_files.append(os.urandom(1000000)) + return _source_files @@ -165,11 +172,14 @@ def generate_samples(): if hypothesis: - default_settings = hypothesis.settings() + default_settings = hypothesis.settings(deadline=1000) hypothesis.settings.register_profile('default', default_settings) - ci_settings = hypothesis.settings(max_examples=2500) + ci_settings = hypothesis.settings(deadline=10000, max_examples=2500) hypothesis.settings.register_profile('ci', ci_settings) + expensive_settings = hypothesis.settings(deadline=None, max_examples=10000) + hypothesis.settings.register_profile('expensive', expensive_settings) + hypothesis.settings.load_profile( os.environ.get('HYPOTHESIS_PROFILE', 'default')) diff --git a/tests/test_decompressor.py b/tests/test_decompressor.py index ed847db5..8682f7ef 100644 --- a/tests/test_decompressor.py +++ b/tests/test_decompressor.py @@ -533,6 +533,23 @@ def test_read_after_error(self): with self.assertRaisesRegexp(ValueError, 'stream is closed'): reader.read(100) + def test_partial_read(self): + # Inspired by https://github.com/indygreg/python-zstandard/issues/71. + buffer = io.BytesIO() + cctx = zstd.ZstdCompressor() + writer = cctx.stream_writer(buffer) + writer.write(bytearray(os.urandom(1000000))) + writer.flush(zstd.FLUSH_FRAME) + buffer.seek(0) + + dctx = zstd.ZstdDecompressor() + reader = dctx.stream_reader(buffer) + + while True: + chunk = reader.read(8192) + if not chunk: + break + @make_cffi class TestDecompressor_decompressobj(unittest.TestCase): diff --git a/tests/test_decompressor_fuzzing.py b/tests/test_decompressor_fuzzing.py index 2357917f..3aefa3e8 100644 --- a/tests/test_decompressor_fuzzing.py +++ b/tests/test_decompressor_fuzzing.py @@ -24,20 +24,29 @@ class TestDecompressor_stream_reader_fuzzing(unittest.TestCase): suppress_health_check=[hypothesis.HealthCheck.large_base_example]) @hypothesis.given(original=strategies.sampled_from(random_input_data()), level=strategies.integers(min_value=1, max_value=5), - source_read_size=strategies.integers(1, 16384), + streaming=strategies.booleans(), + source_read_size=strategies.integers(1, 1048576), read_sizes=strategies.data()) - def test_stream_source_read_variance(self, original, level, source_read_size, - read_sizes): + def test_stream_source_read_variance(self, original, level, streaming, + source_read_size, read_sizes): cctx = zstd.ZstdCompressor(level=level) - frame = cctx.compress(original) + + if streaming: + source = io.BytesIO() + writer = cctx.stream_writer(source) + writer.write(original) + writer.flush(zstd.FLUSH_FRAME) + source.seek(0) + else: + frame = cctx.compress(original) + source = io.BytesIO(frame) dctx = zstd.ZstdDecompressor() - source = io.BytesIO(frame) chunks = [] with dctx.stream_reader(source, read_size=source_read_size) as reader: while True: - read_size = read_sizes.draw(strategies.integers(1, 16384)) + read_size = read_sizes.draw(strategies.integers(1, 131072)) chunk = reader.read(read_size) if not chunk: break @@ -46,23 +55,67 @@ def test_stream_source_read_variance(self, original, level, source_read_size, self.assertEqual(b''.join(chunks), original) + # Similar to above except we have a constant read() size. + @hypothesis.settings( + suppress_health_check=[hypothesis.HealthCheck.large_base_example]) + @hypothesis.given(original=strategies.sampled_from(random_input_data()), + level=strategies.integers(min_value=1, max_value=5), + streaming=strategies.booleans(), + source_read_size=strategies.integers(1, 1048576), + read_size=strategies.integers(1, 131072)) + def test_stream_source_read_size(self, original, level, streaming, + source_read_size, read_size): + cctx = zstd.ZstdCompressor(level=level) + + if streaming: + source = io.BytesIO() + writer = cctx.stream_writer(source) + writer.write(original) + writer.flush(zstd.FLUSH_FRAME) + source.seek(0) + else: + frame = cctx.compress(original) + source = io.BytesIO(frame) + + dctx = zstd.ZstdDecompressor() + + chunks = [] + reader = dctx.stream_reader(source, read_size=source_read_size) + while True: + chunk = reader.read(read_size) + if not chunk: + break + + chunks.append(chunk) + + self.assertEqual(b''.join(chunks), original) + @hypothesis.settings( suppress_health_check=[hypothesis.HealthCheck.large_base_example]) @hypothesis.given(original=strategies.sampled_from(random_input_data()), level=strategies.integers(min_value=1, max_value=5), - source_read_size=strategies.integers(1, 16384), + streaming=strategies.booleans(), + source_read_size=strategies.integers(1, 1048576), read_sizes=strategies.data()) - def test_buffer_source_read_variance(self, original, level, source_read_size, - read_sizes): + def test_buffer_source_read_variance(self, original, level, streaming, + source_read_size, read_sizes): cctx = zstd.ZstdCompressor(level=level) - frame = cctx.compress(original) + + if streaming: + source = io.BytesIO() + writer = cctx.stream_writer(source) + writer.write(original) + writer.flush(zstd.FLUSH_FRAME) + frame = source.getvalue() + else: + frame = cctx.compress(original) dctx = zstd.ZstdDecompressor() chunks = [] with dctx.stream_reader(frame, read_size=source_read_size) as reader: while True: - read_size = read_sizes.draw(strategies.integers(1, 16384)) + read_size = read_sizes.draw(strategies.integers(1, 131072)) chunk = reader.read(read_size) if not chunk: break @@ -71,12 +124,46 @@ def test_buffer_source_read_variance(self, original, level, source_read_size, self.assertEqual(b''.join(chunks), original) + # Similar to above except we have a constant read() size. + @hypothesis.settings( + suppress_health_check=[hypothesis.HealthCheck.large_base_example]) + @hypothesis.given(original=strategies.sampled_from(random_input_data()), + level=strategies.integers(min_value=1, max_value=5), + streaming=strategies.booleans(), + source_read_size=strategies.integers(1, 1048576), + read_size=strategies.integers(1, 131072)) + def test_buffer_source_constant_read_size(self, original, level, streaming, + source_read_size, read_size): + cctx = zstd.ZstdCompressor(level=level) + + if streaming: + source = io.BytesIO() + writer = cctx.stream_writer(source) + writer.write(original) + writer.flush(zstd.FLUSH_FRAME) + frame = source.getvalue() + else: + frame = cctx.compress(original) + + dctx = zstd.ZstdDecompressor() + chunks = [] + + reader = dctx.stream_reader(frame, read_size=source_read_size) + while True: + chunk = reader.read(read_size) + if not chunk: + break + + chunks.append(chunk) + + self.assertEqual(b''.join(chunks), original) + @hypothesis.settings( suppress_health_check=[hypothesis.HealthCheck.large_base_example]) @hypothesis.given( original=strategies.sampled_from(random_input_data()), level=strategies.integers(min_value=1, max_value=5), - source_read_size=strategies.integers(1, 16384), + source_read_size=strategies.integers(1, 1048576), seek_amounts=strategies.data(), read_sizes=strategies.data()) def test_relative_seeks(self, original, level, source_read_size, seek_amounts, diff --git a/zstandard/cffi.py b/zstandard/cffi.py index 39fb43fb..9a424f1d 100644 --- a/zstandard/cffi.py +++ b/zstandard/cffi.py @@ -1697,9 +1697,15 @@ def decompress(): return ffi.buffer(out_buffer.dst, out_buffer.pos)[:] def get_input(): + # We have data left over in the input buffer. Use it. + if self._in_buffer.pos < self._in_buffer.size: + return + + # All input data exhausted. Nothing to do. if self._finished_input: return + # Else populate the input buffer from our source. if hasattr(self._source, 'read'): data = self._source.read(self._read_size)