From 233a7f99ee583be5cdb91931d9a25391b533a9e5 Mon Sep 17 00:00:00 2001 From: Gustavo Carneiro Date: Wed, 16 Oct 2019 16:33:08 +0100 Subject: [PATCH 1/5] Add a Request.wait_for_disconnection() method #2492 as means of allowing request handlers to be notified of premature client disconnections. --- CHANGES/2492.feature | 1 + aiohttp/web_request.py | 24 +++++++++++++++++++++++- tests/test_web_protocol.py | 16 ++++++++++++++-- 3 files changed, 38 insertions(+), 3 deletions(-) create mode 100644 CHANGES/2492.feature diff --git a/CHANGES/2492.feature b/CHANGES/2492.feature new file mode 100644 index 00000000000..5c98dbbbcf2 --- /dev/null +++ b/CHANGES/2492.feature @@ -0,0 +1 @@ +Add a Request.wait_for_disconnection() method, as means of allowing request handlers to be notified of premature client disconnections. diff --git a/aiohttp/web_request.py b/aiohttp/web_request.py index ebc16d8ef96..c460b0eedf0 100644 --- a/aiohttp/web_request.py +++ b/aiohttp/web_request.py @@ -17,6 +17,7 @@ Mapping, MutableMapping, Optional, + Set, Tuple, Union, cast, @@ -35,6 +36,7 @@ is_expected_content_type, reify, sentinel, + set_result, ) from .http_parser import RawRequestMessage from .multipart import BodyPartReader, MultipartReader @@ -110,7 +112,9 @@ class BaseRequest(MutableMapping[str, Any], HeadersMixin): '_message', '_protocol', '_payload_writer', '_payload', '_headers', '_method', '_version', '_rel_url', '_post', '_read_bytes', '_state', '_cache', '_task', '_client_max_size', '_loop', - '_transport_sslcontext', '_transport_peername') + '_transport_sslcontext', '_transport_peername', + '_disconnection_waiters', + ) def __init__(self, message: RawRequestMessage, payload: StreamReader, protocol: 'RequestHandler', @@ -142,6 +146,7 @@ def __init__(self, message: RawRequestMessage, self._task = task self._client_max_size = client_max_size self._loop = loop + self._disconnection_waiters = set() # type: Set[asyncio.Future[None]] transport = self._protocol.transport assert transport is not None @@ -690,6 +695,23 @@ async def _prepare_hook(self, response: StreamResponse) -> None: def _cancel(self, exc: BaseException) -> None: self._payload.set_exception(exc) + for fut in self._disconnection_waiters: + set_result(fut, None) + + async def wait_for_disconnection(self) -> None: + """Returns when the connection that sent this request closes + + This can be used in handlers as a means to receive a notification of + premature client disconnection. + + .. versionadded:: 4.0 + """ + fut = asyncio.Future() # type: asyncio.Future[None] + self._disconnection_waiters.add(fut) + try: + await fut + finally: + self._disconnection_waiters.remove(fut) class Request(BaseRequest): diff --git a/tests/test_web_protocol.py b/tests/test_web_protocol.py index 7f99ca2a7e9..12f310caa0d 100644 --- a/tests/test_web_protocol.py +++ b/tests/test_web_protocol.py @@ -839,17 +839,28 @@ async def test_two_data_received_without_waking_up_start_task(srv) -> None: async def test_client_disconnect(aiohttp_server) -> None: + loop = asyncio.get_event_loop() + loop.set_debug(True) + disconnected_notified = False async def handler(request): + + async def disconn(): + nonlocal disconnected_notified + await request.wait_for_disconnection() + disconnected_notified = True + + disconn_task = loop.create_task(disconn()) + buf = b"" with pytest.raises(ConnectionError): while len(buf) < 10: buf += await request.content.read(10) # return with closed transport means premature client disconnection + await asyncio.sleep(0) + disconn_task.cancel() return web.Response() - loop = asyncio.get_event_loop() - loop.set_debug(True) logger = mock.Mock() app = web.Application() app.router.add_route('POST', '/', handler) @@ -871,3 +882,4 @@ async def handler(request): writer.close() await asyncio.sleep(0.1) logger.debug.assert_called_with('Ignored premature client disconnection.') + assert disconnected_notified From 266bacc0dc96bc5aac1ba99b0af6facbd920d992 Mon Sep 17 00:00:00 2001 From: Gustavo Carneiro Date: Thu, 24 Oct 2019 17:47:02 +0100 Subject: [PATCH 2/5] request.wait_for_disconnection(): cancel upon request handling end --- aiohttp/web_protocol.py | 1 + aiohttp/web_request.py | 4 ++++ tests/test_web_protocol.py | 28 ++++++++++++++++++++++++++++ 3 files changed, 33 insertions(+) diff --git a/aiohttp/web_protocol.py b/aiohttp/web_protocol.py index dd89dd83357..3eb3e4f567b 100644 --- a/aiohttp/web_protocol.py +++ b/aiohttp/web_protocol.py @@ -560,6 +560,7 @@ async def finish_response(self, can get exception information. Returns True if the client disconnects prematurely. """ + request._finish() if self._request_parser is not None: self._request_parser.set_upgraded(False) self._upgrade = False diff --git a/aiohttp/web_request.py b/aiohttp/web_request.py index c460b0eedf0..6b4d4c26d00 100644 --- a/aiohttp/web_request.py +++ b/aiohttp/web_request.py @@ -698,6 +698,10 @@ def _cancel(self, exc: BaseException) -> None: for fut in self._disconnection_waiters: set_result(fut, None) + def _finish(self) -> None: + for fut in self._disconnection_waiters: + fut.cancel() + async def wait_for_disconnection(self) -> None: """Returns when the connection that sent this request closes diff --git a/tests/test_web_protocol.py b/tests/test_web_protocol.py index 12f310caa0d..991788b12d5 100644 --- a/tests/test_web_protocol.py +++ b/tests/test_web_protocol.py @@ -883,3 +883,31 @@ async def disconn(): await asyncio.sleep(0.1) logger.debug.assert_called_with('Ignored premature client disconnection.') assert disconnected_notified + + +async def test_wait_for_disconnection_cancel(srv, buf, monkeypatch) -> None: + # srv is aiohttp.web_protocol.RequestHandler + + waiter_tasks = [] + + async def request_waiter(request): + await request.wait_for_disconnection() + + orig_request_factory = srv._request_factory + + def request_factory(*args, **kwargs): + request = orig_request_factory(*args, **kwargs) + loop = asyncio.get_event_loop() + waiter_tasks.append(loop.create_task(request_waiter(request))) + return request + + monkeypatch.setattr(srv, "_request_factory", request_factory) + + srv.data_received( + b'GET / HTTP/1.1\r\n\r\n') + + await asyncio.sleep(0.05) + assert buf.startswith(b'HTTP/1.1 200 OK\r\n') + + assert len(waiter_tasks) == 1 + assert waiter_tasks[0].cancelled() From c2a59299d718240c7cc3855dc4762474208b1796 Mon Sep 17 00:00:00 2001 From: Gustavo Carneiro Date: Fri, 25 Oct 2019 11:29:26 +0100 Subject: [PATCH 3/5] Request.wait_for_disconnection(): move docs away from code --- aiohttp/web_request.py | 7 ------- docs/web_reference.rst | 8 ++++++++ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/aiohttp/web_request.py b/aiohttp/web_request.py index 6b4d4c26d00..ae5b2083c7f 100644 --- a/aiohttp/web_request.py +++ b/aiohttp/web_request.py @@ -703,13 +703,6 @@ def _finish(self) -> None: fut.cancel() async def wait_for_disconnection(self) -> None: - """Returns when the connection that sent this request closes - - This can be used in handlers as a means to receive a notification of - premature client disconnection. - - .. versionadded:: 4.0 - """ fut = asyncio.Future() # type: asyncio.Future[None] self._disconnection_waiters.add(fut) try: diff --git a/docs/web_reference.rst b/docs/web_reference.rst index 5df154104d7..d32e67ec70c 100644 --- a/docs/web_reference.rst +++ b/docs/web_reference.rst @@ -442,6 +442,14 @@ and :ref:`aiohttp-web-signals` handlers. required work will be processed by :mod:`aiohttp.web` internal machinery. + .. comethod:: wait_for_disconnection() + + Returns when the connection that sent this request closes + + This can be used in handlers as a means of receiving a notification of + premature client disconnection. + + .. versionadded:: 4.0 .. class:: Request From aeabbac16b382fff5240ba4079ba013f39d4d747 Mon Sep 17 00:00:00 2001 From: Gustavo Carneiro Date: Fri, 25 Oct 2019 11:37:24 +0100 Subject: [PATCH 4/5] clarify cancellation in docs --- docs/web_reference.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/web_reference.rst b/docs/web_reference.rst index d32e67ec70c..bd6268827d5 100644 --- a/docs/web_reference.rst +++ b/docs/web_reference.rst @@ -446,6 +446,10 @@ and :ref:`aiohttp-web-signals` handlers. Returns when the connection that sent this request closes + If there is no client disconnection during request handling, this + coroutine gets cancelled automatically at the end of this request being + handled. + This can be used in handlers as a means of receiving a notification of premature client disconnection. From 831ce96185f8f126dbd9cc45726380fd60fa96f7 Mon Sep 17 00:00:00 2001 From: "Gustavo J. A. M. Carneiro" Date: Fri, 25 Oct 2019 13:22:46 +0100 Subject: [PATCH 5/5] Update aiohttp/web_request.py Co-Authored-By: Andrew Svetlov --- aiohttp/web_request.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/aiohttp/web_request.py b/aiohttp/web_request.py index ae5b2083c7f..6f902cc7822 100644 --- a/aiohttp/web_request.py +++ b/aiohttp/web_request.py @@ -703,7 +703,8 @@ def _finish(self) -> None: fut.cancel() async def wait_for_disconnection(self) -> None: - fut = asyncio.Future() # type: asyncio.Future[None] + loop = asyncio.get_event_loop() + fut = loop.create_future() # type: asyncio.Future[None] self._disconnection_waiters.add(fut) try: await fut