From 071cf455c1a309b2481f8d977c03b0dda283a2d9 Mon Sep 17 00:00:00 2001 From: Pau Freixes Date: Wed, 31 Jan 2018 00:13:16 +0100 Subject: [PATCH] Use custom classes to pass client signals parameters #2686 This will allow aiohttp development add new parameters in the future without break the signals signature. --- CHANGES/2686.feature | 1 + aiohttp/connector.py | 14 +-- aiohttp/tracing.py | 163 ++++++++++++++++++++++++++--------- docs/client_advanced.rst | 10 +-- docs/tracing_reference.rst | 159 +++++++++++++++++++++++++++++----- tests/test_client_session.py | 35 ++++---- tests/test_connector.py | 39 ++++++--- 7 files changed, 318 insertions(+), 103 deletions(-) create mode 100644 CHANGES/2686.feature diff --git a/CHANGES/2686.feature b/CHANGES/2686.feature new file mode 100644 index 00000000000..eab2eb89b6f --- /dev/null +++ b/CHANGES/2686.feature @@ -0,0 +1 @@ +Use custom classes to pass client signals parameters diff --git a/aiohttp/connector.py b/aiohttp/connector.py index 7eed3168eca..acb898f7cd1 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -660,14 +660,14 @@ async def _resolve_host(self, host, port, traces=None): if traces: for trace in traces: - await trace.send_dns_resolvehost_start() + await trace.send_dns_resolvehost_start(host) res = (await self._resolver.resolve( host, port, family=self._family)) if traces: for trace in traces: - await trace.send_dns_resolvehost_end() + await trace.send_dns_resolvehost_end(host) return res @@ -678,26 +678,26 @@ async def _resolve_host(self, host, port, traces=None): if traces: for trace in traces: - await trace.send_dns_cache_hit() + await trace.send_dns_cache_hit(host) return self._cached_hosts.next_addrs(key) if key in self._throttle_dns_events: if traces: for trace in traces: - await trace.send_dns_cache_hit() + await trace.send_dns_cache_hit(host) await self._throttle_dns_events[key].wait() else: if traces: for trace in traces: - await trace.send_dns_cache_miss() + await trace.send_dns_cache_miss(host) self._throttle_dns_events[key] = \ EventResultOrError(self._loop) try: if traces: for trace in traces: - await trace.send_dns_resolvehost_start() + await trace.send_dns_resolvehost_start(host) addrs = await \ asyncio.shield(self._resolver.resolve(host, @@ -706,7 +706,7 @@ async def _resolve_host(self, host, port, traces=None): loop=self._loop) if traces: for trace in traces: - await trace.send_dns_resolvehost_end() + await trace.send_dns_resolvehost_end(host) self._cached_hosts.add(key, addrs) self._throttle_dns_events[key].set() diff --git a/aiohttp/tracing.py b/aiohttp/tracing.py index cbfd1597352..f6397e66eb6 100644 --- a/aiohttp/tracing.py +++ b/aiohttp/tracing.py @@ -1,9 +1,21 @@ from types import SimpleNamespace +import attr +from multidict import CIMultiDict +from yarl import URL + +from .client_reqrep import ClientResponse from .signals import Signal -__all__ = ('TraceConfig',) +__all__ = ( + 'TraceConfig', 'TraceRequestStartParams', 'TraceRequestEndParams', + 'TraceRequestExceptionParams', 'TraceConnectionQueuedStartParams', + 'TraceConnectionQueuedEndParams', 'TraceConnectionCreateStartParams', + 'TraceConnectionCreateEndParams', 'TraceConnectionReuseconnParams', + 'TraceDnsResolveHostStartParams', 'TraceDnsResolveHostEndParams', + 'TraceDnsCacheHitParams', 'TraceDnsCacheMissParams' +) class TraceConfig: @@ -100,6 +112,90 @@ def on_dns_cache_miss(self): return self._on_dns_cache_miss +@attr.s(frozen=True, slots=True) +class TraceRequestStartParams: + """ Parameters sent by the `on_request_start` signal""" + method = attr.ib(type=str) + url = attr.ib(type=URL) + headers = attr.ib(type=CIMultiDict) + + +@attr.s(frozen=True, slots=True) +class TraceRequestEndParams: + """ Parameters sent by the `on_request_end` signal""" + method = attr.ib(type=str) + url = attr.ib(type=URL) + headers = attr.ib(type=CIMultiDict) + resp = attr.ib(type=ClientResponse) + + +@attr.s(frozen=True, slots=True) +class TraceRequestExceptionParams: + """ Parameters sent by the `on_request_exception` signal""" + method = attr.ib(type=str) + url = attr.ib(type=URL) + headers = attr.ib(type=CIMultiDict) + exception = attr.ib(type=Exception) + + +@attr.s(frozen=True, slots=True) +class TraceRequestRedirectParams: + """ Parameters sent by the `on_request_redirect` signal""" + method = attr.ib(type=str) + url = attr.ib(type=URL) + headers = attr.ib(type=CIMultiDict) + resp = attr.ib(type=ClientResponse) + + +@attr.s(frozen=True, slots=True) +class TraceConnectionQueuedStartParams: + """ Parameters sent by the `on_connection_queued_start` signal""" + + +@attr.s(frozen=True, slots=True) +class TraceConnectionQueuedEndParams: + """ Parameters sent by the `on_connection_queued_end` signal""" + + +@attr.s(frozen=True, slots=True) +class TraceConnectionCreateStartParams: + """ Parameters sent by the `on_connection_create_start` signal""" + + +@attr.s(frozen=True, slots=True) +class TraceConnectionCreateEndParams: + """ Parameters sent by the `on_connection_create_end` signal""" + + +@attr.s(frozen=True, slots=True) +class TraceConnectionReuseconnParams: + """ Parameters sent by the `on_connection_reuseconn` signal""" + + +@attr.s(frozen=True, slots=True) +class TraceDnsResolveHostStartParams: + """ Parameters sent by the `on_dns_resolvehost_start` signal""" + host = attr.ib(type=str) + + +@attr.s(frozen=True, slots=True) +class TraceDnsResolveHostEndParams: + """ Parameters sent by the `on_dns_resolvehost_end` signal""" + host = attr.ib(type=str) + + +@attr.s(frozen=True, slots=True) +class TraceDnsCacheHitParams: + """ Parameters sent by the `on_dns_cache_hit` signal""" + host = attr.ib(type=str) + + +@attr.s(frozen=True, slots=True) +class TraceDnsCacheMissParams: + """ Parameters sent by the `on_dns_cache_miss` signal""" + host = attr.ib(type=str) + + class Trace: """ Internal class used to keep together the main dependencies used at the moment of send a signal.""" @@ -109,106 +205,93 @@ def __init__(self, session, trace_config, trace_config_ctx): self._trace_config_ctx = trace_config_ctx self._session = session - async def send_request_start(self, *args, **kwargs): + async def send_request_start(self, method, url, headers): return await self._trace_config.on_request_start.send( self._session, self._trace_config_ctx, - *args, - **kwargs + TraceRequestStartParams(method, url, headers) ) - async def send_request_end(self, *args, **kwargs): + async def send_request_end(self, method, url, headers, response): return await self._trace_config.on_request_end.send( self._session, self._trace_config_ctx, - *args, - **kwargs + TraceRequestEndParams(method, url, headers, response) ) - async def send_request_exception(self, *args, **kwargs): + async def send_request_exception(self, method, url, headers, exception): return await self._trace_config.on_request_exception.send( self._session, self._trace_config_ctx, - *args, - **kwargs + TraceRequestExceptionParams(method, url, headers, exception) ) - async def send_request_redirect(self, *args, **kwargs): + async def send_request_redirect(self, method, url, headers, response): return await self._trace_config._on_request_redirect.send( self._session, self._trace_config_ctx, - *args, - **kwargs + TraceRequestRedirectParams(method, url, headers, response) ) - async def send_connection_queued_start(self, *args, **kwargs): + async def send_connection_queued_start(self): return await self._trace_config.on_connection_queued_start.send( self._session, self._trace_config_ctx, - *args, - **kwargs + TraceConnectionQueuedStartParams() ) - async def send_connection_queued_end(self, *args, **kwargs): + async def send_connection_queued_end(self): return await self._trace_config.on_connection_queued_end.send( self._session, self._trace_config_ctx, - *args, - **kwargs + TraceConnectionQueuedEndParams() ) - async def send_connection_create_start(self, *args, **kwargs): + async def send_connection_create_start(self): return await self._trace_config.on_connection_create_start.send( self._session, self._trace_config_ctx, - *args, - **kwargs + TraceConnectionCreateStartParams() ) - async def send_connection_create_end(self, *args, **kwargs): + async def send_connection_create_end(self): return await self._trace_config.on_connection_create_end.send( self._session, self._trace_config_ctx, - *args, - **kwargs + TraceConnectionCreateEndParams() ) - async def send_connection_reuseconn(self, *args, **kwargs): + async def send_connection_reuseconn(self): return await self._trace_config.on_connection_reuseconn.send( self._session, self._trace_config_ctx, - *args, - **kwargs + TraceConnectionReuseconnParams() ) - async def send_dns_resolvehost_start(self, *args, **kwargs): + async def send_dns_resolvehost_start(self, host): return await self._trace_config.on_dns_resolvehost_start.send( self._session, self._trace_config_ctx, - *args, - **kwargs + TraceDnsResolveHostStartParams(host) ) - async def send_dns_resolvehost_end(self, *args, **kwargs): + async def send_dns_resolvehost_end(self, host): return await self._trace_config.on_dns_resolvehost_end.send( self._session, self._trace_config_ctx, - *args, - **kwargs + TraceDnsResolveHostEndParams(host) ) - async def send_dns_cache_hit(self, *args, **kwargs): + async def send_dns_cache_hit(self, host): return await self._trace_config.on_dns_cache_hit.send( self._session, self._trace_config_ctx, - *args, - **kwargs + TraceDnsCacheHitParams(host) ) - async def send_dns_cache_miss(self, *args, **kwargs): + async def send_dns_cache_miss(self, host): return await self._trace_config.on_dns_cache_miss.send( self._session, self._trace_config_ctx, - *args, - **kwargs + TraceDnsCacheMissParams(host) ) diff --git a/docs/client_advanced.rst b/docs/client_advanced.rst index 2c4b19624ee..135877252a8 100644 --- a/docs/client_advanced.rst +++ b/docs/client_advanced.rst @@ -225,10 +225,10 @@ disabled. The following snippet shows how the start and the end signals of a request flow can be followed:: async def on_request_start( - session, trace_config_ctx, method, host, port, headers): + session, trace_config_ctx, params): print("Starting request") - async def on_request_end(session, trace_config_ctx, resp): + async def on_request_end(session, trace_config_ctx, params): print("Ending request") trace_config = aiohttp.TraceConfig() @@ -259,10 +259,10 @@ share the state through to the different signals that belong to the same request and to the same :class:`TraceConfig` class, perhaps:: async def on_request_start( - session, trace_config_ctx, method, host, port, headers): + session, trace_config_ctx, params): trace_config_ctx.start = session.loop.time() - async def on_request_end(session, trace_config_ctx, resp): + async def on_request_end(session, trace_config_ctx, params): elapsed = session.loop.time() - trace_config_ctx.start print("Request took {}".format(elapsed)) @@ -280,7 +280,7 @@ factory. This param is useful to pass data that is only available at request time, perhaps:: async def on_request_start( - session, trace_config_ctx, method, host, port, headers): + session, trace_config_ctx, params): print(trace_config_ctx.trace_request_ctx) diff --git a/docs/tracing_reference.rst b/docs/tracing_reference.rst index 71689bde02f..af381155156 100644 --- a/docs/tracing_reference.rst +++ b/docs/tracing_reference.rst @@ -29,10 +29,10 @@ the request flow. .. attribute:: on_request_start Property that gives access to the signals that will be executed when a - request starts, based on the :class:`~signals.Signal` implementation. + request starts, based on the :class:`aiohttp.signals.Signal` implementation. The coroutines listening will receive as a param the ``session``, - ``trace_config_ctx``, ``method``, ``url`` and ``headers``. + ``trace_config_ctx`` and :class:`aiohttp.TraceRequestStartParams` params. .. versionadded:: 3.0 @@ -42,7 +42,7 @@ the request flow. redirect happens during a request flow. The coroutines that are listening will receive the ``session``, - ``trace_config_ctx``, ``method``, ``url``, ``headers`` and ``resp`` params. + ``trace_config_ctx`` and :class:`aiohttp.TraceRequestRedirectParams` params. .. versionadded:: 3.0 @@ -52,7 +52,7 @@ the request flow. request ends. The coroutines that are listening will receive the ``session``, - ``trace_config_ctx``, ``method``, ``url``, ``headers`` and ``resp`` params + ``trace_config_ctx`` and :class:`aiohttp.TraceRequestEndParams` params. .. versionadded:: 3.0 @@ -62,7 +62,7 @@ the request flow. request finishes with an exception. The coroutines listening will receive the ``session``, - ``trace_config_ctx``, ``method``, ``url``, ``headers`` and ``exception`` params. + ``trace_config_ctx`` and :class:`aiohttp.TraceRequestExceptionParams` params. .. versionadded:: 3.0 @@ -71,8 +71,8 @@ the request flow. Property that gives access to the signals that will be executed when a request has been queued waiting for an available connection. - The coroutines that are listening will receive the ``session`` and - ``trace_config_ctx`` params. + The coroutines that are listening will receive the ``session``, + ``trace_config_ctx`` and :class:`aiohttp.TraceConnectionQueuedStartParams` params. .. versionadded:: 3.0 @@ -81,8 +81,8 @@ the request flow. Property that gives access to the signals that will be executed when a request that was queued already has an available connection. - The coroutines that are listening will receive the ``session`` and - ``trace_config_ctx`` params. + The coroutines that are listening will receive the ``session``, + ``trace_config_ctx`` and :class:`aiohttp.TraceConnectionQueuedEndParams` params. .. versionadded:: 3.0 @@ -91,8 +91,8 @@ the request flow. Property that gives access to the signals that will be executed when a request creates a new connection. - The coroutines listening will receive the ``session`` and - ``trace_config_ctx`` params. + The coroutines listening will receive the ``session``, + ``trace_config_ctx`` and :class:`aiohttp.TraceConnectionCreateStartParams` params. .. versionadded:: 3.0 @@ -101,8 +101,8 @@ the request flow. Property that gives access to the signals that will be executed when a request that created a new connection finishes its creation. - The coroutines listening will receive the ``session`` and - ``trace_config_ctx`` params. + The coroutines listening will receive the ``session``, + ``trace_config_ctx`` and :class:`aiohttp.TraceConnectionCreateEndParams` params. .. versionadded:: 3.0 @@ -111,8 +111,8 @@ the request flow. Property that gives access to the signals that will be executed when a request reuses a connection. - The coroutines listening will receive the ``session`` and - ``trace_config_ctx`` params. + The coroutines listening will receive the ``session``, + ``trace_config_ctx`` and :class:`aiohttp.TraceConnectionReuseconnParams` params. .. versionadded:: 3.0 @@ -121,8 +121,8 @@ the request flow. Property that gives access to the signals that will be executed when a request starts to resolve the domain related with the request. - The coroutines listening will receive the ``session`` and - ``trace_config_ctx`` params. + The coroutines listening will receive the ``session``, + ``trace_config_ctx`` and :class:`aiohttp.TraceDnsResolveHostStartParams` params. .. versionadded:: 3.0 @@ -131,8 +131,8 @@ the request flow. Property that gives access to the signals that will be executed when a request finishes to resolve the domain related with the request. - The coroutines listening will receive the ``session`` and ``trace_config_ctx`` - params. + The coroutines listening will receive the ``session``, + ``trace_config_ctx`` and :class:`aiohttp.TraceDnsResolveHostEndParams` params. .. versionadded:: 3.0 @@ -142,8 +142,8 @@ the request flow. request was able to use a cached DNS resolution for the domain related with the request. - The coroutines listening will receive the ``session`` and - ``trace_config_ctx`` params. + The coroutines listening will receive the ``session``, + ``trace_config_ctx`` and :class:`aiohttp.TraceDnsCacheHitParams` params. .. versionadded:: 3.0 @@ -153,7 +153,120 @@ the request flow. request was not able to use a cached DNS resolution for the domain related with the request. - The coroutines listening will receive the ``session`` and - ``trace_config_ctx`` params. + The coroutines listening will receive the ``session``, + ``trace_config_ctx`` and :class:`aiohttp.TraceDnsCacheMissParams` params. .. versionadded:: 3.0 + +.. class:: TraceRequestStartParams + + .. attribute:: method + + Method that will be used to make the request. + + .. attribute:: url + + URL that will be used for the request. + + .. attribute:: headers + + Headers that will be used for the request, can be mutated. + +.. class:: TraceRequestEndParams + + .. attribute:: method + + Method used to make the request. + + .. attribute:: url + + URL used for the request. + + .. attribute:: headers + + Headers used for the request. + + .. attribute:: resp + + Response :class:`ClientReponse`. + + +.. class:: TraceRequestExceptionParams + + .. attribute:: method + + Method used to make the request. + + .. attribute:: url + + URL used for the request. + + .. attribute:: headers + + Headers used for the request. + + .. attribute:: exception + + Exception raised during the request. + +.. class:: TraceRequestRedirectParams + + .. attribute:: method + + Method used to get this redirect request. + + .. attribute:: url + + URL used for this redirect request. + + .. attribute:: headers + + Headers used for this redirect. + + .. attribute:: resp + + Response :class:`ClientReponse` got from the redirect. + +.. class:: TraceConnectionQueuedStartParams + + There are no attributes right now. + +.. class:: TraceConnectionQueuedEndParams + + There are no attributes right now. + +.. class:: TraceConnectionCreateStartParams + + There are no attributes right now. + +.. class:: TraceConnectionCreateEndParams + + There are no attributes right now. + +.. class:: TraceConnectionReuseconnParams + + There are no attributes right now. + +.. class:: TraceDnsResolveHostStartParams + + .. attribute:: Host + + Host that will be resolved. + +.. class:: TraceDnsResolveHostEndParams + + .. attribute:: Host + + Host that has been resolved. + +.. class:: TraceDnsCacheHitParams + + .. attribute:: Host + + Host found in the cache. + +.. class:: TraceDnsCacheMissParams + + .. attribute:: Host + + Host didn't find the cache. diff --git a/tests/test_client_session.py b/tests/test_client_session.py index 3ed641e3d7a..972106779bc 100644 --- a/tests/test_client_session.py +++ b/tests/test_client_session.py @@ -479,18 +479,22 @@ async def test_request_tracing(loop): on_request_start.assert_called_once_with( session, trace_config_ctx, - hdrs.METH_GET, - URL("http://example.com"), - CIMultiDict() + aiohttp.TraceRequestStartParams( + hdrs.METH_GET, + URL("http://example.com"), + CIMultiDict() + ) ) on_request_end.assert_called_once_with( session, trace_config_ctx, - hdrs.METH_GET, - URL("http://example.com"), - CIMultiDict(), - resp + aiohttp.TraceRequestEndParams( + hdrs.METH_GET, + URL("http://example.com"), + CIMultiDict(), + resp + ) ) assert not on_request_redirect.called @@ -524,10 +528,12 @@ async def test_request_tracing_exception(loop): on_request_exception.assert_called_once_with( session, mock.ANY, - hdrs.METH_GET, - URL("http://example.com"), - CIMultiDict(), - error + aiohttp.TraceRequestExceptionParams( + hdrs.METH_GET, + URL("http://example.com"), + CIMultiDict(), + error + ) ) assert not on_request_end.called @@ -544,11 +550,8 @@ def __init__(self, *args, **kwargs): async def new_headers( session, trace_config_ctx, - method, - url, - headers, - trace_request_ctx=None): - headers['foo'] = 'bar' + data): + data.headers['foo'] = 'bar' trace_config = aiohttp.TraceConfig() trace_config.on_request_start.append(new_headers) diff --git a/tests/test_connector.py b/tests/test_connector.py index ddfe4871985..f8536de9fce 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -666,14 +666,17 @@ async def test_tcp_connector_dns_tracing(loop, dns_response): on_dns_resolvehost_start.assert_called_once_with( session, trace_config_ctx, + aiohttp.TraceDnsResolveHostStartParams('localhost') ) - on_dns_resolvehost_start.assert_called_once_with( + on_dns_resolvehost_end.assert_called_once_with( session, trace_config_ctx, + aiohttp.TraceDnsResolveHostEndParams('localhost') ) on_dns_cache_miss.assert_called_once_with( session, trace_config_ctx, + aiohttp.TraceDnsCacheMissParams('localhost') ) assert not on_dns_cache_hit.called @@ -685,6 +688,7 @@ async def test_tcp_connector_dns_tracing(loop, dns_response): on_dns_cache_hit.assert_called_once_with( session, trace_config_ctx, + aiohttp.TraceDnsCacheHitParams('localhost') ) @@ -738,21 +742,25 @@ async def test_tcp_connector_dns_tracing_cache_disabled(loop, dns_response): on_dns_resolvehost_start.assert_has_calls([ mock.call( session, - trace_config_ctx + trace_config_ctx, + aiohttp.TraceDnsResolveHostStartParams('localhost') ), mock.call( session, - trace_config_ctx + trace_config_ctx, + aiohttp.TraceDnsResolveHostStartParams('localhost') ) ]) on_dns_resolvehost_end.assert_has_calls([ mock.call( session, - trace_config_ctx + trace_config_ctx, + aiohttp.TraceDnsResolveHostEndParams('localhost') ), mock.call( session, - trace_config_ctx + trace_config_ctx, + aiohttp.TraceDnsResolveHostEndParams('localhost') ) ]) @@ -793,11 +801,13 @@ async def test_tcp_connector_dns_tracing_throttle_requests(loop, dns_response): await asyncio.sleep(0, loop=loop) on_dns_cache_hit.assert_called_once_with( session, - trace_config_ctx + trace_config_ctx, + aiohttp.TraceDnsCacheHitParams('localhost') ) on_dns_cache_miss.assert_called_once_with( session, - trace_config_ctx + trace_config_ctx, + aiohttp.TraceDnsCacheMissParams('localhost') ) @@ -933,11 +943,13 @@ async def test_connect_tracing(loop): await conn.connect(req, traces=traces) on_connection_create_start.assert_called_with( session, - trace_config_ctx + trace_config_ctx, + aiohttp.TraceConnectionCreateStartParams() ) on_connection_create_end.assert_called_with( session, - trace_config_ctx + trace_config_ctx, + aiohttp.TraceConnectionCreateEndParams() ) @@ -1333,11 +1345,13 @@ async def f(): ) on_connection_queued_start.assert_called_with( session, - trace_config_ctx + trace_config_ctx, + aiohttp.TraceConnectionQueuedStartParams() ) on_connection_queued_end.assert_called_with( session, - trace_config_ctx + trace_config_ctx, + aiohttp.TraceConnectionQueuedEndParams() ) connection2.release() @@ -1381,7 +1395,8 @@ async def test_connect_reuseconn_tracing(loop, key): on_connection_reuseconn.assert_called_with( session, - trace_config_ctx + trace_config_ctx, + aiohttp.TraceConnectionReuseconnParams() ) conn.close()