From 29eccad84e8200b5c90856c8732da0fdbbcef904 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Sat, 26 Oct 2019 12:36:51 +0300 Subject: [PATCH] Backport contextvars support (#4271) --- CHANGES/3380.bugfix | 1 + CHANGES/3557.feature | 1 + aiohttp/_http_parser.pyx | 3 + aiohttp/http_parser.py | 6 ++ aiohttp/web_protocol.py | 120 ++++++++++++++++--------- tests/test_web_log.py | 37 ++++++++ tests/test_web_protocol.py | 12 +-- tests/test_web_server.py | 2 +- tests/test_web_websocket_functional.py | 26 ++++++ 9 files changed, 157 insertions(+), 51 deletions(-) create mode 100644 CHANGES/3380.bugfix create mode 100644 CHANGES/3557.feature diff --git a/CHANGES/3380.bugfix b/CHANGES/3380.bugfix new file mode 100644 index 00000000000..4c66ff0394b --- /dev/null +++ b/CHANGES/3380.bugfix @@ -0,0 +1 @@ +Fix failed websocket handshake leaving connection hanging. diff --git a/CHANGES/3557.feature b/CHANGES/3557.feature new file mode 100644 index 00000000000..9d2b10be0f7 --- /dev/null +++ b/CHANGES/3557.feature @@ -0,0 +1 @@ +Call ``AccessLogger.log`` with the current exception available from sys.exc_info(). diff --git a/aiohttp/_http_parser.pyx b/aiohttp/_http_parser.pyx index b0ee4a18d38..1160c4120f6 100644 --- a/aiohttp/_http_parser.pyx +++ b/aiohttp/_http_parser.pyx @@ -533,6 +533,9 @@ cdef class HttpParser: else: return messages, False, b'' + def set_upgraded(self, val): + self._upgraded = val + cdef class HttpRequestParser(HttpParser): diff --git a/aiohttp/http_parser.py b/aiohttp/http_parser.py index 9e22d10263a..f12f0796971 100644 --- a/aiohttp/http_parser.py +++ b/aiohttp/http_parser.py @@ -411,6 +411,12 @@ def parse_headers( return (headers, raw_headers, close_conn, encoding, upgrade, chunked) + def set_upgraded(self, val: bool) -> None: + """Set connection upgraded (to websocket) mode. + :param bool val: new state. + """ + self._upgraded = val + class HttpRequestParser(HttpParser): """Read request status line. Exception .http_exceptions.BadStatusLine diff --git a/aiohttp/web_protocol.py b/aiohttp/web_protocol.py index a8b49b4c310..8796739644f 100644 --- a/aiohttp/web_protocol.py +++ b/aiohttp/web_protocol.py @@ -13,6 +13,7 @@ Awaitable, Callable, Optional, + Tuple, Type, cast, ) @@ -371,6 +372,33 @@ def _process_keepalive(self) -> None: self._keepalive_handle = self._loop.call_later( self.KEEPALIVE_RESCHEDULE_DELAY, self._process_keepalive) + async def _handle_request(self, + request: BaseRequest, + start_time: float, + ) -> Tuple[StreamResponse, bool]: + assert self._request_handler is not None + try: + resp = await self._request_handler(request) + except HTTPException as exc: + resp = Response(status=exc.status, + reason=exc.reason, + text=exc.text, + headers=exc.headers) + reset = await self.finish_response(request, resp, start_time) + except asyncio.CancelledError: + raise + except asyncio.TimeoutError as exc: + self.log_debug('Request handler timed out.', exc_info=exc) + resp = self.handle_error(request, 504) + reset = await self.finish_response(request, resp, start_time) + except Exception as exc: + resp = self.handle_error(request, 500, exc) + reset = await self.finish_response(request, resp, start_time) + else: + reset = await self.finish_response(request, resp, start_time) + + return resp, reset + async def start(self) -> None: """Process incoming request. @@ -403,8 +431,7 @@ async def start(self) -> None: message, payload = self._messages.popleft() - if self.access_log: - now = loop.time() + start = loop.time() manager.requests_count += 1 writer = StreamWriter(self, loop) @@ -413,54 +440,23 @@ async def start(self) -> None: try: # a new task is used for copy context vars (#3406) task = self._loop.create_task( - self._request_handler(request)) + self._handle_request(request, start)) try: - resp = await task - except HTTPException as exc: - resp = exc + resp, reset = await task except (asyncio.CancelledError, ConnectionError): self.log_debug('Ignored premature client disconnection') break - except asyncio.TimeoutError as exc: - self.log_debug('Request handler timed out.', exc_info=exc) - resp = self.handle_error(request, 504) - except Exception as exc: - resp = self.handle_error(request, 500, exc) - else: - # Deprecation warning (See #2415) - if getattr(resp, '__http_exception__', False): - warnings.warn( - "returning HTTPException object is deprecated " - "(#2415) and will be removed, " - "please raise the exception instead", - DeprecationWarning) + # Deprecation warning (See #2415) + if getattr(resp, '__http_exception__', False): + warnings.warn( + "returning HTTPException object is deprecated " + "(#2415) and will be removed, " + "please raise the exception instead", + DeprecationWarning) # Drop the processed task from asyncio.Task.all_tasks() early del task - - if self.debug: - if not isinstance(resp, StreamResponse): - if resp is None: - raise RuntimeError("Missing return " - "statement on request handler") - else: - raise RuntimeError("Web-handler should return " - "a response instance, " - "got {!r}".format(resp)) - try: - prepare_meth = resp.prepare - except AttributeError: - if resp is None: - raise RuntimeError("Missing return " - "statement on request handler") - else: - raise RuntimeError("Web-handler should return " - "a response instance, " - "got {!r}".format(resp)) - try: - await prepare_meth(request) - await resp.write_eof() - except ConnectionError: + if reset: self.log_debug('Ignored premature client disconnection 2') break @@ -469,7 +465,7 @@ async def start(self) -> None: # log access if self.access_log: - self.log_access(request, resp, loop.time() - now) + self.log_access(request, resp, loop.time() - start) # check payload if not payload.is_eof(): @@ -530,6 +526,42 @@ async def start(self) -> None: if self.transport is not None and self._error_handler is None: self.transport.close() + async def finish_response(self, + request: BaseRequest, + resp: StreamResponse, + start_time: float) -> bool: + """ + Prepare the response and write_eof, then log access. This has to + be called within the context of any exception so the access logger + can get exception information. Returns True if the client disconnects + prematurely. + """ + if self._request_parser is not None: + self._request_parser.set_upgraded(False) + self._upgrade = False + if self._message_tail: + self._request_parser.feed_data(self._message_tail) + self._message_tail = b'' + try: + prepare_meth = resp.prepare + except AttributeError: + if resp is None: + raise RuntimeError("Missing return " + "statement on request handler") + else: + raise RuntimeError("Web-handler should return " + "a response instance, " + "got {!r}".format(resp)) + try: + await prepare_meth(request) + await resp.write_eof() + except ConnectionError: + self.log_access(request, resp, start_time) + return True + else: + self.log_access(request, resp, start_time) + return False + def handle_error(self, request: BaseRequest, status: int=500, diff --git a/tests/test_web_log.py b/tests/test_web_log.py index 0f62bc8bd0d..15236cf6b41 100644 --- a/tests/test_web_log.py +++ b/tests/test_web_log.py @@ -5,9 +5,17 @@ import pytest import aiohttp +from aiohttp import web from aiohttp.abc import AbstractAccessLogger +from aiohttp.helpers import PY_37 from aiohttp.web_log import AccessLogger +try: + from contextvars import ContextVar +except ImportError: + ContextVar = None + + IS_PYPY = platform.python_implementation() == 'PyPy' @@ -157,3 +165,32 @@ def log(self, request, response, time): access_logger = Logger(mock_logger, '{request} {response} {time}') access_logger.log('request', 'response', 1) mock_logger.info.assert_called_with('request response 1') + + +@pytest.mark.skipif(not PY_37, + reason="contextvars support is required") +async def test_contextvars_logger(aiohttp_server, aiohttp_client): + VAR = ContextVar('VAR') + + async def handler(request): + return web.Response() + + @web.middleware + async def middleware(request, handler): + VAR.set("uuid") + return await handler(request) + + msg = None + + class Logger(AbstractAccessLogger): + def log(self, request, response, time): + nonlocal msg + msg = 'contextvars: {}'.format(VAR.get()) + + app = web.Application(middlewares=[middleware]) + app.router.add_get('/', handler) + server = await aiohttp_server(app, access_log_class=Logger) + client = await aiohttp_client(server) + resp = await client.get('/') + assert 200 == resp.status + assert msg == 'contextvars: uuid' diff --git a/tests/test_web_protocol.py b/tests/test_web_protocol.py index 6a97f056436..d343a860b2d 100644 --- a/tests/test_web_protocol.py +++ b/tests/test_web_protocol.py @@ -373,7 +373,7 @@ async def handle(request): b'GET / HTTP/1.1\r\n' b'Host: example.com\r\n' b'Content-Length: 0\r\n\r\n') - await asyncio.sleep(0) + await asyncio.sleep(0.01) # with exception request_handler.side_effect = handle_with_error() @@ -384,7 +384,7 @@ async def handle(request): assert srv._task_handler - await asyncio.sleep(0) + await asyncio.sleep(0.01) await srv._task_handler assert normal_completed @@ -600,7 +600,7 @@ async def test_content_length_0(srv, request_handler) -> None: b'GET / HTTP/1.1\r\n' b'Host: example.org\r\n' b'Content-Length: 0\r\n\r\n') - await asyncio.sleep(0) + await asyncio.sleep(0.01) assert request_handler.called assert request_handler.call_args[0][0].content == streams.EMPTY_PAYLOAD @@ -722,7 +722,7 @@ async def handle1(request): b'GET / HTTP/1.1\r\n' b'Host: example.com\r\n' b'Content-Length: 0\r\n\r\n') - await asyncio.sleep(0) + await asyncio.sleep(0.01) # second @@ -740,7 +740,7 @@ async def handle2(request): b'GET / HTTP/1.1\r\n' b'Host: example.com\r\n' b'Content-Length: 0\r\n\r\n') - await asyncio.sleep(0) + await asyncio.sleep(0.01) assert srv._task_handler is not None @@ -855,4 +855,4 @@ async def handler(request): writer.write(b"x") writer.close() await asyncio.sleep(0.1) - logger.debug.assert_called_with('Ignored premature client disconnection 2') + logger.debug.assert_called_with('Ignored premature client disconnection.') diff --git a/tests/test_web_server.py b/tests/test_web_server.py index 57ffa110557..eabc313db0c 100644 --- a/tests/test_web_server.py +++ b/tests/test_web_server.py @@ -89,7 +89,7 @@ async def handler(request): with pytest.raises(client.ClientPayloadError): await resp.read() - logger.debug.assert_called_with('Ignored premature client disconnection ') + logger.debug.assert_called_with('Ignored premature client disconnection') async def test_raw_server_not_http_exception_debug(aiohttp_raw_server, diff --git a/tests/test_web_websocket_functional.py b/tests/test_web_websocket_functional.py index 710e3ffb19f..7ad984045d0 100644 --- a/tests/test_web_websocket_functional.py +++ b/tests/test_web_websocket_functional.py @@ -788,3 +788,29 @@ async def handler(request): ws = await client.ws_connect('/') data = await ws.receive_str() assert data == 'OK' + + +async def test_bug3380(loop, aiohttp_client) -> None: + + async def handle_null(request): + return aiohttp.web.json_response({'err': None}) + + async def ws_handler(request): + return web.Response(status=401) + + app = web.Application() + app.router.add_route('GET', '/ws', ws_handler) + app.router.add_route('GET', '/api/null', handle_null) + + client = await aiohttp_client(app) + + resp = await client.get('/api/null') + assert (await resp.json()) == {'err': None} + resp.close() + + with pytest.raises(aiohttp.WSServerHandshakeError): + await client.ws_connect('/ws') + + resp = await client.get('/api/null', timeout=1) + assert (await resp.json()) == {'err': None} + resp.close()