Skip to content

Commit

Permalink
Add a Request.wait_for_disconnection() method #2492
Browse files Browse the repository at this point in the history
as means of allowing request handlers to be notified of premature client
disconnections.
  • Loading branch information
Gustavo Carneiro authored and gjcarneiro committed Oct 19, 2019
1 parent 6236536 commit 233a7f9
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGES/2492.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add a Request.wait_for_disconnection() method, as means of allowing request handlers to be notified of premature client disconnections.
24 changes: 23 additions & 1 deletion aiohttp/web_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
Mapping,
MutableMapping,
Optional,
Set,
Tuple,
Union,
cast,
Expand All @@ -35,6 +36,7 @@
is_expected_content_type,
reify,
sentinel,
set_result,
)
from .http_parser import RawRequestMessage
from .multipart import BodyPartReader, MultipartReader
Expand Down Expand Up @@ -110,7 +112,9 @@ class BaseRequest(MutableMapping[str, Any], HeadersMixin):
'_message', '_protocol', '_payload_writer', '_payload', '_headers',
'_method', '_version', '_rel_url', '_post', '_read_bytes',
'_state', '_cache', '_task', '_client_max_size', '_loop',
'_transport_sslcontext', '_transport_peername')
'_transport_sslcontext', '_transport_peername',
'_disconnection_waiters',
)

def __init__(self, message: RawRequestMessage,
payload: StreamReader, protocol: 'RequestHandler',
Expand Down Expand Up @@ -142,6 +146,7 @@ def __init__(self, message: RawRequestMessage,
self._task = task
self._client_max_size = client_max_size
self._loop = loop
self._disconnection_waiters = set() # type: Set[asyncio.Future[None]]

transport = self._protocol.transport
assert transport is not None
Expand Down Expand Up @@ -690,6 +695,23 @@ async def _prepare_hook(self, response: StreamResponse) -> None:

def _cancel(self, exc: BaseException) -> None:
self._payload.set_exception(exc)
for fut in self._disconnection_waiters:
set_result(fut, None)

async def wait_for_disconnection(self) -> None:
"""Returns when the connection that sent this request closes
This can be used in handlers as a means to receive a notification of
premature client disconnection.
.. versionadded:: 4.0
"""
fut = asyncio.Future() # type: asyncio.Future[None]
self._disconnection_waiters.add(fut)
try:
await fut
finally:
self._disconnection_waiters.remove(fut)


class Request(BaseRequest):
Expand Down
16 changes: 14 additions & 2 deletions tests/test_web_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,17 +839,28 @@ async def test_two_data_received_without_waking_up_start_task(srv) -> None:


async def test_client_disconnect(aiohttp_server) -> None:
loop = asyncio.get_event_loop()
loop.set_debug(True)
disconnected_notified = False

async def handler(request):

async def disconn():
nonlocal disconnected_notified
await request.wait_for_disconnection()
disconnected_notified = True

disconn_task = loop.create_task(disconn())

buf = b""
with pytest.raises(ConnectionError):
while len(buf) < 10:
buf += await request.content.read(10)
# return with closed transport means premature client disconnection
await asyncio.sleep(0)
disconn_task.cancel()
return web.Response()

loop = asyncio.get_event_loop()
loop.set_debug(True)
logger = mock.Mock()
app = web.Application()
app.router.add_route('POST', '/', handler)
Expand All @@ -871,3 +882,4 @@ async def handler(request):
writer.close()
await asyncio.sleep(0.1)
logger.debug.assert_called_with('Ignored premature client disconnection.')
assert disconnected_notified

0 comments on commit 233a7f9

Please sign in to comment.