Skip to content

Commit

Permalink
decompressionreader: implement readinto()
Browse files Browse the repository at this point in the history
This is part of the io.BufferedIOBase interface. It allows writing
output into a pre-allocated buffer.
  • Loading branch information
indygreg committed Feb 18, 2019
1 parent 80204ca commit 54f4d2a
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 0 deletions.
2 changes: 2 additions & 0 deletions NEWS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ Bug Fixes
Changes
-------

* ``ZstdDecompressionReader`` has gained a ``readinto(b)`` method for
reading decompressed output into an existing buffer.
* ``ZstdDecompressor.stream_reader()`` now accepts a ``read_across_frames``
argument to control behavior when the input data has multiple zstd
*frames*. When ``False`` (the default for backwards compatibility), a
Expand Down
73 changes: 73 additions & 0 deletions c-ext/decompressionreader.c
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,78 @@ static PyObject* reader_read(ZstdDecompressionReader* self, PyObject* args, PyOb
return result;
}


static PyObject* reader_readinto(ZstdDecompressionReader* self, PyObject* args) {
Py_buffer dest;
ZSTD_outBuffer output;
int decompressResult, readResult;
PyObject* result = NULL;

if (self->closed) {
PyErr_SetString(PyExc_ValueError, "stream is closed");
return NULL;
}

if (self->finishedOutput) {
return PyLong_FromLong(0);
}

if (!PyArg_ParseTuple(args, "w*:readinto", &dest)) {
return NULL;
}

if (!PyBuffer_IsContiguous(&dest, 'C') || dest.ndim > 1) {
PyErr_SetString(PyExc_ValueError,
"destination buffer should be contiguous and have at most one dimension");
goto finally;
}

output.dst = dest.buf;
output.size = dest.len;
output.pos = 0;

readinput:

decompressResult = decompress_input(self, &output);

if (-1 == decompressResult) {
goto finally;
}
else if (0 == decompressResult) { }
else if (1 == decompressResult) {
self->bytesDecompressed += output.pos;
result = PyLong_FromSize_t(output.pos);
goto finally;
}
else {
assert(0);
}

readResult = read_input(self);

if (-1 == readResult) {
goto finally;
}
else if (0 == readResult) {}
else if (1 == readResult) {}
else {
assert(0);
}

if (self->input.size) {
goto readinput;
}

/* EOF */
self->bytesDecompressed += output.pos;
result = PyLong_FromSize_t(output.pos);

finally:
PyBuffer_Release(&dest);

return result;
}

static PyObject* reader_readall(PyObject* self) {
PyErr_SetNone(PyExc_NotImplementedError);
return NULL;
Expand Down Expand Up @@ -426,6 +498,7 @@ static PyMethodDef reader_methods[] = {
PyDoc_STR("Returns True") },
{ "read", (PyCFunction)reader_read, METH_VARARGS | METH_KEYWORDS,
PyDoc_STR("read compressed data") },
{ "readinto", (PyCFunction)reader_readinto, METH_VARARGS, NULL },
{ "readall", (PyCFunction)reader_readall, METH_NOARGS, PyDoc_STR("Not implemented") },
{ "readline", (PyCFunction)reader_readline, METH_NOARGS, PyDoc_STR("Not implemented") },
{ "readlines", (PyCFunction)reader_readlines, METH_NOARGS, PyDoc_STR("Not implemented") },
Expand Down
32 changes: 32 additions & 0 deletions tests/test_decompressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,38 @@ def test_read_multiple_frames(self):
reader = dctx.stream_reader(source, read_across_frames=True)
self.assertEqual(reader.read(128), b'foobar')

def test_readinto(self):
cctx = zstd.ZstdCompressor()
foo = cctx.compress(b'foo')

dctx = zstd.ZstdDecompressor()

# Attempting to readinto() a non-writable buffer fails.
# The exact exception varies based on the backend.
reader = dctx.stream_reader(foo)
with self.assertRaises(Exception):
reader.readinto(b'foobar')

# readinto() with sufficiently large destination.
b = bytearray(1024)
reader = dctx.stream_reader(foo)
self.assertEqual(reader.readinto(b), 3)
self.assertEqual(b[0:3], b'foo')
self.assertEqual(reader.readinto(b), 0)
self.assertEqual(b[0:3], b'foo')

# readinto() with small reads.
b = bytearray(1024)
reader = dctx.stream_reader(foo, read_size=1)
self.assertEqual(reader.readinto(b), 3)
self.assertEqual(b[0:3], b'foo')

# Too small destination buffer.
b = bytearray(2)
reader = dctx.stream_reader(foo)
self.assertEqual(reader.readinto(b), 2)
self.assertEqual(b[:], b'fo')


@make_cffi
class TestDecompressor_decompressobj(unittest.TestCase):
Expand Down
29 changes: 29 additions & 0 deletions zstandard/cffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -1743,6 +1743,35 @@ def read(self, size):
self._bytes_decompressed += out_buffer.pos
return ffi.buffer(out_buffer.dst, out_buffer.pos)[:]

def readinto(self, b):
if self._closed:
raise ValueError('stream is closed')

if self._finished_output:
return 0

# TODO use writable=True once we require CFFI >= 1.12.
dest_buffer = ffi.from_buffer(b)
ffi.memmove(b, b'', 0)
out_buffer = ffi.new('ZSTD_outBuffer *')
out_buffer.dst = dest_buffer
out_buffer.size = len(dest_buffer)
out_buffer.pos = 0

self._read_input()
if self._decompress_into_buffer(out_buffer):
self._bytes_decompressed += out_buffer.pos
return out_buffer.pos

while not self._finished_input:
self._read_input()
if self._decompress_into_buffer(out_buffer):
self._bytes_decompressed += out_buffer.pos
return out_buffer.pos

self._bytes_decompressed += out_buffer.pos
return out_buffer.pos

def seek(self, pos, whence=os.SEEK_SET):
if self._closed:
raise ValueError('stream is closed')
Expand Down

0 comments on commit 54f4d2a

Please sign in to comment.