Skip to content

Commit

Permalink
Accept duck typed python streams. (#185)
Browse files Browse the repository at this point in the history
  • Loading branch information
graebm authored Oct 20, 2020
1 parent e67a5e2 commit 53d9cff
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 35 deletions.
20 changes: 10 additions & 10 deletions awscrt/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def port(self):
def request(self, request, on_response=None, on_body=None):
"""Create :class:`HttpClientStream` to carry out the request/response exchange.
NOTE: The stream sends no data until :meth:`HttpClientStream.activate()`
NOTE: The HTTP stream sends no data until :meth:`HttpClientStream.activate()`
is called. Call activate() when you're ready for callbacks and events to fire.
Args:
Expand All @@ -188,7 +188,7 @@ def request(self, request, on_response=None, on_body=None):
on_response: Optional callback invoked once main response headers are received.
The function should take the following arguments and return nothing:
* `http_stream` (:class:`HttpClientStream`): Stream carrying
* `http_stream` (:class:`HttpClientStream`): HTTP stream carrying
out this request/response exchange.
* `status_code` (int): Response status code.
Expand All @@ -198,21 +198,21 @@ def request(self, request, on_response=None, on_body=None):
* `**kwargs` (dict): Forward compatibility kwargs.
An exception raise by this function will cause the stream to end in error.
An exception raise by this function will cause the HTTP stream to end in error.
This callback is always invoked on the connection's event-loop thread.
on_body: Optional callback invoked 0+ times as response body data is received.
The function should take the following arguments and return nothing:
* `http_stream` (:class:`HttpClientStream`): Stream carrying
* `http_stream` (:class:`HttpClientStream`): HTTP stream carrying
out this request/response exchange.
* `chunk` (buffer): Response body data (not necessarily
a whole "chunk" of chunked encoding).
* `**kwargs` (dict): Forward-compatibility kwargs.
An exception raise by this function will cause the stream to end in error.
An exception raise by this function will cause the HTTP stream to end in error.
This callback is always invoked on the connection's event-loop thread.
Returns:
Expand Down Expand Up @@ -245,11 +245,11 @@ def _on_body(self, chunk):


class HttpClientStream(HttpStreamBase):
"""Stream that sends a request and receives a response.
"""HTTP stream that sends a request and receives a response.
Create an HttpClientStream with :meth:`HttpClientConnection.request()`.
NOTE: The stream sends no data until :meth:`HttpClientStream.activate()`
NOTE: The HTTP stream sends no data until :meth:`HttpClientStream.activate()`
is called. Call activate() when you're ready for callbacks and events to fire.
Attributes:
Expand Down Expand Up @@ -288,7 +288,7 @@ def response_status_code(self):
def activate(self):
"""Begin sending the request.
The stream does nothing until this is called. Call activate() when you
The HTTP stream does nothing until this is called. Call activate() when you
are ready for its callbacks and events to fire.
"""
_awscrt.http_client_stream_activate(self)
Expand Down Expand Up @@ -332,7 +332,7 @@ def headers(self):

@property
def body_stream(self):
"""InputStream: Stream of outgoing body."""
"""InputStream: Binary stream of outgoing body."""
return _awscrt.http_message_get_body_stream(self._binding)

@body_stream.setter
Expand All @@ -352,7 +352,7 @@ class HttpRequest(HttpMessageBase):
path (str): HTTP path-and-query value. Default value is "/".
headers (Optional[HttpHeaders]): Optional headers. If None specified,
an empty :class:`HttpHeaders` is created.
body_string(Optional[Union[InputStream, io.IOBase]]): Optional body as stream.
body_stream(Optional[Union[InputStream, io.IOBase]]): Optional body as binary stream.
"""

__slots__ = ()
Expand Down
48 changes: 36 additions & 12 deletions awscrt/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import _awscrt
from awscrt import NativeResource
from enum import IntEnum
import io
import threading


Expand Down Expand Up @@ -490,28 +489,55 @@ def is_alpn_available():


class InputStream(NativeResource):
"""InputStream allows `awscrt` native code to read from Python I/O classes.
"""InputStream allows `awscrt` native code to read from Python binary I/O classes.
Args:
stream (io.IOBase): Python I/O stream to wrap.
stream (io.IOBase): Python binary I/O stream to wrap.
"""
__slots__ = ()
__slots__ = ('_stream')
# TODO: Implement IOBase interface so Python can read from this class as well.

def __init__(self, stream):
assert isinstance(stream, io.IOBase)
# duck-type instead of checking inheritance from IOBase.
# At the least, stream must have read()
if not callable(getattr(stream, 'read', None)):
raise TypeError('I/O stream type expected')
assert not isinstance(stream, InputStream)

super().__init__()
self._binding = _awscrt.input_stream_new(stream)
self._stream = stream
self._binding = _awscrt.input_stream_new(self)

def _read_into_memoryview(self, m):
# Read into memoryview m.
# Return number of bytes read, or None if no data available.
try:
# prefer the most efficient read methods,
if hasattr(self._stream, 'readinto1'):
return self._stream.readinto1(m)
if hasattr(self._stream, 'readinto'):
return self._stream.readinto(m)

if hasattr(self._stream, 'read1'):
data = self._stream.read1(len(m))
else:
data = self._stream.read(len(m))
n = len(data)
m[:n] = data
return n
except BlockingIOError:
return None

def _seek(self, offset, whence):
return self._stream.seek(offset, whence)

@classmethod
def wrap(cls, stream, allow_none=False):
"""
Given some stream type, returns an :class:`InputStream`.
Args:
stream (Union[io.IOBase, InputStream, None]): I/O stream to wrap.
stream (Union[io.IOBase, InputStream, None]): Binary I/O stream to wrap.
allow_none (bool): Whether to allow `stream` to be None.
If False (default), and `stream` is None, an exception is raised.
Expand All @@ -520,10 +546,8 @@ def wrap(cls, stream, allow_none=False):
Otherwise, an :class:`InputStream` which wraps the `stream` is returned.
If `allow_none` is True, and `stream` is None, then None is returned.
"""
if isinstance(stream, InputStream):
return stream
if isinstance(stream, io.IOBase):
return cls(stream)
if stream is None and allow_none:
return None
raise TypeError('I/O stream type expected')
if isinstance(stream, InputStream):
return stream
return cls(stream)
29 changes: 17 additions & 12 deletions source/io.c
Original file line number Diff line number Diff line change
Expand Up @@ -627,13 +627,13 @@ struct aws_input_stream_py_impl {

bool is_end_of_stream;

/* Dependencies that must outlive this */
PyObject *io;
/* Weak reference proxy to python self. */
PyObject *self_proxy;
};

static void s_aws_input_stream_py_destroy(struct aws_input_stream *stream) {
struct aws_input_stream_py_impl *impl = stream->impl;
Py_DECREF(impl->io);
Py_XDECREF(impl->self_proxy);
aws_mem_release(stream->allocator, stream);
}

Expand All @@ -653,7 +653,7 @@ static int s_aws_input_stream_py_seek(
return AWS_OP_ERR; /* Python has shut down. Nothing matters anymore, but don't crash */
}

method_result = PyObject_CallMethod(impl->io, "seek", "(li)", offset, basis);
method_result = PyObject_CallMethod(impl->self_proxy, "_seek", "(li)", offset, basis);
if (!method_result) {
aws_result = aws_py_raise_error();
goto done;
Expand Down Expand Up @@ -689,7 +689,7 @@ int s_aws_input_stream_py_read(struct aws_input_stream *stream, struct aws_byte_
goto done;
}

method_result = PyObject_CallMethod(impl->io, "readinto", "(O)", memory_view);
method_result = PyObject_CallMethod(impl->self_proxy, "_read_into_memoryview", "(O)", memory_view);
if (!method_result) {
aws_result = aws_py_raise_error();
goto done;
Expand Down Expand Up @@ -745,9 +745,9 @@ static struct aws_input_stream_vtable s_aws_input_stream_py_vtable = {
.destroy = s_aws_input_stream_py_destroy,
};

static struct aws_input_stream *aws_input_stream_new_from_py(PyObject *io) {
static struct aws_input_stream *aws_input_stream_new_from_py(PyObject *py_self) {

if (!io || (io == Py_None)) {
if (!py_self || (py_self == Py_None)) {
aws_raise_error(AWS_ERROR_INVALID_ARGUMENT);
return NULL;
}
Expand All @@ -761,10 +761,15 @@ static struct aws_input_stream *aws_input_stream_new_from_py(PyObject *io) {
impl->base.allocator = alloc;
impl->base.vtable = &s_aws_input_stream_py_vtable;
impl->base.impl = impl;
impl->io = io;
Py_INCREF(impl->io);
impl->self_proxy = PyWeakref_NewProxy(py_self, NULL);
if (!impl->self_proxy) {
goto error;
}

return &impl->base;
error:
aws_input_stream_destroy(&impl->base);
return NULL;
}

/**
Expand All @@ -783,12 +788,12 @@ static void s_input_stream_capsule_destructor(PyObject *py_capsule) {
PyObject *aws_py_input_stream_new(PyObject *self, PyObject *args) {
(void)self;

PyObject *py_io;
if (!PyArg_ParseTuple(args, "O", &py_io)) {
PyObject *py_self;
if (!PyArg_ParseTuple(args, "O", &py_self)) {
return NULL;
}

struct aws_input_stream *stream = aws_input_stream_new_from_py(py_io);
struct aws_input_stream *stream = aws_input_stream_new_from_py(py_self);
if (!stream) {
return PyErr_AwsLastError();
}
Expand Down
59 changes: 58 additions & 1 deletion test/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
# SPDX-License-Identifier: Apache-2.0.

from __future__ import absolute_import
from awscrt.io import ClientBootstrap, ClientTlsContext, DefaultHostResolver, EventLoopGroup, TlsConnectionOptions, TlsContextOptions
from awscrt.io import ClientBootstrap, ClientTlsContext, DefaultHostResolver, EventLoopGroup, InputStream, TlsConnectionOptions, TlsContextOptions
from test import NativeResourceTest, TIMEOUT
import io
import unittest


Expand Down Expand Up @@ -84,5 +85,61 @@ def test_server_name(self):
conn_opt.set_server_name('localhost')


class MockPythonStream:
"""For testing duck-typed stream classes.
Doesn't inherit from io.IOBase. Doesn't implement readinto()"""

def __init__(self, src_data):
self.data = bytes(src_data)
self.len = len(src_data)
self.pos = 0

def seek(self, where):
self.pos = where

def tell(self):
return self.pos

def read(self, amount=None):
if amount is None:
amount = self.len - self.pos
else:
amount = min(amount, self.len - self.pos)
prev_pos = self.pos
self.pos += amount
return self.data[prev_pos: self.pos]


class InputStreamTest(NativeResourceTest):
def _test(self, python_stream, expected):
input_stream = InputStream(python_stream)
result = bytearray()
fixed_mv_len = 4
fixed_mv = memoryview(bytearray(fixed_mv_len))
while True:
read_len = input_stream._read_into_memoryview(fixed_mv)
if read_len is None:
continue
if read_len == 0:
break
if read_len > 0:
self.assertLessEqual(read_len, fixed_mv_len)
result += fixed_mv[0:read_len]

self.assertEqual(expected, result)

def test_read_official_io(self):
# Read from a class defined in the io module
src_data = b'a long string here'
python_stream = io.BytesIO(src_data)
self._test(python_stream, src_data)

def test_read_duck_typed_io(self):
# Read from a class defined in the io module
src_data = b'a man a can a planal canada'
python_stream = MockPythonStream(src_data)
self._test(python_stream, src_data)


if __name__ == '__main__':
unittest.main()

0 comments on commit 53d9cff

Please sign in to comment.