diff --git a/CHANGES/2492.feature b/CHANGES/2492.feature new file mode 100644 index 00000000000..5c98dbbbcf2 --- /dev/null +++ b/CHANGES/2492.feature @@ -0,0 +1 @@ +Add a Request.wait_for_disconnection() method, as means of allowing request handlers to be notified of premature client disconnections. diff --git a/aiohttp/web_protocol.py b/aiohttp/web_protocol.py index dd89dd83357..3eb3e4f567b 100644 --- a/aiohttp/web_protocol.py +++ b/aiohttp/web_protocol.py @@ -560,6 +560,7 @@ async def finish_response(self, can get exception information. Returns True if the client disconnects prematurely. """ + request._finish() if self._request_parser is not None: self._request_parser.set_upgraded(False) self._upgrade = False diff --git a/aiohttp/web_request.py b/aiohttp/web_request.py index ebc16d8ef96..6f902cc7822 100644 --- a/aiohttp/web_request.py +++ b/aiohttp/web_request.py @@ -17,6 +17,7 @@ Mapping, MutableMapping, Optional, + Set, Tuple, Union, cast, @@ -35,6 +36,7 @@ is_expected_content_type, reify, sentinel, + set_result, ) from .http_parser import RawRequestMessage from .multipart import BodyPartReader, MultipartReader @@ -110,7 +112,9 @@ class BaseRequest(MutableMapping[str, Any], HeadersMixin): '_message', '_protocol', '_payload_writer', '_payload', '_headers', '_method', '_version', '_rel_url', '_post', '_read_bytes', '_state', '_cache', '_task', '_client_max_size', '_loop', - '_transport_sslcontext', '_transport_peername') + '_transport_sslcontext', '_transport_peername', + '_disconnection_waiters', + ) def __init__(self, message: RawRequestMessage, payload: StreamReader, protocol: 'RequestHandler', @@ -142,6 +146,7 @@ def __init__(self, message: RawRequestMessage, self._task = task self._client_max_size = client_max_size self._loop = loop + self._disconnection_waiters = set() # type: Set[asyncio.Future[None]] transport = self._protocol.transport assert transport is not None @@ -690,6 +695,21 @@ async def _prepare_hook(self, response: StreamResponse) -> None: def _cancel(self, exc: BaseException) -> None: self._payload.set_exception(exc) + for fut in self._disconnection_waiters: + set_result(fut, None) + + def _finish(self) -> None: + for fut in self._disconnection_waiters: + fut.cancel() + + async def wait_for_disconnection(self) -> None: + loop = asyncio.get_event_loop() + fut = loop.create_future() # type: asyncio.Future[None] + self._disconnection_waiters.add(fut) + try: + await fut + finally: + self._disconnection_waiters.remove(fut) class Request(BaseRequest): diff --git a/docs/web_reference.rst b/docs/web_reference.rst index 5df154104d7..bd6268827d5 100644 --- a/docs/web_reference.rst +++ b/docs/web_reference.rst @@ -442,6 +442,18 @@ and :ref:`aiohttp-web-signals` handlers. required work will be processed by :mod:`aiohttp.web` internal machinery. + .. comethod:: wait_for_disconnection() + + Returns when the connection that sent this request closes + + If there is no client disconnection during request handling, this + coroutine gets cancelled automatically at the end of this request being + handled. + + This can be used in handlers as a means of receiving a notification of + premature client disconnection. + + .. versionadded:: 4.0 .. class:: Request diff --git a/tests/test_web_protocol.py b/tests/test_web_protocol.py index 7f99ca2a7e9..991788b12d5 100644 --- a/tests/test_web_protocol.py +++ b/tests/test_web_protocol.py @@ -839,17 +839,28 @@ async def test_two_data_received_without_waking_up_start_task(srv) -> None: async def test_client_disconnect(aiohttp_server) -> None: + loop = asyncio.get_event_loop() + loop.set_debug(True) + disconnected_notified = False async def handler(request): + + async def disconn(): + nonlocal disconnected_notified + await request.wait_for_disconnection() + disconnected_notified = True + + disconn_task = loop.create_task(disconn()) + buf = b"" with pytest.raises(ConnectionError): while len(buf) < 10: buf += await request.content.read(10) # return with closed transport means premature client disconnection + await asyncio.sleep(0) + disconn_task.cancel() return web.Response() - loop = asyncio.get_event_loop() - loop.set_debug(True) logger = mock.Mock() app = web.Application() app.router.add_route('POST', '/', handler) @@ -871,3 +882,32 @@ async def handler(request): writer.close() await asyncio.sleep(0.1) logger.debug.assert_called_with('Ignored premature client disconnection.') + assert disconnected_notified + + +async def test_wait_for_disconnection_cancel(srv, buf, monkeypatch) -> None: + # srv is aiohttp.web_protocol.RequestHandler + + waiter_tasks = [] + + async def request_waiter(request): + await request.wait_for_disconnection() + + orig_request_factory = srv._request_factory + + def request_factory(*args, **kwargs): + request = orig_request_factory(*args, **kwargs) + loop = asyncio.get_event_loop() + waiter_tasks.append(loop.create_task(request_waiter(request))) + return request + + monkeypatch.setattr(srv, "_request_factory", request_factory) + + srv.data_received( + b'GET / HTTP/1.1\r\n\r\n') + + await asyncio.sleep(0.05) + assert buf.startswith(b'HTTP/1.1 200 OK\r\n') + + assert len(waiter_tasks) == 1 + assert waiter_tasks[0].cancelled()