Skip to content

Commit

Permalink
Signals on_headers_received and on_content_received aio-libs#2313
Browse files Browse the repository at this point in the history
  • Loading branch information
pfreixes committed Oct 25, 2017
1 parent 3115fac commit 5ba42d9
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 15 deletions.
26 changes: 23 additions & 3 deletions aiohttp/_http_parser.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ cdef class HttpParser:
object _last_error
bint _auto_decompress

object _on_headers_received
object _on_content_received
object _trace_context

Py_buffer py_buf

def __cinit__(self):
Expand All @@ -82,7 +86,9 @@ cdef class HttpParser:
object protocol, object loop, object timer=None,
size_t max_line_size=8190, size_t max_headers=32768,
size_t max_field_size=8190, payload_exception=None,
response_with_body=True, auto_decompress=True):
response_with_body=True, auto_decompress=True,
on_headers_received=None, on_content_received=None,
trace_context=None):
cparser.http_parser_init(self._cparser, mode)
self._cparser.data = <void*>self
self._cparser.content_length = 0
Expand Down Expand Up @@ -122,6 +128,10 @@ cdef class HttpParser:
self._csettings.on_chunk_header = cb_on_chunk_header
self._csettings.on_chunk_complete = cb_on_chunk_complete

self._on_headers_received = on_headers_received
self._on_content_received = on_content_received
self._trace_context = trace_context

self._last_error = None

cdef _process_header(self):
Expand Down Expand Up @@ -215,10 +225,16 @@ cdef class HttpParser:

self._messages.append((msg, payload))

if self._on_headers_received is not None:
self._on_headers_received.send(self._trace_context)

cdef _on_message_complete(self):
self._payload.feed_eof()
self._payload = None

if self._on_content_received is not None:
self._on_content_received.send(self._trace_context)

cdef _on_chunk_header(self):
self._payload.begin_http_chunk_receiving()

Expand Down Expand Up @@ -339,10 +355,14 @@ cdef class HttpResponseParserC(HttpParser):
size_t max_line_size=8190, size_t max_headers=32768,
size_t max_field_size=8190, payload_exception=None,
response_with_body=True, read_until_eof=False,
auto_decompress=True):
auto_decompress=True, on_headers_received=None,
on_content_received=None, trace_context=None):
self._init(cparser.HTTP_RESPONSE, protocol, loop, timer,
max_line_size, max_headers, max_field_size,
payload_exception, response_with_body, auto_decompress)
payload_exception, response_with_body, auto_decompress,
on_headers_received=on_headers_received,
on_content_received=on_content_received,
trace_context=trace_context)

cdef object _on_status_complete(self):
if self._buf:
Expand Down
7 changes: 2 additions & 5 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,6 @@ def __init__(self, *, connector=None, loop=None, cookies=None,
self._on_request_start = Signal()
self._on_request_end = Signal()
self._on_request_exception = Signal()

self._on_request_queued_start = FuncSignal()
self._on_request_queued_start = FuncSignal()
self._on_request_createconn_start = FuncSignal()
self._on_request_createconn_end = FuncSignal()
self._on_request_redirect = FuncSignal()

self._on_request_headers_sent = FuncSignal()
Expand Down Expand Up @@ -297,6 +292,8 @@ def _request(self, method, url, *,
ssl_context=ssl_context, proxy_headers=proxy_headers,
on_headers_sent=self.on_request_headers_sent,
on_content_sent=self.on_request_content_sent,
on_headers_received=self.on_request_headers_received,
on_content_received=self.on_request_content_received,
trace_context=trace_context)

# connection timeout
Expand Down
10 changes: 8 additions & 2 deletions aiohttp/client_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,15 +129,21 @@ def set_response_params(self, *, timer=None,
skip_payload=False,
skip_status_codes=(),
read_until_eof=False,
auto_decompress=True):
auto_decompress=True,
on_headers_received=None,
on_content_received=None,
trace_context=None):
self._skip_payload = skip_payload
self._skip_status_codes = skip_status_codes
self._read_until_eof = read_until_eof
self._parser = HttpResponseParser(
self, self._loop, timer=timer,
payload_exception=ClientPayloadError,
read_until_eof=read_until_eof,
auto_decompress=auto_decompress)
auto_decompress=auto_decompress,
on_headers_received=on_headers_received,
on_content_received=on_content_received,
trace_context=trace_context)

if self._tail:
data, self._tail = self._tail, b''
Expand Down
20 changes: 16 additions & 4 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ def __init__(self, method, url, *,
timer=None, session=None, auto_decompress=True,
verify_ssl=None, fingerprint=None, ssl_context=None,
proxy_headers=None, on_headers_sent=None,
on_content_sent=None, trace_context=None):
on_content_sent=None, on_headers_received=None,
on_content_received=None, trace_context=None):

