Skip to content

Commit

Permalink
decompressionreader: add closefd argument to ZstdDecompressionReader
Browse files Browse the repository at this point in the history
This completes our implementation of closefd support on all our IO
classes to enable the consumer to have complete control over whether the
inner stream should be closed on our close. See prior commits and #76 for
more context.
  • Loading branch information
indygreg committed Dec 26, 2020
1 parent e5e68ab commit d5c6685
Show file tree
Hide file tree
Showing 8 changed files with 131 additions and 17 deletions.
12 changes: 7 additions & 5 deletions NEWS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ Backwards Compatibility Notes
``ZstdDecompressionReader.readlines()`` now accept an integer argument.
This makes them conform with the IO interface. The methods still raise
``io.UnsupportedOperation``.
* ``ZstdCompressionReader.__enter__`` now raises ``ValueError`` if the
instance was already closed.
* ``ZstdCompressionReader.__enter__`` and ``ZstdDecompressionReader.__enter__``
now raise ``ValueError`` if the instance was already closed.

Bug Fixes
---------
Expand Down Expand Up @@ -148,10 +148,12 @@ Changes
package. Previously, there were modules in other packages. (#115)
* C source code is now automatically formatted with ``clang-format``.
* ``ZstdCompressor.stream_writer()``, ``ZstdCompressor.stream_reader()``,
and ``ZstdDecompressor.stream_writer()`` now accept a ``closefd``
``ZstdDecompressor.stream_writer()``, and
``ZstdDecompressor.stream_reader()`` now accept a ``closefd``
argument to control whether the underlying stream should be closed
when the ``ZstdCompressionWriter``, ``ZstdCompressReader``, or
``ZstdDecompressionWriter`` is closed.
when the ``ZstdCompressionWriter``, ``ZstdCompressReader``,
``ZstdDecompressionWriter``, or ``ZstdDecompressionReader`` is closed.
(#76)

0.14.1 (released 2020-12-05)
============================
Expand Down
19 changes: 18 additions & 1 deletion c-ext/decompressionreader.c
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ decompressionreader_enter(ZstdDecompressionReader *self) {
return NULL;
}

if (self->closed) {
PyErr_SetString(PyExc_ValueError, "stream is closed");
return NULL;
}

self->entered = 1;

Py_INCREF(self);
Expand All @@ -68,7 +73,10 @@ static PyObject *decompressionreader_exit(ZstdDecompressionReader *self,
}

self->entered = 0;
self->closed = 1;

if (NULL == PyObject_CallMethod((PyObject *)self, "close", NULL)) {
return NULL;
}

/* Release resources. */
Py_CLEAR(self->reader);
Expand All @@ -95,7 +103,16 @@ static PyObject *decompressionreader_seekable(PyObject *self) {
}

static PyObject *decompressionreader_close(ZstdDecompressionReader *self) {
if (self->closed) {
Py_RETURN_NONE;
}

self->closed = 1;

if (self->closefd && PyObject_HasAttrString(self->reader, "close")) {
return PyObject_CallMethod(self->reader, "close", NULL);
}

Py_RETURN_NONE;
}

Expand Down
10 changes: 7 additions & 3 deletions c-ext/decompressor.c
Original file line number Diff line number Diff line change
Expand Up @@ -576,15 +576,18 @@ PyDoc_STRVAR(
static ZstdDecompressionReader *
Decompressor_stream_reader(ZstdDecompressor *self, PyObject *args,
PyObject *kwargs) {
static char *kwlist[] = {"source", "read_size", "read_across_frames", NULL};
static char *kwlist[] = {"source", "read_size", "read_across_frames",
"closefd", NULL};

PyObject *source;
size_t readSize = ZSTD_DStreamInSize();
PyObject *readAcrossFrames = NULL;
PyObject *closefd = NULL;
ZstdDecompressionReader *result;

if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|kO:stream_reader", kwlist,
&source, &readSize, &readAcrossFrames)) {
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|kOO:stream_reader",
kwlist, &source, &readSize,
&readAcrossFrames, &closefd)) {
return NULL;
}

Expand Down Expand Up @@ -624,6 +627,7 @@ Decompressor_stream_reader(ZstdDecompressor *self, PyObject *args,
Py_INCREF(self);
result->readAcrossFrames =
readAcrossFrames ? PyObject_IsTrue(readAcrossFrames) : 0;
result->closefd = closefd ? PyObject_IsTrue(closefd) : 0;

return result;
}
Expand Down
2 changes: 2 additions & 0 deletions c-ext/python-zstandard.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,8 @@ typedef struct {
int readAcrossFrames;
/* Buffer to read from (if reading from a buffer). */
Py_buffer buffer;
/* Whether to close the inner object on close() */
int closefd;

/* Whether the context manager is active. */
int entered;
Expand Down
3 changes: 3 additions & 0 deletions docs/decompressor.rst
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,9 @@ When ``False`` (the default), a read will complete when the end of a
zstd *frame* is encountered. When ``True``, a read can potentially
return data spanning multiple zstd *frames*.

The ``closefd`` keyword argument defines whether to close the underlying stream
when this instance is itself ``close()``d. The default is ``False``.

Streaming Writer Interface
==========================

Expand Down
68 changes: 65 additions & 3 deletions tests/test_decompressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,68 @@ def test_read_stream_small_chunks(self):

self.assertEqual(b"".join(chunks), source)

def test_close(self):
foo = zstd.ZstdCompressor().compress(b"foo" * 1024)

buffer = io.BytesIO(foo)
dctx = zstd.ZstdDecompressor()
reader = dctx.stream_reader(buffer)

reader.read(3)
self.assertFalse(reader.closed)
self.assertFalse(buffer.closed)
reader.close()
self.assertTrue(reader.closed)
self.assertFalse(buffer.closed)

with self.assertRaisesRegex(ValueError, "stream is closed"):
reader.read(b"")

with self.assertRaisesRegex(ValueError, "stream is closed"):
with reader:
pass

# Context manager exit should not close stream.
buffer = io.BytesIO(foo)
reader = dctx.stream_reader(buffer)

with reader:
reader.read(3)

self.assertTrue(reader.closed)
self.assertFalse(buffer.closed)

def test_close_closefd_true(self):
foo = zstd.ZstdCompressor().compress(b"foo" * 1024)

buffer = io.BytesIO(foo)
dctx = zstd.ZstdDecompressor()
reader = dctx.stream_reader(buffer, closefd=True)

reader.read(3)
self.assertFalse(reader.closed)
self.assertFalse(buffer.closed)
reader.close()
self.assertTrue(reader.closed)
self.assertTrue(buffer.closed)

with self.assertRaisesRegex(ValueError, "stream is closed"):
reader.read(b"")

with self.assertRaisesRegex(ValueError, "stream is closed"):
with reader:
pass

# Context manager exit should not close stream.
buffer = io.BytesIO(foo)
reader = dctx.stream_reader(buffer, closefd=True)

with reader:
reader.read(3)

self.assertTrue(reader.closed)
self.assertTrue(buffer.closed)

def test_read_after_exit(self):
cctx = zstd.ZstdCompressor()
frame = cctx.compress(b"foo" * 60)
Expand Down Expand Up @@ -550,9 +612,9 @@ def test_read_after_error(self):
with reader:
reader.read(0)

with reader:
with self.assertRaisesRegex(ValueError, "stream is closed"):
reader.read(100)
with self.assertRaisesRegex(ValueError, "stream is closed"):
with reader:
pass

def test_partial_read(self):
# Inspired by https://github.com/indygreg/python-zstandard/issues/71.
Expand Down
2 changes: 2 additions & 0 deletions zstandard/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,8 @@ class ZstdDecompressor(object):
source: Union[IO[bytes], ByteString],
read_size: int = 0,
read_across_frames: bool = False,
*,
closefd=False,
) -> ZstdDecompressionReader: ...
def decompressobj(self, write_size: int = 0) -> ZstdDecompressionObj: ...
def read_to_iter(
Expand Down
32 changes: 27 additions & 5 deletions zstandard/backend_cffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -1940,11 +1940,20 @@ def flush(self, length=0):


class ZstdDecompressionReader(object):
def __init__(self, decompressor, source, read_size, read_across_frames):
def __init__(
self,
decompressor,
source,
read_size,
read_across_frames,
*,
closefd=False,
):
self._decompressor = decompressor
self._source = source
self._read_size = read_size
self._read_across_frames = bool(read_across_frames)
self._closefd = bool(closefd)
self._entered = False
self._closed = False
self._bytes_decompressed = 0
Expand All @@ -1958,15 +1967,20 @@ def __enter__(self):
if self._entered:
raise ValueError("cannot __enter__ multiple times")

if self._closed:
raise ValueError("stream is closed")

self._entered = True
return self

def __exit__(self, exc_type, exc_value, exc_tb):
self._entered = False
self._closed = True
self._source = None
self._decompressor = None

self.close()

self._source = None

return False

def readable(self):
Expand Down Expand Up @@ -1997,8 +2011,14 @@ def flush(self):
return None

def close(self):
if self._closed:
return None

self._closed = True
return None

f = getattr(self._source, "close", None)
if self._closefd and f:
f()

@property
def closed(self):
Expand Down Expand Up @@ -2481,10 +2501,12 @@ def stream_reader(
source,
read_size=DECOMPRESSION_RECOMMENDED_INPUT_SIZE,
read_across_frames=False,
*,
closefd=False,
):
self._ensure_dctx()
return ZstdDecompressionReader(
self, source, read_size, read_across_frames
self, source, read_size, read_across_frames, closefd=closefd
)

def decompressobj(self, write_size=DECOMPRESSION_RECOMMENDED_OUTPUT_SIZE):
Expand Down

0 comments on commit d5c6685

Please sign in to comment.