diff --git a/CHANGES/2313.feature b/CHANGES/2313.feature new file mode 100644 index 00000000000..fd3ee1d6c15 --- /dev/null +++ b/CHANGES/2313.feature @@ -0,0 +1 @@ +ClientSession publishes a set of signals to track the HTTP request execution. diff --git a/aiohttp/__init__.py b/aiohttp/__init__.py index ba51666c9b9..59b0db4b62e 100644 --- a/aiohttp/__init__.py +++ b/aiohttp/__init__.py @@ -14,6 +14,7 @@ from .payload import * # noqa from .payload_streamer import * # noqa from .resolver import * # noqa +from .tracing import * # noqa try: from .worker import GunicornWebWorker, GunicornUVLoopWebWorker # noqa @@ -30,6 +31,7 @@ payload.__all__ + # noqa payload_streamer.__all__ + # noqa streams.__all__ + # noqa + tracing.__all__ + # noqa ('hdrs', 'HttpVersion', 'HttpVersion10', 'HttpVersion11', 'WSMsgType', 'WSCloseCode', 'WebSocketError', 'WSMessage', diff --git a/aiohttp/client.py b/aiohttp/client.py index 6a0370fe846..a453358b806 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -29,6 +29,7 @@ from .http import WS_KEY, WebSocketReader, WebSocketWriter from .http_websocket import WSHandshakeError, ws_ext_gen, ws_ext_parse from .streams import FlowControlDataQueue +from .tracing import Trace __all__ = (client_exceptions.__all__ + # noqa @@ -57,7 +58,8 @@ def __init__(self, *, connector=None, loop=None, cookies=None, version=http.HttpVersion11, cookie_jar=None, connector_owner=True, raise_for_status=False, read_timeout=sentinel, conn_timeout=None, - auto_decompress=True, trust_env=False): + auto_decompress=True, trust_env=False, + trace_configs=None): implicit_loop = False if loop is None: @@ -96,6 +98,7 @@ def __init__(self, *, connector=None, loop=None, cookies=None, if cookies is not None: self._cookie_jar.update_cookies(cookies) + self._connector = connector self._connector_owner = connector_owner self._default_auth = auth @@ -124,6 +127,10 @@ def __init__(self, *, connector=None, loop=None, cookies=None, self._response_class = response_class self._ws_response_class = ws_response_class + self._trace_configs = trace_configs or [] + for trace_config in self._trace_configs: + trace_config.freeze() + def __del__(self, _warnings=warnings): if not self.closed: _warnings.warn("Unclosed client session {!r}".format(self), @@ -159,7 +166,8 @@ def _request(self, method, url, *, verify_ssl=None, fingerprint=None, ssl_context=None, - proxy_headers=None): + proxy_headers=None, + trace_request_ctx=None): # NOTE: timeout clamps existing connect and read timeouts. We cannot # set the default to None because we need to detect if the user wants @@ -216,6 +224,22 @@ def _request(self, method, url, *, handle = tm.start() url = URL(url) + + traces = [ + Trace( + trace_config, + self, + trace_request_ctx=trace_request_ctx) + for trace_config in self._trace_configs + ] + + for trace in traces: + yield from trace.send_request_start( + method, + url, + headers + ) + timer = tm.timer() try: with timer: @@ -264,7 +288,10 @@ def _request(self, method, url, *, # connection timeout try: with CeilTimeout(self._conn_timeout, loop=self._loop): - conn = yield from self._connector.connect(req) + conn = yield from self._connector.connect( + req, + traces=traces + ) except asyncio.TimeoutError as exc: raise ServerTimeoutError( 'Connection timeout ' @@ -289,6 +316,15 @@ def _request(self, method, url, *, # redirects if resp.status in ( 301, 302, 303, 307, 308) and allow_redirects: + + for trace in traces: + yield from trace.send_request_redirect( + method, + url, + headers, + resp + ) + redirects += 1 history.append(resp) if max_redirects and redirects >= max_redirects: @@ -352,15 +388,30 @@ def _request(self, method, url, *, handle.cancel() resp._history = tuple(history) + + for trace in traces: + yield from trace.send_request_end( + method, + url, + headers, + resp + ) return resp - except Exception: + except Exception as e: # cleanup timer tm.close() if handle: handle.cancel() handle = None + for trace in traces: + yield from trace.send_request_exception( + method, + url, + headers, + e + ) raise def ws_connect(self, url, *, diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index a74719c1efc..6d2eaed86d7 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -472,8 +472,7 @@ def send(self, conn): self.method, self.original_url, writer=self._writer, continue100=self._continue, timer=self._timer, request_info=self.request_info, - auto_decompress=self._auto_decompress - ) + auto_decompress=self._auto_decompress) self.response._post_init(self.loop, self._session) return self.response diff --git a/aiohttp/connector.py b/aiohttp/connector.py index a74015e5f28..1041ed5bf32 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -347,7 +347,7 @@ def closed(self): """ return self._closed - async def connect(self, req): + async def connect(self, req, traces=None): """Get from pool or create new connection.""" key = req.connection_key @@ -375,6 +375,11 @@ async def connect(self, req): # This connection will now count towards the limit. waiters = self._waiters[key] waiters.append(fut) + + if traces: + for trace in traces: + await trace.send_connection_queued_start() + try: await fut finally: @@ -383,13 +388,25 @@ async def connect(self, req): if not waiters: del self._waiters[key] + if traces: + for trace in traces: + await trace.send_connection_queued_end() + proto = self._get(key) if proto is None: placeholder = _TransportPlaceholder() self._acquired.add(placeholder) self._acquired_per_host[key].add(placeholder) + + if traces: + for trace in traces: + await trace.send_connection_create_start() + try: - proto = await self._create_connection(req) + proto = await self._create_connection( + req, + traces=traces + ) if self._closed: proto.close() raise ClientConnectionError("Connector is closed.") @@ -405,6 +422,14 @@ async def connect(self, req): self._acquired.remove(placeholder) self._acquired_per_host[key].remove(placeholder) + if traces: + for trace in traces: + await trace.send_connection_create_end() + else: + if traces: + for trace in traces: + await trace.send_connection_reuseconn() + self._acquired.add(proto) self._acquired_per_host[key].add(proto) return Connection(self, key, proto, self._loop) @@ -497,7 +522,7 @@ def _release(self, key, protocol, *, should_close=False): self._cleanup_handle = helpers.weakref_handle( self, '_cleanup', self._keepalive_timeout, self._loop) - async def _create_connection(self, req): + async def _create_connection(self, req, traces=None): raise NotImplementedError() @@ -685,31 +710,63 @@ def clear_dns_cache(self, host=None, port=None): else: self._cached_hosts.clear() - async def _resolve_host(self, host, port): + async def _resolve_host(self, host, port, traces=None): if is_ip_address(host): return [{'hostname': host, 'host': host, 'port': port, 'family': self._family, 'proto': 0, 'flags': 0}] if not self._use_dns_cache: - return (await self._resolver.resolve( + + if traces: + for trace in traces: + await trace.send_dns_resolvehost_start() + + res = (await self._resolver.resolve( host, port, family=self._family)) + if traces: + for trace in traces: + await trace.send_dns_resolvehost_end() + + return res + key = (host, port) if (key in self._cached_hosts) and \ (not self._cached_hosts.expired(key)): + + if traces: + for trace in traces: + await trace.send_dns_cache_hit() + return self._cached_hosts.next_addrs(key) if key in self._throttle_dns_events: + if traces: + for trace in traces: + await trace.send_dns_cache_hit() await self._throttle_dns_events[key].wait() else: - self._throttle_dns_events[key] = EventResultOrError(self._loop) + if traces: + for trace in traces: + await trace.send_dns_cache_miss() + self._throttle_dns_events[key] = \ + EventResultOrError(self._loop) try: + + if traces: + for trace in traces: + await trace.send_dns_resolvehost_start() + addrs = await \ asyncio.shield(self._resolver.resolve(host, port, family=self._family), loop=self._loop) + if traces: + for trace in traces: + await trace.send_dns_resolvehost_end() + self._cached_hosts.add(key, addrs) self._throttle_dns_events[key].set() except Exception as e: @@ -722,15 +779,21 @@ async def _resolve_host(self, host, port): return self._cached_hosts.next_addrs(key) - async def _create_connection(self, req): + async def _create_connection(self, req, traces=None): """Create connection. Has same keyword arguments as BaseEventLoop.create_connection. """ if req.proxy: - _, proto = await self._create_proxy_connection(req) + _, proto = await self._create_proxy_connection( + req, + traces=None + ) else: - _, proto = await self._create_direct_connection(req) + _, proto = await self._create_direct_connection( + req, + traces=None + ) return proto @@ -787,12 +850,16 @@ async def _wrap_create_connection(self, *args, raise client_error(req.connection_key, exc) from exc async def _create_direct_connection(self, req, - *, client_error=ClientConnectorError): + *, client_error=ClientConnectorError, + traces=None): sslcontext = self._get_ssl_context(req) fingerprint, hashfunc = self._get_fingerprint_and_hashfunc(req) try: - hosts = await self._resolve_host(req.url.raw_host, req.port) + hosts = await self._resolve_host( + req.url.raw_host, + req.port, + traces=traces) except OSError as exc: # in case of proxy it is not ClientProxyConnectionError # it is problem of resolving proxy ip itself @@ -841,7 +908,7 @@ async def _create_direct_connection(self, req, else: raise last_exc - async def _create_proxy_connection(self, req): + async def _create_proxy_connection(self, req, traces=None): headers = {} if req.proxy_headers is not None: headers = req.proxy_headers @@ -943,7 +1010,7 @@ def path(self): """Path to unix socket.""" return self._path - async def _create_connection(self, req): + async def _create_connection(self, req, traces=None): try: _, proto = await self._loop.create_unix_connection( self._factory, self._path) diff --git a/aiohttp/tracing.py b/aiohttp/tracing.py new file mode 100644 index 00000000000..06eecc2a863 --- /dev/null +++ b/aiohttp/tracing.py @@ -0,0 +1,227 @@ +from types import SimpleNamespace + +from .signals import Signal + + +__all__ = ('TraceConfig',) + + +class TraceConfig: + """First-class used to trace requests launched via ClientSession + objects.""" + + def __init__(self, trace_config_ctx_class=SimpleNamespace): + self._on_request_start = Signal(self) + self._on_request_end = Signal(self) + self._on_request_exception = Signal(self) + self._on_request_redirect = Signal(self) + self._on_connection_queued_start = Signal(self) + self._on_connection_queued_end = Signal(self) + self._on_connection_create_start = Signal(self) + self._on_connection_create_end = Signal(self) + self._on_connection_reuseconn = Signal(self) + self._on_dns_resolvehost_start = Signal(self) + self._on_dns_resolvehost_end = Signal(self) + self._on_dns_cache_hit = Signal(self) + self._on_dns_cache_miss = Signal(self) + + self._trace_config_ctx_class = trace_config_ctx_class + + def trace_config_ctx(self): + """ Return a new trace_config_ctx instance """ + return self._trace_config_ctx_class() + + def freeze(self): + self._on_request_start.freeze() + self._on_request_end.freeze() + self._on_request_exception.freeze() + self._on_request_redirect.freeze() + self._on_connection_queued_start.freeze() + self._on_connection_queued_end.freeze() + self._on_connection_create_start.freeze() + self._on_connection_create_end.freeze() + self._on_connection_reuseconn.freeze() + self._on_dns_resolvehost_start.freeze() + self._on_dns_resolvehost_end.freeze() + self._on_dns_cache_hit.freeze() + self._on_dns_cache_miss.freeze() + + @property + def on_request_start(self): + return self._on_request_start + + @property + def on_request_end(self): + return self._on_request_end + + @property + def on_request_exception(self): + return self._on_request_exception + + @property + def on_request_redirect(self): + return self._on_request_redirect + + @property + def on_connection_queued_start(self): + return self._on_connection_queued_start + + @property + def on_connection_queued_end(self): + return self._on_connection_queued_end + + @property + def on_connection_create_start(self): + return self._on_connection_create_start + + @property + def on_connection_create_end(self): + return self._on_connection_create_end + + @property + def on_connection_reuseconn(self): + return self._on_connection_reuseconn + + @property + def on_dns_resolvehost_start(self): + return self._on_dns_resolvehost_start + + @property + def on_dns_resolvehost_end(self): + return self._on_dns_resolvehost_end + + @property + def on_dns_cache_hit(self): + return self._on_dns_cache_hit + + @property + def on_dns_cache_miss(self): + return self._on_dns_cache_miss + + +class Trace: + """ Internal class used to keep together the main dependencies used + at the moment of send a signal.""" + + def __init__(self, trace_config, session, trace_request_ctx=None): + self._trace_config = trace_config + self._trace_config_ctx = self._trace_config.trace_config_ctx() + self._trace_request_ctx = trace_request_ctx + self._session = session + + async def send_request_start(self, *args, **kwargs): + return await self._trace_config.on_request_start.send( + self._session, + self._trace_config_ctx, + *args, + trace_request_ctx=self._trace_request_ctx, + **kwargs + ) + + async def send_request_end(self, *args, **kwargs): + return await self._trace_config.on_request_end.send( + self._session, + self._trace_config_ctx, + *args, + trace_request_ctx=self._trace_request_ctx, + **kwargs + ) + + async def send_request_exception(self, *args, **kwargs): + return await self._trace_config.on_request_exception.send( + self._session, + self._trace_config_ctx, + *args, + trace_request_ctx=self._trace_request_ctx, + **kwargs + ) + + async def send_request_redirect(self, *args, **kwargs): + return await self._trace_config._on_request_redirect.send( + self._session, + self._trace_config_ctx, + *args, + trace_request_ctx=self._trace_request_ctx, + **kwargs + ) + + async def send_connection_queued_start(self, *args, **kwargs): + return await self._trace_config.on_connection_queued_start.send( + self._session, + self._trace_config_ctx, + *args, + trace_request_ctx=self._trace_request_ctx, + **kwargs + ) + + async def send_connection_queued_end(self, *args, **kwargs): + return await self._trace_config.on_connection_queued_end.send( + self._session, + self._trace_config_ctx, + *args, + trace_request_ctx=self._trace_request_ctx, + **kwargs + ) + + async def send_connection_create_start(self, *args, **kwargs): + return await self._trace_config.on_connection_create_start.send( + self._session, + self._trace_config_ctx, + *args, + trace_request_ctx=self._trace_request_ctx, + **kwargs + ) + + async def send_connection_create_end(self, *args, **kwargs): + return await self._trace_config.on_connection_create_end.send( + self._session, + self._trace_config_ctx, + *args, + trace_request_ctx=self._trace_request_ctx, + **kwargs + ) + + async def send_connection_reuseconn(self, *args, **kwargs): + return await self._trace_config.on_connection_reuseconn.send( + self._session, + self._trace_config_ctx, + *args, + trace_request_ctx=self._trace_request_ctx, + **kwargs + ) + + async def send_dns_resolvehost_start(self, *args, **kwargs): + return await self._trace_config.on_dns_resolvehost_start.send( + self._session, + self._trace_config_ctx, + *args, + trace_request_ctx=self._trace_request_ctx, + **kwargs + ) + + async def send_dns_resolvehost_end(self, *args, **kwargs): + return await self._trace_config.on_dns_resolvehost_end.send( + self._session, + self._trace_config_ctx, + *args, + trace_request_ctx=self._trace_request_ctx, + **kwargs + ) + + async def send_dns_cache_hit(self, *args, **kwargs): + return await self._trace_config.on_dns_cache_hit.send( + self._session, + self._trace_config_ctx, + *args, + trace_request_ctx=self._trace_request_ctx, + **kwargs + ) + + async def send_dns_cache_miss(self, *args, **kwargs): + return await self._trace_config.on_dns_cache_miss.send( + self._session, + self._trace_config_ctx, + *args, + trace_request_ctx=self._trace_request_ctx, + **kwargs + ) diff --git a/docs/client.rst b/docs/client.rst index c63210a0ff8..c7e3960a74b 100644 --- a/docs/client.rst +++ b/docs/client.rst @@ -14,3 +14,4 @@ The page contains all information about aiohttp Client API: client_quickstart client_advanced client_reference + tracing_reference diff --git a/docs/client_advanced.rst b/docs/client_advanced.rst index 2eccd46f359..8829115a6ac 100644 --- a/docs/client_advanced.rst +++ b/docs/client_advanced.rst @@ -130,6 +130,70 @@ session:: jar = aiohttp.DummyCookieJar() session = aiohttp.ClientSession(cookie_jar=jar) +Client tracing +-------------- + +The execution flow of a specific request can be followed attaching listeners coroutines +to the signals provided by the :class:`TraceConfig` instance, this instance will be used +as a parameter for the :class:`ClientSession` constructor having as a result a client that +triggers the different signals supported by the :class:`TraceConfig`. By default any instance +of :class:`ClientSession` class comes with the signals ability disabled. The following +snippet shows how the start and the end signals of a request flow can be followed:: + + async def on_request_start( + session, trace_config_ctx, method, host, port, headers, request_trace_config_ctx=None): + print("Starting request") + + async def on_request_end(session, trace_config_ctx, resp, request_trace_config_ctx=None): + print("Ending request") + + trace_config = aiohttp.TraceConfig() + trace_config.on_request_start.append(on_request_start) + trace_config.on_request_end.append(on_request_end) + async with aiohttp.ClientSession(trace_configs=[trace_config]) as client: + client.get('http://example.com/some/redirect/') + +The `trace_configs` is a list that can contain instances of :class:`TraceConfig` class +that allow run the signals handlers coming from different :class:`TraceConfig` instances. +The following example shows how two different :class:`TraceConfig` that have a different +nature are installed to perform their job in each signal handle:: + + from .traceconfig import AuditRequest + from .traceconfig import XRay + + async with aiohttp.ClientSession(trace_configs=[AuditRequest(), XRay()]) as client: + client.get('http://example.com/some/redirect/') + + +All signals take as a parameters first, the :class:`ClientSession` instance used by +the specific request related to that signals and second, a :class:`SimpleNamespace` +instance called ``trace_config_ctx``. The ``trace_config_ctx`` object can be used to share +the state through to the different signals that belong to the same request and to +the same :class:`TraceConfig` class, perhaps:: + + async def on_request_start( + session, trace_config_ctx, method, host, port, headers, trace_request_ctx=None): + trace_config_ctx.start = session.loop.time() + + async def on_request_end( + session, trace_config_ctx, resp, trace_request_ctx=None): + elapsed = session.loop.time() - trace_config_ctx.start + print("Request took {}".format(elapsed)) + + +The ``trace_config_ctx`` param is by default a :class:`SimpleNampespace` that is initialized at +the beginning of the request flow. However, the factory used to create this object can be +overwritten using the ``trace_config_ctx_class`` constructor param of the +:class:`TraceConfig` class. + +The ``trace_request_ctx`` param can given at the beginning of the request execution and +will be passed as a keyword argument for all of the signals, as the following snippet shows:: + + session.get('http://example.com/some/redirect/', trace_request_ctx={'foo': 'bar'}) + + +.. seealso:: :ref:`aiohttp-tracing-reference` section for + more information about the different signals supported. Connectors ---------- diff --git a/docs/client_reference.rst b/docs/client_reference.rst index 1c37083a80b..ecebfd5b00d 100644 --- a/docs/client_reference.rst +++ b/docs/client_reference.rst @@ -294,6 +294,12 @@ The client session supports the context manager protocol for self closing. .. versionadded:: 2.3 + :param trace_request_ctx: Object used to give as a kw param for the all + signals triggered by the ongoing request and for all :class:`TraceConfig` + configured. Default passes a None value. + + .. versionadded:: 3.0 + :return ClientResponse: a :class:`client response ` object. @@ -523,6 +529,7 @@ The client session supports the context manager protocol for self closing. Session is switched to closed state anyway. + Basic API --------- @@ -1154,6 +1161,7 @@ Response object object, :class:`aiohttp.RequestInfo` instance. + ClientWebSocketResponse ----------------------- diff --git a/docs/tracing_reference.rst b/docs/tracing_reference.rst new file mode 100644 index 00000000000..6a90c963a44 --- /dev/null +++ b/docs/tracing_reference.rst @@ -0,0 +1,156 @@ +.. _aiohttp-tracing-reference: + +Tracing Reference +================= + +.. module:: aiohttp +.. currentmodule:: aiohttp + + +TraceConfig +----------- + +Trace config is the configuration object used to trace requests launched by +a Client session object using different events related to different parts of +the request flow. + +.. class:: TraceConfig(trace_config_ctx_class=SimpleNamespace) + + :param trace_config_ctx_class: factory used to create trace contexts, + default class used :class:`SimpleNamespace` + + .. method:: trace_config_ctx() + + Return a new trace context. + + .. attribute:: on_request_start + + Property that gives access to the signals that will be executed when a + request starts, based on the :class:`~signals.Signal` implementation. + + The coroutines listening will receive as a param the ``session``, + ``trace_config_ctx``, ``method``, ``url`` and ``headers``. + + .. versionadded:: 3.0 + + .. attribute:: on_request_redirect + + Property that gives access to the signals that will be executed when a + redirect happens during a request flow. + + The coroutines that are listening will receive the ``session``, + ``trace_config_ctx``, ``method``, ``url``, ``headers`` and ``resp`` params. + + .. versionadded:: 3.0 + + .. attribute:: on_request_end + + Property that gives access to the signals that will be executed when a + request ends. + + The coroutines that are listening will receive the ``session``, + ``trace_config_ctx``, ``method``, ``url``, ``headers`` and ``resp`` params + + .. versionadded:: 3.0 + + .. attribute:: on_request_exception + + Property that gives access to the signals that will be executed when a + request finishes with an exception. + + The coroutines listening will receive the ``session``, + ``trace_config_ctx``, ``method``, ``url``, ``headers`` and ``exception`` params. + + .. versionadded:: 3.0 + + .. attribute:: on_connection_queued_start + + Property that gives access to the signals that will be executed when a + request has been queued waiting for an available connection. + + The coroutines that are listening will receive the ``session`` and + ``trace_config_ctx`` params. + + .. versionadded:: 3.0 + + .. attribute:: on_connection_queued_end + + Property that gives access to the signals that will be executed when a + request that was queued already has an available connection. + + The coroutines that are listening will receive the ``session`` and + ``trace_config_ctx`` params. + + .. versionadded:: 3.0 + + .. attribute:: on_connection_create_start + + Property that gives access to the signals that will be executed when a + request creates a new connection. + + The coroutines listening will receive the ``session`` and + ``trace_config_ctx`` params. + + .. versionadded:: 3.0 + + .. attribute:: on_connection_create_end + + Property that gives access to the signals that will be executed when a + request that created a new connection finishes its creation. + + The coroutines listening will receive the ``session`` and + ``trace_config_ctx`` params. + + .. versionadded:: 3.0 + + .. attribute:: on_connection_reuseconn + + Property that gives access to the signals that will be executed when a + request reuses a connection. + + The coroutines listening will receive the ``session`` and + ``trace_config_ctx`` params. + + .. versionadded:: 3.0 + + .. attribute:: on_dns_resolvehost_start + + Property that gives access to the signals that will be executed when a + request starts to resolve the domain related with the request. + + The coroutines listening will receive the ``session`` and + ``trace_config_ctx`` params. + + .. versionadded:: 3.0 + + .. attribute:: on_dns_resolvehost_end + + Property that gives access to the signals that will be executed when a + request finishes to resolve the domain related with the request. + + The coroutines listening will receive the ``session`` and ``trace_config_ctx`` + params. + + .. versionadded:: 3.0 + + .. attribute:: on_dns_cache_hit + + Property that gives access to the signals that will be executed when a + request was able to use a cached DNS resolution for the domain related + with the request. + + The coroutines listening will receive the ``session`` and + ``trace_config_ctx`` params. + + .. versionadded:: 3.0 + + .. attribute:: on_dns_cache_miss + + Property that gives access to the signals that will be executed when a + request was not able to use a cached DNS resolution for the domain related + with the request. + + The coroutines listening will receive the ``session`` and + ``trace_config_ctx`` params. + + .. versionadded:: 3.0 diff --git a/tests/test_client_request.py b/tests/test_client_request.py index 00cb3059c5a..ea1dae0a465 100644 --- a/tests/test_client_request.py +++ b/tests/test_client_request.py @@ -500,7 +500,7 @@ def test_gen_netloc_no_port(make_request): '012345678901234567890' -async def test_connection_header(loop, conn): +def test_connection_header(loop, conn): req = ClientRequest('get', URL('http://python.org'), loop=loop) req.keep_alive = mock.Mock() req.headers.clear() @@ -1074,7 +1074,7 @@ def send(self, conn): called = True return resp - async def create_connection(req): + async def create_connection(req, traces=None): assert isinstance(req, CustomRequest) return mock.Mock() connector = BaseConnector(loop=loop) diff --git a/tests/test_client_session.py b/tests/test_client_session.py index e995fc37007..7796304613f 100644 --- a/tests/test_client_session.py +++ b/tests/test_client_session.py @@ -11,8 +11,9 @@ from yarl import URL import aiohttp -from aiohttp import web +from aiohttp import hdrs, web from aiohttp.client import ClientSession +from aiohttp.client_reqrep import ClientRequest from aiohttp.connector import BaseConnector, TCPConnector @@ -373,7 +374,7 @@ async def test_reraise_os_error(create_session): req.send = mock.Mock(side_effect=err) session = create_session(request_class=req_factory) - async def create_connection(req): + async def create_connection(req, traces=None): # return self.transport, self.protocol return mock.Mock() session._connector._create_connection = create_connection @@ -474,3 +475,115 @@ def test_client_session_implicit_loop_warn(): asyncio.set_event_loop(None) loop.close() + + +async def test_request_tracing(loop): + trace_config_ctx = mock.Mock() + trace_request_ctx = {} + on_request_start = mock.Mock(side_effect=asyncio.coroutine(mock.Mock())) + on_request_redirect = mock.Mock(side_effect=asyncio.coroutine(mock.Mock())) + on_request_end = mock.Mock(side_effect=asyncio.coroutine(mock.Mock())) + + trace_config = aiohttp.TraceConfig( + trace_config_ctx_class=mock.Mock(return_value=trace_config_ctx) + ) + trace_config.on_request_start.append(on_request_start) + trace_config.on_request_end.append(on_request_end) + trace_config.on_request_redirect.append(on_request_redirect) + + session = aiohttp.ClientSession(loop=loop, trace_configs=[trace_config]) + + resp = await session.get( + 'http://example.com', + trace_request_ctx=trace_request_ctx + ) + + on_request_start.assert_called_once_with( + session, + trace_config_ctx, + hdrs.METH_GET, + URL("http://example.com"), + CIMultiDict(), + trace_request_ctx=trace_request_ctx + ) + + on_request_end.assert_called_once_with( + session, + trace_config_ctx, + hdrs.METH_GET, + URL("http://example.com"), + CIMultiDict(), + resp, + trace_request_ctx=trace_request_ctx + ) + assert not on_request_redirect.called + + +async def test_request_tracing_exception(loop): + on_request_end = mock.Mock(side_effect=asyncio.coroutine(mock.Mock())) + on_request_exception = mock.Mock( + side_effect=asyncio.coroutine(mock.Mock()) + ) + + trace_config = aiohttp.TraceConfig() + trace_config.on_request_end.append(on_request_end) + trace_config.on_request_exception.append(on_request_exception) + + with mock.patch("aiohttp.client.TCPConnector.connect") as connect_patched: + error = Exception() + f = loop.create_future() + f.set_exception(error) + connect_patched.return_value = f + + session = aiohttp.ClientSession( + loop=loop, + trace_configs=[trace_config] + ) + + try: + await session.get('http://example.com') + except Exception: + pass + + on_request_exception.assert_called_once_with( + session, + mock.ANY, + hdrs.METH_GET, + URL("http://example.com"), + CIMultiDict(), + error, + trace_request_ctx=mock.ANY + ) + assert not on_request_end.called + + +async def test_request_tracing_interpose_headers(loop): + + class MyClientRequest(ClientRequest): + headers = None + + def __init__(self, *args, **kwargs): + super(MyClientRequest, self).__init__(*args, **kwargs) + MyClientRequest.headers = self.headers + + @asyncio.coroutine + def new_headers( + session, + trace_config_ctx, + method, + url, + headers, + trace_request_ctx=None): + headers['foo'] = 'bar' + + trace_config = aiohttp.TraceConfig() + trace_config.on_request_start.append(new_headers) + + session = aiohttp.ClientSession( + loop=loop, + request_class=MyClientRequest, + trace_configs=[trace_config] + ) + + await session.get('http://example.com') + assert MyClientRequest.headers['foo'] == 'bar' diff --git a/tests/test_connector.py b/tests/test_connector.py index 396c54418d4..40380d994d3 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -21,6 +21,7 @@ from aiohttp.client import ClientRequest from aiohttp.connector import Connection, _DNSCacheTable from aiohttp.test_utils import make_mocked_coro, unused_port +from aiohttp.tracing import Trace @pytest.fixture() @@ -377,7 +378,7 @@ async def test_tcp_connector_multiple_hosts_errors(loop): fingerprint=fingerprint, loop=loop) - async def _resolve_host(host, port): + async def _resolve_host(host, port, traces=None): return [{ 'hostname': host, 'host': ip, @@ -591,6 +592,203 @@ async def test_tcp_connector_dns_throttle_requests_cancelled_when_close( await f +async def test_tcp_connector_dns_tracing(loop, dns_response): + session = mock.Mock() + trace_config_ctx = mock.Mock() + trace_request_ctx = mock.Mock() + on_dns_resolvehost_start = mock.Mock( + side_effect=asyncio.coroutine(mock.Mock()) + ) + on_dns_resolvehost_end = mock.Mock( + side_effect=asyncio.coroutine(mock.Mock()) + ) + on_dns_cache_hit = mock.Mock( + side_effect=asyncio.coroutine(mock.Mock()) + ) + on_dns_cache_miss = mock.Mock( + side_effect=asyncio.coroutine(mock.Mock()) + ) + + trace_config = aiohttp.TraceConfig( + trace_config_ctx_class=mock.Mock(return_value=trace_config_ctx) + ) + trace_config.on_dns_resolvehost_start.append(on_dns_resolvehost_start) + trace_config.on_dns_resolvehost_end.append(on_dns_resolvehost_end) + trace_config.on_dns_cache_hit.append(on_dns_cache_hit) + trace_config.on_dns_cache_miss.append(on_dns_cache_miss) + trace_config.freeze() + traces = [ + Trace( + trace_config, + session, + trace_request_ctx=trace_request_ctx + ) + ] + + with mock.patch('aiohttp.connector.DefaultResolver') as m_resolver: + conn = aiohttp.TCPConnector( + loop=loop, + use_dns_cache=True, + ttl_dns_cache=10 + ) + + m_resolver().resolve.return_value = dns_response() + + await conn._resolve_host( + 'localhost', + 8080, + traces=traces + ) + on_dns_resolvehost_start.assert_called_once_with( + session, + trace_config_ctx, + trace_request_ctx=trace_request_ctx + ) + on_dns_resolvehost_start.assert_called_once_with( + session, + trace_config_ctx, + trace_request_ctx=trace_request_ctx + ) + on_dns_cache_miss.assert_called_once_with( + session, + trace_config_ctx, + trace_request_ctx=trace_request_ctx + ) + assert not on_dns_cache_hit.called + + await conn._resolve_host( + 'localhost', + 8080, + traces=traces + ) + on_dns_cache_hit.assert_called_once_with( + session, + trace_config_ctx, + trace_request_ctx=trace_request_ctx + ) + + +async def test_tcp_connector_dns_tracing_cache_disabled(loop, dns_response): + session = mock.Mock() + trace_config_ctx = mock.Mock() + trace_request_ctx = mock.Mock() + on_dns_resolvehost_start = mock.Mock( + side_effect=asyncio.coroutine(mock.Mock()) + ) + on_dns_resolvehost_end = mock.Mock( + side_effect=asyncio.coroutine(mock.Mock()) + ) + + trace_config = aiohttp.TraceConfig( + trace_config_ctx_class=mock.Mock(return_value=trace_config_ctx) + ) + trace_config.on_dns_resolvehost_start.append(on_dns_resolvehost_start) + trace_config.on_dns_resolvehost_end.append(on_dns_resolvehost_end) + trace_config.freeze() + traces = [ + Trace( + trace_config, + session, + trace_request_ctx=trace_request_ctx + ) + ] + + with mock.patch('aiohttp.connector.DefaultResolver') as m_resolver: + conn = aiohttp.TCPConnector( + loop=loop, + use_dns_cache=False + ) + + m_resolver().resolve.side_effect = [ + dns_response(), + dns_response() + ] + + await conn._resolve_host( + 'localhost', + 8080, + traces=traces + ) + + await conn._resolve_host( + 'localhost', + 8080, + traces=traces + ) + + on_dns_resolvehost_start.assert_has_calls([ + mock.call( + session, + trace_config_ctx, + trace_request_ctx=trace_request_ctx + ), + mock.call( + session, + trace_config_ctx, + trace_request_ctx=trace_request_ctx + ) + ]) + on_dns_resolvehost_end.assert_has_calls([ + mock.call( + session, + trace_config_ctx, + trace_request_ctx=trace_request_ctx + ), + mock.call( + session, + trace_config_ctx, + trace_request_ctx=trace_request_ctx + ) + ]) + + +async def test_tcp_connector_dns_tracing_throttle_requests(loop, dns_response): + session = mock.Mock() + trace_config_ctx = mock.Mock() + trace_request_ctx = mock.Mock() + on_dns_cache_hit = mock.Mock( + side_effect=asyncio.coroutine(mock.Mock()) + ) + on_dns_cache_miss = mock.Mock( + side_effect=asyncio.coroutine(mock.Mock()) + ) + + trace_config = aiohttp.TraceConfig( + trace_config_ctx_class=mock.Mock(return_value=trace_config_ctx) + ) + trace_config.on_dns_cache_hit.append(on_dns_cache_hit) + trace_config.on_dns_cache_miss.append(on_dns_cache_miss) + trace_config.freeze() + traces = [ + Trace( + trace_config, + session, + trace_request_ctx=trace_request_ctx + ) + ] + + with mock.patch('aiohttp.connector.DefaultResolver') as m_resolver: + conn = aiohttp.TCPConnector( + loop=loop, + use_dns_cache=True, + ttl_dns_cache=10 + ) + m_resolver().resolve.return_value = dns_response() + loop.create_task(conn._resolve_host('localhost', 8080, traces=traces)) + loop.create_task(conn._resolve_host('localhost', 8080, traces=traces)) + await asyncio.sleep(0, loop=loop) + on_dns_cache_hit.assert_called_once_with( + session, + trace_config_ctx, + trace_request_ctx=trace_request_ctx + ) + on_dns_cache_miss.assert_called_once_with( + session, + trace_config_ctx, + trace_request_ctx=trace_request_ctx + ) + + def test_dns_error(loop): connector = aiohttp.TCPConnector(loop=loop) connector._resolve_host = make_mocked_coro( @@ -686,6 +884,54 @@ async def test_connect(loop): connection.close() +async def test_connect_tracing(loop): + session = mock.Mock() + trace_config_ctx = mock.Mock() + trace_request_ctx = mock.Mock() + on_connection_create_start = mock.Mock( + side_effect=asyncio.coroutine(mock.Mock()) + ) + on_connection_create_end = mock.Mock( + side_effect=asyncio.coroutine(mock.Mock()) + ) + + trace_config = aiohttp.TraceConfig( + trace_config_ctx_class=mock.Mock(return_value=trace_config_ctx) + ) + trace_config.on_connection_create_start.append(on_connection_create_start) + trace_config.on_connection_create_end.append(on_connection_create_end) + trace_config.freeze() + traces = [ + Trace( + trace_config, + session, + trace_request_ctx=trace_request_ctx + ) + ] + + proto = mock.Mock() + proto.is_connected.return_value = True + + req = ClientRequest('GET', URL('http://host:80'), loop=loop) + + conn = aiohttp.BaseConnector(loop=loop) + conn._create_connection = mock.Mock() + conn._create_connection.return_value = loop.create_future() + conn._create_connection.return_value.set_result(proto) + + await conn.connect(req, traces=traces) + on_connection_create_start.assert_called_with( + session, + trace_config_ctx, + trace_request_ctx=trace_request_ctx + ) + on_connection_create_end.assert_called_with( + session, + trace_config_ctx, + trace_request_ctx=trace_request_ctx + ) + + async def test_close_during_connect(loop): proto = mock.Mock() proto.is_connected.return_value = True @@ -977,6 +1223,110 @@ async def f(): conn.close() +async def test_connect_queued_operation_tracing(loop, key): + session = mock.Mock() + trace_config_ctx = mock.Mock() + trace_request_ctx = mock.Mock() + on_connection_queued_start = mock.Mock( + side_effect=asyncio.coroutine(mock.Mock()) + ) + on_connection_queued_end = mock.Mock( + side_effect=asyncio.coroutine(mock.Mock()) + ) + + trace_config = aiohttp.TraceConfig( + trace_config_ctx_class=mock.Mock(return_value=trace_config_ctx) + ) + trace_config.on_connection_queued_start.append(on_connection_queued_start) + trace_config.on_connection_queued_end.append(on_connection_queued_end) + trace_config.freeze() + traces = [ + Trace( + trace_config, + session, + trace_request_ctx=trace_request_ctx + ) + ] + + proto = mock.Mock() + proto.is_connected.return_value = True + + req = ClientRequest('GET', URL('http://localhost1:80'), + loop=loop, + response_class=mock.Mock()) + + conn = aiohttp.BaseConnector(loop=loop, limit=1) + conn._conns[key] = [(proto, loop.time())] + conn._create_connection = mock.Mock() + conn._create_connection.return_value = loop.create_future() + conn._create_connection.return_value.set_result(proto) + + connection1 = await conn.connect(req, traces=traces) + + async def f(): + connection2 = await conn.connect( + req, + traces=traces + ) + on_connection_queued_start.assert_called_with( + session, + trace_config_ctx, + trace_request_ctx=trace_request_ctx + ) + on_connection_queued_end.assert_called_with( + session, + trace_config_ctx, + trace_request_ctx=trace_request_ctx + ) + connection2.release() + + task = asyncio.ensure_future(f(), loop=loop) + await asyncio.sleep(0.01, loop=loop) + connection1.release() + await task + conn.close() + + +async def test_connect_reuseconn_tracing(loop, key): + session = mock.Mock() + trace_config_ctx = mock.Mock() + trace_request_ctx = mock.Mock() + on_connection_reuseconn = mock.Mock( + side_effect=asyncio.coroutine(mock.Mock()) + ) + + trace_config = aiohttp.TraceConfig( + trace_config_ctx_class=mock.Mock(return_value=trace_config_ctx) + ) + trace_config.on_connection_reuseconn.append(on_connection_reuseconn) + trace_config.freeze() + traces = [ + Trace( + trace_config, + session, + trace_request_ctx=trace_request_ctx + ) + ] + + proto = mock.Mock() + proto.is_connected.return_value = True + + req = ClientRequest('GET', URL('http://localhost1:80'), + loop=loop, + response_class=mock.Mock()) + + conn = aiohttp.BaseConnector(loop=loop, limit=1) + conn._conns[key] = [(proto, loop.time())] + await conn.connect(req, traces=traces) + + on_connection_reuseconn.assert_called_with( + session, + trace_config_ctx, + trace_request_ctx=trace_request_ctx + ) + conn.close() + + async def test_connect_with_limit_and_limit_per_host(loop, key): proto = mock.Mock() proto.is_connected.return_value = True @@ -1136,7 +1486,7 @@ async def test_connect_with_limit_concurrent(loop): # Use a real coroutine for _create_connection; a mock would mask # problems that only happen when the method yields. - async def create_connection(req): + async def create_connection(req, traces=None): nonlocal num_connections num_connections += 1 await asyncio.sleep(0, loop=loop) @@ -1264,7 +1614,7 @@ async def test_error_on_connection(loop): fut = loop.create_future() exc = OSError() - async def create_connection(req): + async def create_connection(req, traces=None): nonlocal i i += 1 if i == 1: @@ -1306,7 +1656,7 @@ async def test_error_on_connection_with_cancelled_waiter(loop): fut2 = loop.create_future() exc = OSError() - async def create_connection(req): + async def create_connection(req, traces=None): nonlocal i i += 1 if i == 1: diff --git a/tests/test_proxy.py b/tests/test_proxy.py index cc066904e2b..1e1767d8fa1 100644 --- a/tests/test_proxy.py +++ b/tests/test_proxy.py @@ -122,7 +122,9 @@ def test_connect_req_verify_ssl_true(self, ClientRequestMock): (proto.transport, proto)) self.loop.run_until_complete(connector.connect(req)) - connector._create_proxy_connection.assert_called_with(req) + connector._create_proxy_connection.assert_called_with( + req, + traces=None) ((proxy_req,), _) = connector._create_direct_connection.call_args proxy_req.send.assert_called_with(mock.ANY) @@ -147,7 +149,9 @@ def test_connect_req_verify_ssl_false(self, ClientRequestMock): (proto.transport, proto)) self.loop.run_until_complete(connector.connect(req)) - connector._create_proxy_connection.assert_called_with(req) + connector._create_proxy_connection.assert_called_with( + req, + traces=None) ((proxy_req,), _) = connector._create_direct_connection.call_args proxy_req.send.assert_called_with(mock.ANY) @@ -183,7 +187,9 @@ def test_connect_req_fingerprint_ssl_context(self, ClientRequestMock): (transport, proto)) self.loop.run_until_complete(connector.connect(req)) - connector._create_proxy_connection.assert_called_with(req) + connector._create_proxy_connection.assert_called_with( + req, + traces=None) ((proxy_req,), _) = connector._create_direct_connection.call_args self.assertTrue(proxy_req.verify_ssl) self.assertEqual(proxy_req.fingerprint, req.fingerprint) @@ -676,7 +682,10 @@ def test_https_auth(self, ClientRequestMock): self.assertNotIn('AUTHORIZATION', proxy_req.headers) self.assertIn('PROXY-AUTHORIZATION', proxy_req.headers) - connector._resolve_host.assert_called_with('proxy.example.com', 80) + connector._resolve_host.assert_called_with( + 'proxy.example.com', + 80, + traces=None) self.loop.run_until_complete(proxy_req.close()) proxy_resp.close() diff --git a/tests/test_tracing.py b/tests/test_tracing.py new file mode 100644 index 00000000000..6235bd59bef --- /dev/null +++ b/tests/test_tracing.py @@ -0,0 +1,77 @@ +import asyncio +from types import SimpleNamespace +from unittest.mock import Mock + +import pytest + +from aiohttp.tracing import Trace, TraceConfig + + +class TestTraceConfig: + + def test_trace_config_ctx_default(self): + trace_config = TraceConfig() + assert isinstance(trace_config.trace_config_ctx(), SimpleNamespace) + + def test_trace_config_ctx_class(self): + trace_config = TraceConfig(trace_config_ctx_class=dict) + assert isinstance(trace_config.trace_config_ctx(), dict) + + def test_freeze(self): + trace_config = TraceConfig() + trace_config.freeze() + + assert trace_config.on_request_start.frozen + assert trace_config.on_request_end.frozen + assert trace_config.on_request_exception.frozen + assert trace_config.on_request_redirect.frozen + assert trace_config.on_connection_queued_start.frozen + assert trace_config.on_connection_queued_end.frozen + assert trace_config.on_connection_create_start.frozen + assert trace_config.on_connection_create_end.frozen + assert trace_config.on_connection_reuseconn.frozen + assert trace_config.on_dns_resolvehost_start.frozen + assert trace_config.on_dns_resolvehost_end.frozen + assert trace_config.on_dns_cache_hit.frozen + assert trace_config.on_dns_cache_miss.frozen + + +class TestTrace: + + @pytest.mark.parametrize('signal', [ + 'request_start', + 'request_end', + 'request_exception', + 'request_redirect', + 'connection_queued_start', + 'connection_queued_end', + 'connection_create_start', + 'connection_create_end', + 'connection_reuseconn', + 'dns_resolvehost_start', + 'dns_resolvehost_end', + 'dns_cache_hit', + 'dns_cache_miss' + ]) + async def test_send(self, loop, signal): + param = Mock() + session = Mock() + trace_request_ctx = Mock() + callback = Mock(side_effect=asyncio.coroutine(Mock())) + + trace_config = TraceConfig() + getattr(trace_config, "on_%s" % signal).append(callback) + trace_config.freeze() + trace = Trace( + trace_config, + session, + trace_request_ctx=trace_request_ctx + ) + await getattr(trace, "send_%s" % signal)(param) + + callback.assert_called_once_with( + session, + SimpleNamespace(), + param, + trace_request_ctx=trace_request_ctx + ) diff --git a/tests/test_web_functional.py b/tests/test_web_functional.py index 21ccbe8da53..f5d5b22910c 100644 --- a/tests/test_web_functional.py +++ b/tests/test_web_functional.py @@ -10,7 +10,8 @@ from yarl import URL import aiohttp -from aiohttp import FormData, HttpVersion10, HttpVersion11, multipart, web +from aiohttp import (FormData, HttpVersion10, HttpVersion11, TraceConfig, + multipart, web) try: @@ -1578,6 +1579,47 @@ async def handler(request): assert resp.status == 200 +async def test_request_tracing(loop, test_client): + + on_request_start = mock.Mock(side_effect=asyncio.coroutine(mock.Mock())) + on_request_end = mock.Mock(side_effect=asyncio.coroutine(mock.Mock())) + on_request_redirect = mock.Mock(side_effect=asyncio.coroutine(mock.Mock())) + on_connection_create_start = mock.Mock( + side_effect=asyncio.coroutine(mock.Mock())) + on_connection_create_end = mock.Mock( + side_effect=asyncio.coroutine(mock.Mock())) + + async def redirector(request): + raise web.HTTPFound(location=URL('/redirected')) + + async def redirected(request): + return web.Response() + + trace_config = TraceConfig() + + trace_config.on_request_start.append(on_request_start) + trace_config.on_request_end.append(on_request_end) + trace_config.on_request_redirect.append(on_request_redirect) + trace_config.on_connection_create_start.append( + on_connection_create_start) + trace_config.on_connection_create_end.append( + on_connection_create_end) + + app = web.Application() + app.router.add_get('/redirector', redirector) + app.router.add_get('/redirected', redirected) + + client = await test_client(app, trace_configs=[trace_config]) + + await client.get('/redirector', data="foo") + + assert on_request_start.called + assert on_request_end.called + assert on_request_redirect.called + assert on_connection_create_start.called + assert on_connection_create_end.called + + async def test_return_http_exception_deprecated(loop, test_client): async def handler(request):