if verify_ssl is False and ssl_context is not None:
raise ValueError(
Expand Down Expand Up @@ -121,6 +122,8 @@ def __init__(self, method, url, *,

self._on_headers_sent = on_headers_sent
self._on_content_sent = on_content_sent
self._on_headers_received = on_headers_received
self._on_content_received = on_content_received
self._trace_context = trace_context

if loop.get_debug():
Expand Down Expand Up @@ -482,7 +485,10 @@ def send(self, conn):
writer=self._writer, continue100=self._continue, timer=self._timer,
request_info=self.request_info,
auto_decompress=self._auto_decompress,
session=self._session, trace_context=self._trace_context
session=self._session,
on_headers_received=self._on_headers_received,
on_content_received=self._on_content_received,
trace_context=self._trace_context
)

self.response._post_init(self.loop, self._session)
Expand Down Expand Up @@ -527,7 +533,8 @@ class ClientResponse(HeadersMixin):
def __init__(self, method, url, *,
writer=None, continue100=None, timer=None,
request_info=None, auto_decompress=True,
session=None, trace_context=None):
session=None, on_headers_received=None,
on_content_received=None, trace_context=None):
assert isinstance(url, URL)

self.method = method
Expand All @@ -543,6 +550,9 @@ def __init__(self, method, url, *,
self._request_info = request_info
self._timer = timer if timer is not None else TimerNoop()
self._auto_decompress = auto_decompress
self._on_headers_received = on_headers_received
self._on_content_received = on_content_received
self._trace_context = trace_context

@property
def url(self):
Expand Down Expand Up @@ -629,7 +639,9 @@ def start(self, connection, read_until_eof=False):
skip_payload=self.method.lower() == 'head',
skip_status_codes=(204, 304),
read_until_eof=read_until_eof,
auto_decompress=self._auto_decompress)
auto_decompress=self._auto_decompress,
on_headers_received=self._on_headers_received,
on_content_received=self._on_content_received)

with self._timer:
while True:
Expand Down
14 changes: 13 additions & 1 deletion aiohttp/http_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ def __init__(self, protocol=None, loop=None,
timer=None, code=None, method=None, readall=False,
payload_exception=None,
response_with_body=True, read_until_eof=False,
auto_decompress=True):
auto_decompress=True,
on_headers_received=None, on_content_received=None,
trace_context=None):
self.protocol = protocol
self.loop = loop
self.max_line_size = max_line_size
Expand All @@ -88,6 +90,10 @@ def __init__(self, protocol=None, loop=None,
self._payload_parser = None
self._auto_decompress = auto_decompress

self._on_headers_received = on_headers_received
self._on_content_received = on_content_received
self._trace_context = trace_context

def feed_eof(self):
if self._payload_parser is not None:
self._payload_parser.feed_eof()
Expand Down Expand Up @@ -143,6 +149,9 @@ def feed_data(self, data,
finally:
self._lines.clear()

if self._on_headers_received is not None:
self._on_headers_received.send(self._trace_context)

# payload length
length = msg.headers.get(CONTENT_LENGTH)
if length is not None:
Expand Down Expand Up @@ -240,6 +249,9 @@ def feed_data(self, data,
else:
data = EMPTY

if self._on_headers_received is not None:
self._on_content_received.send(self._trace_context)

return messages, self._upgraded, data

def parse_headers(self, lines):
Expand Down
8 changes: 8 additions & 0 deletions tests/test_client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,10 +596,18 @@ def __init__(self, *args, **kwargs):
super(MyClientRequest, self).__init__(*args, **kwargs)
MyClientRequest.on_headers_sent = self._on_headers_sent
MyClientRequest.on_content_sent = self._on_content_sent
MyClientRequest.on_headers_received = self._on_headers_received
MyClientRequest.on_content_received = self._on_content_received
MyClientRequest.trace_context = self._trace_context

trace_context = mock.Mock()

session = aiohttp.ClientSession(loop=loop, request_class=MyClientRequest)
yield from session.get('http://example.com', trace_context=trace_context)
assert MyClientRequest.on_headers_sent == session.on_request_headers_sent
assert MyClientRequest.on_content_sent == session.on_request_content_sent
assert MyClientRequest.on_headers_received ==\
session.on_request_headers_received
assert MyClientRequest.on_content_received ==\
session.on_request_content_received
assert MyClientRequest.trace_context == trace_context
23 changes: 23 additions & 0 deletions tests/test_http_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,29 @@ def test_url_parse_non_strict_mode(parser):
assert payload.is_eof()


@pytest.mark.parametrize("parser_cls", RESPONSE_PARSERS)
def test_tracing_response(protocol, loop, parser_cls):
trace_context = mock.Mock()
on_headers_received = mock.Mock()
on_content_received = mock.Mock()
parser = parser_cls(
protocol,
loop,
on_headers_received=on_headers_received,
on_content_received=on_content_received,
trace_context=trace_context
)

headers = b'HTTP/1.1 200 Ok\r\nContent-Length: 4\r\n\r\n'
parser.feed_data(headers)
on_headers_received.send.assert_called_with(trace_context)
assert not on_content_received.called

body = b'body'
parser.feed_data(body)
on_content_received.send.assert_called_with(trace_context)


class TestParsePayload(unittest.TestCase):

def setUp(self):
Expand Down

0 comments on commit 5ba42d9

Please sign in to comment.