From 53d9cff027f88426f3e5aef65e90f162471453b4 Mon Sep 17 00:00:00 2001 From: Michael Graeb Date: Tue, 20 Oct 2020 09:43:52 -0700 Subject: [PATCH] Accept duck typed python streams. (#185) --- awscrt/http.py | 20 ++++++++--------- awscrt/io.py | 48 ++++++++++++++++++++++++++++++---------- source/io.c | 29 ++++++++++++++---------- test/test_io.py | 59 ++++++++++++++++++++++++++++++++++++++++++++++++- 4 files changed, 121 insertions(+), 35 deletions(-) diff --git a/awscrt/http.py b/awscrt/http.py index b24a2a3d4..da87617f2 100644 --- a/awscrt/http.py +++ b/awscrt/http.py @@ -179,7 +179,7 @@ def port(self): def request(self, request, on_response=None, on_body=None): """Create :class:`HttpClientStream` to carry out the request/response exchange. - NOTE: The stream sends no data until :meth:`HttpClientStream.activate()` + NOTE: The HTTP stream sends no data until :meth:`HttpClientStream.activate()` is called. Call activate() when you're ready for callbacks and events to fire. Args: @@ -188,7 +188,7 @@ def request(self, request, on_response=None, on_body=None): on_response: Optional callback invoked once main response headers are received. The function should take the following arguments and return nothing: - * `http_stream` (:class:`HttpClientStream`): Stream carrying + * `http_stream` (:class:`HttpClientStream`): HTTP stream carrying out this request/response exchange. * `status_code` (int): Response status code. @@ -198,13 +198,13 @@ def request(self, request, on_response=None, on_body=None): * `**kwargs` (dict): Forward compatibility kwargs. - An exception raise by this function will cause the stream to end in error. + An exception raise by this function will cause the HTTP stream to end in error. This callback is always invoked on the connection's event-loop thread. on_body: Optional callback invoked 0+ times as response body data is received. The function should take the following arguments and return nothing: - * `http_stream` (:class:`HttpClientStream`): Stream carrying + * `http_stream` (:class:`HttpClientStream`): HTTP stream carrying out this request/response exchange. * `chunk` (buffer): Response body data (not necessarily @@ -212,7 +212,7 @@ def request(self, request, on_response=None, on_body=None): * `**kwargs` (dict): Forward-compatibility kwargs. - An exception raise by this function will cause the stream to end in error. + An exception raise by this function will cause the HTTP stream to end in error. This callback is always invoked on the connection's event-loop thread. Returns: @@ -245,11 +245,11 @@ def _on_body(self, chunk): class HttpClientStream(HttpStreamBase): - """Stream that sends a request and receives a response. + """HTTP stream that sends a request and receives a response. Create an HttpClientStream with :meth:`HttpClientConnection.request()`. - NOTE: The stream sends no data until :meth:`HttpClientStream.activate()` + NOTE: The HTTP stream sends no data until :meth:`HttpClientStream.activate()` is called. Call activate() when you're ready for callbacks and events to fire. Attributes: @@ -288,7 +288,7 @@ def response_status_code(self): def activate(self): """Begin sending the request. - The stream does nothing until this is called. Call activate() when you + The HTTP stream does nothing until this is called. Call activate() when you are ready for its callbacks and events to fire. """ _awscrt.http_client_stream_activate(self) @@ -332,7 +332,7 @@ def headers(self): @property def body_stream(self): - """InputStream: Stream of outgoing body.""" + """InputStream: Binary stream of outgoing body.""" return _awscrt.http_message_get_body_stream(self._binding) @body_stream.setter @@ -352,7 +352,7 @@ class HttpRequest(HttpMessageBase): path (str): HTTP path-and-query value. Default value is "/". headers (Optional[HttpHeaders]): Optional headers. If None specified, an empty :class:`HttpHeaders` is created. - body_string(Optional[Union[InputStream, io.IOBase]]): Optional body as stream. + body_stream(Optional[Union[InputStream, io.IOBase]]): Optional body as binary stream. """ __slots__ = () diff --git a/awscrt/io.py b/awscrt/io.py index ba970a896..32da474b4 100644 --- a/awscrt/io.py +++ b/awscrt/io.py @@ -12,7 +12,6 @@ import _awscrt from awscrt import NativeResource from enum import IntEnum -import io import threading @@ -490,20 +489,47 @@ def is_alpn_available(): class InputStream(NativeResource): - """InputStream allows `awscrt` native code to read from Python I/O classes. + """InputStream allows `awscrt` native code to read from Python binary I/O classes. Args: - stream (io.IOBase): Python I/O stream to wrap. + stream (io.IOBase): Python binary I/O stream to wrap. """ - __slots__ = () + __slots__ = ('_stream') # TODO: Implement IOBase interface so Python can read from this class as well. def __init__(self, stream): - assert isinstance(stream, io.IOBase) + # duck-type instead of checking inheritance from IOBase. + # At the least, stream must have read() + if not callable(getattr(stream, 'read', None)): + raise TypeError('I/O stream type expected') assert not isinstance(stream, InputStream) super().__init__() - self._binding = _awscrt.input_stream_new(stream) + self._stream = stream + self._binding = _awscrt.input_stream_new(self) + + def _read_into_memoryview(self, m): + # Read into memoryview m. + # Return number of bytes read, or None if no data available. + try: + # prefer the most efficient read methods, + if hasattr(self._stream, 'readinto1'): + return self._stream.readinto1(m) + if hasattr(self._stream, 'readinto'): + return self._stream.readinto(m) + + if hasattr(self._stream, 'read1'): + data = self._stream.read1(len(m)) + else: + data = self._stream.read(len(m)) + n = len(data) + m[:n] = data + return n + except BlockingIOError: + return None + + def _seek(self, offset, whence): + return self._stream.seek(offset, whence) @classmethod def wrap(cls, stream, allow_none=False): @@ -511,7 +537,7 @@ def wrap(cls, stream, allow_none=False): Given some stream type, returns an :class:`InputStream`. Args: - stream (Union[io.IOBase, InputStream, None]): I/O stream to wrap. + stream (Union[io.IOBase, InputStream, None]): Binary I/O stream to wrap. allow_none (bool): Whether to allow `stream` to be None. If False (default), and `stream` is None, an exception is raised. @@ -520,10 +546,8 @@ def wrap(cls, stream, allow_none=False): Otherwise, an :class:`InputStream` which wraps the `stream` is returned. If `allow_none` is True, and `stream` is None, then None is returned. """ - if isinstance(stream, InputStream): - return stream - if isinstance(stream, io.IOBase): - return cls(stream) if stream is None and allow_none: return None - raise TypeError('I/O stream type expected') + if isinstance(stream, InputStream): + return stream + return cls(stream) diff --git a/source/io.c b/source/io.c index 7042262e9..44075a55e 100644 --- a/source/io.c +++ b/source/io.c @@ -627,13 +627,13 @@ struct aws_input_stream_py_impl { bool is_end_of_stream; - /* Dependencies that must outlive this */ - PyObject *io; + /* Weak reference proxy to python self. */ + PyObject *self_proxy; }; static void s_aws_input_stream_py_destroy(struct aws_input_stream *stream) { struct aws_input_stream_py_impl *impl = stream->impl; - Py_DECREF(impl->io); + Py_XDECREF(impl->self_proxy); aws_mem_release(stream->allocator, stream); } @@ -653,7 +653,7 @@ static int s_aws_input_stream_py_seek( return AWS_OP_ERR; /* Python has shut down. Nothing matters anymore, but don't crash */ } - method_result = PyObject_CallMethod(impl->io, "seek", "(li)", offset, basis); + method_result = PyObject_CallMethod(impl->self_proxy, "_seek", "(li)", offset, basis); if (!method_result) { aws_result = aws_py_raise_error(); goto done; @@ -689,7 +689,7 @@ int s_aws_input_stream_py_read(struct aws_input_stream *stream, struct aws_byte_ goto done; } - method_result = PyObject_CallMethod(impl->io, "readinto", "(O)", memory_view); + method_result = PyObject_CallMethod(impl->self_proxy, "_read_into_memoryview", "(O)", memory_view); if (!method_result) { aws_result = aws_py_raise_error(); goto done; @@ -745,9 +745,9 @@ static struct aws_input_stream_vtable s_aws_input_stream_py_vtable = { .destroy = s_aws_input_stream_py_destroy, }; -static struct aws_input_stream *aws_input_stream_new_from_py(PyObject *io) { +static struct aws_input_stream *aws_input_stream_new_from_py(PyObject *py_self) { - if (!io || (io == Py_None)) { + if (!py_self || (py_self == Py_None)) { aws_raise_error(AWS_ERROR_INVALID_ARGUMENT); return NULL; } @@ -761,10 +761,15 @@ static struct aws_input_stream *aws_input_stream_new_from_py(PyObject *io) { impl->base.allocator = alloc; impl->base.vtable = &s_aws_input_stream_py_vtable; impl->base.impl = impl; - impl->io = io; - Py_INCREF(impl->io); + impl->self_proxy = PyWeakref_NewProxy(py_self, NULL); + if (!impl->self_proxy) { + goto error; + } return &impl->base; +error: + aws_input_stream_destroy(&impl->base); + return NULL; } /** @@ -783,12 +788,12 @@ static void s_input_stream_capsule_destructor(PyObject *py_capsule) { PyObject *aws_py_input_stream_new(PyObject *self, PyObject *args) { (void)self; - PyObject *py_io; - if (!PyArg_ParseTuple(args, "O", &py_io)) { + PyObject *py_self; + if (!PyArg_ParseTuple(args, "O", &py_self)) { return NULL; } - struct aws_input_stream *stream = aws_input_stream_new_from_py(py_io); + struct aws_input_stream *stream = aws_input_stream_new_from_py(py_self); if (!stream) { return PyErr_AwsLastError(); } diff --git a/test/test_io.py b/test/test_io.py index a24fc8921..6ddf087fb 100644 --- a/test/test_io.py +++ b/test/test_io.py @@ -2,8 +2,9 @@ # SPDX-License-Identifier: Apache-2.0. from __future__ import absolute_import -from awscrt.io import ClientBootstrap, ClientTlsContext, DefaultHostResolver, EventLoopGroup, TlsConnectionOptions, TlsContextOptions +from awscrt.io import ClientBootstrap, ClientTlsContext, DefaultHostResolver, EventLoopGroup, InputStream, TlsConnectionOptions, TlsContextOptions from test import NativeResourceTest, TIMEOUT +import io import unittest @@ -84,5 +85,61 @@ def test_server_name(self): conn_opt.set_server_name('localhost') +class MockPythonStream: + """For testing duck-typed stream classes. + Doesn't inherit from io.IOBase. Doesn't implement readinto()""" + + def __init__(self, src_data): + self.data = bytes(src_data) + self.len = len(src_data) + self.pos = 0 + + def seek(self, where): + self.pos = where + + def tell(self): + return self.pos + + def read(self, amount=None): + if amount is None: + amount = self.len - self.pos + else: + amount = min(amount, self.len - self.pos) + prev_pos = self.pos + self.pos += amount + return self.data[prev_pos: self.pos] + + +class InputStreamTest(NativeResourceTest): + def _test(self, python_stream, expected): + input_stream = InputStream(python_stream) + result = bytearray() + fixed_mv_len = 4 + fixed_mv = memoryview(bytearray(fixed_mv_len)) + while True: + read_len = input_stream._read_into_memoryview(fixed_mv) + if read_len is None: + continue + if read_len == 0: + break + if read_len > 0: + self.assertLessEqual(read_len, fixed_mv_len) + result += fixed_mv[0:read_len] + + self.assertEqual(expected, result) + + def test_read_official_io(self): + # Read from a class defined in the io module + src_data = b'a long string here' + python_stream = io.BytesIO(src_data) + self._test(python_stream, src_data) + + def test_read_duck_typed_io(self): + # Read from a class defined in the io module + src_data = b'a man a can a planal canada' + python_stream = MockPythonStream(src_data) + self._test(python_stream, src_data) + + if __name__ == '__main__': unittest.main()