From fb6089dc2b4616f2e78e1f7d12bf49a23be8d7aa Mon Sep 17 00:00:00 2001 From: Niko Fink Date: Fri, 16 Mar 2018 12:45:33 +0100 Subject: [PATCH] first draft of more generic timeouts throughout the request lifecycle, combined with tracing --- aiohttp/client.py | 128 ++++++++++++++++----------------------- aiohttp/client_reqrep.py | 34 +++++------ aiohttp/connector.py | 82 +++++++++++-------------- aiohttp/lifecycle.py | 109 +++++++++++++++++++++++++++++++++ aiohttp/tracing.py | 19 ++++++ 5 files changed, 229 insertions(+), 143 deletions(-) create mode 100644 aiohttp/lifecycle.py diff --git a/aiohttp/client.py b/aiohttp/client.py index be8486af7ab..cd057a2ce12 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -10,6 +10,7 @@ import warnings from collections.abc import Coroutine +import attr from multidict import CIMultiDict, MultiDict, MultiDictProxy, istr from yarl import URL @@ -31,7 +32,7 @@ from .http_websocket import WSHandshakeError, ws_ext_gen, ws_ext_parse from .streams import FlowControlDataQueue from .tcp_helpers import tcp_cork, tcp_nodelay -from .tracing import Trace +from .lifecycle import RequestLifecycle, RequestTimeouts __all__ = (client_exceptions.__all__ + # noqa @@ -41,7 +42,7 @@ # 5 Minute default read and connect timeout -DEFAULT_TIMEOUT = 5 * 60 +DEFAULT_TIMEOUTS = RequestTimeouts(uber_timeout=5 * 60) class ClientSession: @@ -69,7 +70,7 @@ def __init__(self, *, connector=None, loop=None, cookies=None, ws_response_class=ClientWebSocketResponse, version=http.HttpVersion11, cookie_jar=None, connector_owner=True, raise_for_status=False, - read_timeout=sentinel, conn_timeout=None, + timeout=None, read_timeout=sentinel, conn_timeout=None, auto_decompress=True, trust_env=False, trace_configs=None): @@ -116,13 +117,21 @@ def __init__(self, *, connector=None, loop=None, cookies=None, self._default_auth = auth self._version = version self._json_serialize = json_serialize - self._read_timeout = (read_timeout if read_timeout is not sentinel - else DEFAULT_TIMEOUT) - self._conn_timeout = conn_timeout self._raise_for_status = raise_for_status self._auto_decompress = auto_decompress self._trust_env = trust_env + if (read_timeout is not sentinel or conn_timeout is not None) and (timeout is not None): + raise ValueError("Can't not specify a RequestTimeouts config via `timeout` parameter together " + "with legacy parameters `read_timeout` or `conn_timeout`. " + "Please merge the timeout values into the timeout config object.") + elif timeout: + self._timeout = timeout + else: + self._timeout = attr.evolve( + DEFAULT_TIMEOUTS, connection_acquiring_timeout=conn_timeout, + uber_timeout=(read_timeout if read_timeout is not sentinel else DEFAULT_TIMEOUTS.uber_timeout)) + # Convert to list of tuples if headers: headers = CIMultiDict(headers) @@ -191,7 +200,7 @@ async def _request(self, method, url, *, read_until_eof=True, proxy=None, proxy_auth=None, - timeout=sentinel, + timeout=sentinel, # sentinel -> inherit from session, None -> disable verify_ssl=None, fingerprint=None, ssl_context=None, @@ -199,10 +208,6 @@ async def _request(self, method, url, *, proxy_headers=None, trace_request_ctx=None): - # NOTE: timeout clamps existing connect and read timeouts. We cannot - # set the default to None because we need to detect if the user wants - # to use the existing timeouts by setting timeout to None. - if self.closed: raise RuntimeError('Session is closed') @@ -242,33 +247,13 @@ async def _request(self, method, url, *, except ValueError: raise InvalidURL(proxy) - # timeout is cumulative for all request operations - # (request, redirects, responses, data consuming) - tm = TimeoutHandle( - self._loop, - timeout if timeout is not sentinel else self._read_timeout) - handle = tm.start() - - traces = [ - Trace( - self, - trace_config, - trace_config.trace_config_ctx( - trace_request_ctx=trace_request_ctx) - ) - for trace_config in self._trace_configs - ] - - for trace in traces: - await trace.send_request_start( - method, - url, - headers - ) - - timer = tm.timer() + if timeout == sentinel: + timeout = self._timeout + elif isinstance(timeout, int): + timeout = attr.evolve(self._timeout, uber_timeout=timeout) + lifecycle = RequestLifecycle(self, self._loop, self._trace_configs, trace_request_ctx, timeout) try: - with timer: + with lifecycle.request_timer_context: while True: url, auth_from_url = strip_auth_from_url(url) if auth and auth_from_url: @@ -307,17 +292,16 @@ async def _request(self, method, url, *, compress=compress, chunked=chunked, expect100=expect100, loop=self._loop, response_class=self._response_class, - proxy=proxy, proxy_auth=proxy_auth, timer=timer, + proxy=proxy, proxy_auth=proxy_auth, lifecycle=lifecycle, session=self, auto_decompress=self._auto_decompress, - ssl=ssl, proxy_headers=proxy_headers, traces=traces) + ssl=ssl, proxy_headers=proxy_headers) # connection timeout try: - with CeilTimeout(self._conn_timeout, loop=self._loop): - conn = await self._connector.connect( - req, - traces=traces - ) + conn = await self._connector.connect( + req, + lifecycle=lifecycle + ) except asyncio.TimeoutError as exc: raise ServerTimeoutError( 'Connection timeout ' @@ -347,13 +331,12 @@ async def _request(self, method, url, *, if resp.status in ( 301, 302, 303, 307, 308) and allow_redirects: - for trace in traces: - await trace.send_request_redirect( - method, - url, - headers, - resp - ) + await lifecycle.send_request_redirect( + method, + url, + headers, + resp + ) redirects += 1 history.append(resp) @@ -411,37 +394,30 @@ async def _request(self, method, url, *, resp.raise_for_status() # register connection - if handle is not None: - if resp.connection is not None: - resp.connection.add_callback(handle.cancel) - else: - handle.cancel() + # XXX this will only be required if there are still valid open timeouts after request_end + if resp.connection is not None: + resp.connection.add_callback(lifecycle.clear_timeouts) + else: + lifecycle.clear_timeouts() resp._history = tuple(history) - for trace in traces: - await trace.send_request_end( - method, - url, - headers, - resp - ) + await lifecycle.send_request_end( + method, + url, + headers, + resp + ) return resp except BaseException as e: - # cleanup timer - tm.close() - if handle: - handle.cancel() - handle = None - - for trace in traces: - await trace.send_request_exception( - method, - url, - headers, - e - ) + await lifecycle.send_request_exception( + method, + url, + headers, + e + ) + lifecycle.clear_timeouts() raise def ws_connect(self, url, *, diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index 0cf5e49502a..bf48bed21d8 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -175,10 +175,8 @@ def __init__(self, method, url, *, chunked=None, expect100=False, loop=None, response_class=None, proxy=None, proxy_auth=None, - timer=None, session=None, auto_decompress=True, - ssl=None, - proxy_headers=None, - traces=None): + lifecycle=None, session=None, auto_decompress=True, + ssl=None, proxy_headers=None): if loop is None: loop = asyncio.get_event_loop() @@ -199,7 +197,7 @@ def __init__(self, method, url, *, self.loop = loop self.length = None self.response_class = response_class or ClientResponse - self._timer = timer if timer is not None else TimerNoop() + self._lifecycle = lifecycle self._auto_decompress = auto_decompress self._ssl = ssl @@ -219,9 +217,6 @@ def __init__(self, method, url, *, if data or self.method not in self.GET_METHODS: self.update_transfer_encoding() self.update_expect_continue(expect100) - if traces is None: - traces = [] - self._traces = traces def is_ssl(self): return self.url.scheme in ('https', 'wss') @@ -527,10 +522,10 @@ async def send(self, conn): self.response = self.response_class( self.method, self.original_url, - writer=self._writer, continue100=self._continue, timer=self._timer, + writer=self._writer, continue100=self._continue, request_info=self.request_info, auto_decompress=self._auto_decompress, - traces=self._traces, + lifecycle=self._lifecycle, loop=self.loop, session=self._session ) @@ -550,8 +545,8 @@ def terminate(self): self._writer = None async def _on_chunk_request_sent(self, chunk): - for trace in self._traces: - await trace.send_request_chunk_sent(chunk) + if self._lifecycle: + await self._lifecycle.send_request_chunk_sent(chunk) class ClientResponse(HeadersMixin): @@ -573,9 +568,9 @@ class ClientResponse(HeadersMixin): _closed = True # to allow __del__ for non-initialized properly response def __init__(self, method, url, *, - writer, continue100, timer, + writer, continue100, lifecycle, request_info, auto_decompress, - traces, loop, session): + loop, session): assert isinstance(url, URL) self.method = method @@ -589,10 +584,9 @@ def __init__(self, method, url, *, self._closed = True self._history = () self._request_info = request_info - self._timer = timer if timer is not None else TimerNoop() + self._lifecycle = lifecycle self._auto_decompress = auto_decompress # True by default self._cache = {} # required for @reify method decorator - self._traces = traces self._loop = loop self._session = session # store a reference to session #1985 if loop.get_debug(): @@ -683,12 +677,12 @@ async def start(self, connection, read_until_eof=False): self._connection = connection connection.protocol.set_response_params( - timer=self._timer, + timer=self._lifecycle.request_timer_context, skip_payload=self.method.lower() == 'head', read_until_eof=read_until_eof, auto_decompress=self._auto_decompress) - with self._timer: + with self._lifecycle.request_timer_context: while True: # read response try: @@ -813,8 +807,8 @@ async def read(self): if self._body is None: try: self._body = await self.content.read() - for trace in self._traces: - await trace.send_response_chunk_received(self._body) + if self._lifecycle: + await self._lifecycle.send_response_chunk_received(self._body) except BaseException: self.close() raise diff --git a/aiohttp/connector.py b/aiohttp/connector.py index 7c122fead8c..71b34d1da78 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -357,7 +357,7 @@ def closed(self): """ return self._closed - async def connect(self, req, traces=None): + async def connect(self, req, lifecycle=None): """Get from pool or create new connection.""" key = req.connection_key @@ -386,9 +386,8 @@ async def connect(self, req, traces=None): waiters = self._waiters[key] waiters.append(fut) - if traces: - for trace in traces: - await trace.send_connection_queued_start() + if lifecycle: + await lifecycle.send_connection_queued_start() try: await fut @@ -398,9 +397,8 @@ async def connect(self, req, traces=None): if not waiters: del self._waiters[key] - if traces: - for trace in traces: - await trace.send_connection_queued_end() + if lifecycle: + await lifecycle.send_connection_queued_end() proto = self._get(key) if proto is None: @@ -408,14 +406,13 @@ async def connect(self, req, traces=None): self._acquired.add(placeholder) self._acquired_per_host[key].add(placeholder) - if traces: - for trace in traces: - await trace.send_connection_create_start() + if lifecycle: + await lifecycle.send_connection_create_start() try: proto = await self._create_connection( req, - traces=traces + lifecycle=lifecycle ) if self._closed: proto.close() @@ -433,13 +430,11 @@ async def connect(self, req, traces=None): self._acquired.remove(placeholder) self._drop_acquired_per_host(key, placeholder) - if traces: - for trace in traces: - await trace.send_connection_create_end() + if lifecycle: + await lifecycle.send_connection_create_end() else: - if traces: - for trace in traces: - await trace.send_connection_reuseconn() + if lifecycle: + await lifecycle.send_connection_reuseconn() self._acquired.add(proto) self._acquired_per_host[key].add(proto) @@ -531,7 +526,7 @@ def _release(self, key, protocol, *, should_close=False): self._cleanup_handle = helpers.weakref_handle( self, '_cleanup', self._keepalive_timeout, self._loop) - async def _create_connection(self, req, traces=None): + async def _create_connection(self, req, lifecycle=None): raise NotImplementedError() @@ -651,23 +646,21 @@ def clear_dns_cache(self, host=None, port=None): else: self._cached_hosts.clear() - async def _resolve_host(self, host, port, traces=None): + async def _resolve_host(self, host, port, lifecycle=None): if is_ip_address(host): return [{'hostname': host, 'host': host, 'port': port, 'family': self._family, 'proto': 0, 'flags': 0}] if not self._use_dns_cache: - if traces: - for trace in traces: - await trace.send_dns_resolvehost_start(host) + if lifecycle: + await lifecycle.send_dns_resolvehost_start(host) res = (await self._resolver.resolve( host, port, family=self._family)) - if traces: - for trace in traces: - await trace.send_dns_resolvehost_end(host) + if lifecycle: + await lifecycle.send_dns_resolvehost_end(host) return res @@ -676,37 +669,32 @@ async def _resolve_host(self, host, port, traces=None): if (key in self._cached_hosts) and \ (not self._cached_hosts.expired(key)): - if traces: - for trace in traces: - await trace.send_dns_cache_hit(host) + if lifecycle: + await lifecycle.send_dns_cache_hit(host) return self._cached_hosts.next_addrs(key) if key in self._throttle_dns_events: - if traces: - for trace in traces: - await trace.send_dns_cache_hit(host) + if lifecycle: + await lifecycle.send_dns_cache_hit(host) await self._throttle_dns_events[key].wait() else: - if traces: - for trace in traces: - await trace.send_dns_cache_miss(host) + if lifecycle: + await lifecycle.send_dns_cache_miss(host) self._throttle_dns_events[key] = \ EventResultOrError(self._loop) try: - if traces: - for trace in traces: - await trace.send_dns_resolvehost_start(host) + if lifecycle: + await lifecycle.send_dns_resolvehost_start(host) addrs = await \ asyncio.shield(self._resolver.resolve(host, port, family=self._family), loop=self._loop) - if traces: - for trace in traces: - await trace.send_dns_resolvehost_end(host) + if lifecycle: + await lifecycle.send_dns_resolvehost_end(host) self._cached_hosts.add(key, addrs) self._throttle_dns_events[key].set() @@ -720,7 +708,7 @@ async def _resolve_host(self, host, port, traces=None): return self._cached_hosts.next_addrs(key) - async def _create_connection(self, req, traces=None): + async def _create_connection(self, req, lifecycle=None): """Create connection. Has same keyword arguments as BaseEventLoop.create_connection. @@ -728,12 +716,12 @@ async def _create_connection(self, req, traces=None): if req.proxy: _, proto = await self._create_proxy_connection( req, - traces=None + lifecycle=lifecycle ) else: _, proto = await self._create_direct_connection( req, - traces=None + lifecycle=lifecycle ) return proto @@ -808,7 +796,7 @@ async def _wrap_create_connection(self, *args, async def _create_direct_connection(self, req, *, client_error=ClientConnectorError, - traces=None): + lifecycle=None): sslcontext = self._get_ssl_context(req) fingerprint = self._get_fingerprint(req) @@ -816,7 +804,7 @@ async def _create_direct_connection(self, req, hosts = await self._resolve_host( req.url.raw_host, req.port, - traces=traces) + lifecycle=lifecycle) except OSError as exc: # in case of proxy it is not ClientProxyConnectionError # it is problem of resolving proxy ip itself @@ -854,7 +842,7 @@ async def _create_direct_connection(self, req, else: raise last_exc - async def _create_proxy_connection(self, req, traces=None): + async def _create_proxy_connection(self, req, lifecycle=None): # FIXME no callbacks to lifecycle headers = {} if req.proxy_headers is not None: headers = req.proxy_headers @@ -954,7 +942,7 @@ def path(self): """Path to unix socket.""" return self._path - async def _create_connection(self, req, traces=None): + async def _create_connection(self, req, lifecycle=None): try: _, proto = await self._loop.create_unix_connection( self._factory, self._path) diff --git a/aiohttp/lifecycle.py b/aiohttp/lifecycle.py new file mode 100644 index 00000000000..ec28e05c610 --- /dev/null +++ b/aiohttp/lifecycle.py @@ -0,0 +1,109 @@ +import asyncio +import time +import warnings +from collections import defaultdict +from math import ceil + +import attr + +from aiohttp.helpers import TimerContext, TimerNoop +from aiohttp.tracing import SIGNALS + + +def _tm(start, end): + if start not in SIGNALS.keys(): + raise ValueError("Invalid timeout start signal %s" % start) + if end not in SIGNALS.keys(): + raise ValueError("Invalid timeout end signal %s" % end) + return {"start": start, "end": end} + + +@attr.s(frozen=True, slots=True) +class RequestTimeouts: + # XXX what happens between "connection_create_start" and "connection_create_end" and whether we'll get callbacks is implementation dependent + read_timeout = attr.ib(type=float, default=None) # TODO + connection_create_timeout = attr.ib(type=float, default=None, metadata=_tm("connection_create_start", "connection_create_end")) + + uber_timeout = attr.ib(type=float, default=None, metadata=_tm("request_start", "request_end")) + pool_queue_timeout = attr.ib(type=float, default=None, metadata=_tm("connection_queued_start", "connection_queued_end")) + dns_resolution_timeout = attr.ib(type=float, default=None, metadata=_tm("dns_resolvehost_start", "dns_resolvehost_end")) + socket_connect_timeout = attr.ib(type=float, default=None) # TODO + connection_acquiring_timeout = attr.ib(type=float, default=None) # TODO + new_connection_timeout = attr.ib(type=float, default=None) # TODO + http_header_timeout = attr.ib(type=float, default=None) # TODO metadata=_tm("request_sent", "response_headers_received")) + response_body_timeout = attr.ib(type=float, default=None) # TODO metadata=_tm("response_headers_received", "request_end")) + + # to create a timeout specific for a single request, either + # - create a completely new one to overwrite the default + # - or use http://www.attrs.org/en/stable/api.html#attr.evolve to overwrite the defaults + # (maybe this should be done through either session.extend_timeout or session.replace_timeout without directly calling RequestTimeouts) + + +class RequestLifecycle: + """ Internal class used to keep together the main dependencies used + at the moment of send a signal.""" + + def __init__(self, session, loop, trace_configs, trace_request_ctx, timeout_config): + self._session = session + self._loop = loop + self._trace_configs = [ + (config, config.trace_config_ctx(trace_request_ctx=trace_request_ctx)) for config in trace_configs + ] + self._timeout_config = timeout_config + + self._signal_timestamps = {} + self._set_timeouts = defaultdict(list) + self._active_timeouts = defaultdict(list) + + # filter timeouts that were actually set by the user + if timeout_config: + for timeout_field in attr.fields(RequestTimeouts): + if timeout_field.metadata and getattr(timeout_config, timeout_field.name): + self._active_timeouts[timeout_field.metadata["start"]].append( + (timeout_field.metadata["end"], getattr(timeout_config, timeout_field.name)) + ) + + # create a timeout context depending on wether any timeouts were actually set + if self._set_timeouts: + self.request_timer_context = TimerContext(self._loop) + else: + self.request_timer_context = TimerNoop() + + # generate send_signal methods, sending on_signal to all trace listeners and keeping track of timeouts, for all Signals + for signal, params_class in SIGNALS: + setattr(self, "send_" + signal, self._send(signal, params_class)) + + def _send(self, signal, params_class): + async def sender(*args, **kwargs): + # record timestamp + self._signal_timestamps[signal] = time.time() + + # cancel all running timeouts that end with this signal + while self._set_timeouts[signal]: + timeout_handle = self._set_timeouts[signal].pop() + timeout_handle.cancel() + + # send on_signal to all trace listeners + params = params_class(*args, **kwargs) + await asyncio.gather( + getattr(trace_config, "on_" + signal).send(self._session, trace_context, params) + for trace_config, trace_context in self._trace_configs + ) + + # start all timeouts that begin with this signal and register their handles for the end signal + for end, timeout in self._set_timeouts[signal]: + assert isinstance(self.request_timer_context, TimerContext) + at = ceil(self._loop.time() + timeout) + handle = self._loop.call_at(at, self.request_timer_context.timeout) + self._set_timeouts[end].append(handle) + + return sender + + def clear_timeouts(self): + for signal, timeout_handles in self._set_timeouts.items(): + while timeout_handles[signal]: + timeout_handle = timeout_handles[signal].pop() + warnings.warn("Timeout handle %s wasn't cancelled by it's end signal %s. " + "There was something wrong with the lifecycle transitions." + % (timeout_handle, signal)) + timeout_handle.cancel() diff --git a/aiohttp/tracing.py b/aiohttp/tracing.py index 165e68cbf9d..92faeef64f4 100644 --- a/aiohttp/tracing.py +++ b/aiohttp/tracing.py @@ -335,3 +335,22 @@ async def send_dns_cache_miss(self, host): self._trace_config_ctx, TraceDnsCacheMissParams(host) ) + + +SIGNALS = { + "request_start": TraceRequestStartParams, # (method, url, headers) + "request_chunk_sent": TraceRequestChunkSentParams, # (chunk) + "response_chunk_received": TraceResponseChunkReceivedParams, # (chunk) + "request_end": TraceRequestEndParams, # (method, url, headers, response) + "request_exception": TraceRequestExceptionParams, # (method, url, headers, exception) + "request_redirect": TraceRequestRedirectParams, # (method, url, headers, response) + "connection_queued_start": TraceConnectionQueuedStartParams, # () + "connection_queued_end": TraceConnectionQueuedEndParams, # () + "connection_create_start": TraceConnectionCreateStartParams, # () + "connection_create_end": TraceConnectionCreateEndParams, # () + "connection_reuseconn": TraceConnectionReuseconnParams, # () + "dns_resolvehost_start": TraceDnsResolveHostStartParams, # (host) + "dns_resolvehost_end": TraceDnsResolveHostEndParams, # (host) + "dns_cache_hit": TraceDnsCacheHitParams, # (host) + "dns_cache_miss": TraceDnsCacheMissParams, # (host) +}