Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[3.7] Don't cancel web handler on disconnection (#4080) #4771

Merged
merged 4 commits into from
Oct 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES/4080.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Don't cancel web handler on peer disconnection, raise `OSError` on reading/writing instead.
18 changes: 11 additions & 7 deletions aiohttp/http_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

from .base_protocol import BaseProtocol
from .helpers import NO_EXTENSIONS
from .log import ws_logger
from .streams import DataQueue

__all__ = ('WS_CLOSED_MESSAGE', 'WS_CLOSING_MESSAGE', 'WS_KEY',
Expand Down Expand Up @@ -568,8 +567,8 @@ def __init__(self, protocol: BaseProtocol, transport: asyncio.Transport, *,
async def _send_frame(self, message: bytes, opcode: int,
compress: Optional[int]=None) -> None:
"""Send a frame over the websocket with message as its payload."""
if self._closing:
ws_logger.warning('websocket connection is closing.')
if self._closing and not (opcode & WSMsgType.CLOSE):
raise ConnectionResetError('Cannot write to closing transport')

rsv = 0

Expand Down Expand Up @@ -617,21 +616,26 @@ async def _send_frame(self, message: bytes, opcode: int,
mask = mask.to_bytes(4, 'big')
message = bytearray(message)
_websocket_mask(mask, message)
self.transport.write(header + mask + message)
self._write(header + mask + message)
self._output_size += len(header) + len(mask) + len(message)
else:
if len(message) > MSG_SIZE:
self.transport.write(header)
self.transport.write(message)
self._write(header)
self._write(message)
else:
self.transport.write(header + message)
self._write(header + message)

self._output_size += len(header) + len(message)

if self._output_size > self._limit:
self._output_size = 0
await self.protocol._drain_helper()

def _write(self, data: bytes) -> None:
if self.transport is None or self.transport.is_closing():
raise ConnectionResetError('Cannot write to closing transport')
self.transport.write(data)

async def pong(self, message: bytes=b'') -> None:
"""Send pong message."""
if isinstance(message, str):
Expand Down
19 changes: 15 additions & 4 deletions aiohttp/web_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ class RequestHandler(BaseProtocol):
'_waiter', '_error_handler', '_task_handler',
'_upgrade', '_payload_parser', '_request_parser',
'_reading_paused', 'logger', 'debug', 'access_log',
'access_logger', '_close', '_force_close')
'access_logger', '_close', '_force_close',
'_current_request')

def __init__(self, manager: 'Server', *,
loop: asyncio.AbstractEventLoop,
Expand All @@ -135,6 +136,7 @@ def __init__(self, manager: 'Server', *,

self._request_count = 0
self._keepalive = False
self._current_request = None # type: Optional[BaseRequest]
self._manager = manager # type: Optional[Server]
self._request_handler = manager.request_handler # type: Optional[_RequestHandler] # noqa
self._request_factory = manager.request_factory # type: Optional[_RequestFactory] # noqa
Expand Down Expand Up @@ -202,6 +204,9 @@ async def shutdown(self, timeout: Optional[float]=15.0) -> None:
not self._error_handler.done()):
await self._error_handler

if self._current_request is not None:
self._current_request._cancel(asyncio.CancelledError())

if (self._task_handler is not None and
not self._task_handler.done()):
await self._task_handler
Expand Down Expand Up @@ -241,8 +246,10 @@ def connection_lost(self, exc: Optional[BaseException]) -> None:
if self._keepalive_handle is not None:
self._keepalive_handle.cancel()

if self._task_handler is not None:
self._task_handler.cancel()
if self._current_request is not None:
if exc is None:
exc = ConnectionResetError("Connection lost")
self._current_request._cancel(exc)

if self._error_handler is not None:
self._error_handler.cancel()
Expand Down Expand Up @@ -378,7 +385,11 @@ async def _handle_request(self,
) -> Tuple[StreamResponse, bool]:
assert self._request_handler is not None
try:
resp = await self._request_handler(request)
try:
self._current_request = request
resp = await self._request_handler(request)
finally:
self._current_request = None
except HTTPException as exc:
resp = Response(status=exc.status,
reason=exc.reason,
Expand Down
3 changes: 3 additions & 0 deletions aiohttp/web_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,9 @@ def __bool__(self) -> bool:
async def _prepare_hook(self, response: StreamResponse) -> None:
return

def _cancel(self, exc: BaseException) -> None:
self._payload.set_exception(exc)


class Request(BaseRequest):

Expand Down
4 changes: 4 additions & 0 deletions aiohttp/web_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,3 +455,7 @@ async def __anext__(self) -> WSMessage:
WSMsgType.CLOSED):
raise StopAsyncIteration # NOQA
return msg

def _cancel(self, exc: BaseException) -> None:
if self._reader is not None:
self._reader.set_exception(exc)
103 changes: 11 additions & 92 deletions docs/web_advanced.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,103 +20,22 @@ But in case of custom regular expressions for
*percent encoded*: if you pass Unicode patterns they don't match to
*requoted* path.

Peer disconnection
------------------

Web Handler Cancellation
------------------------

.. warning::

:term:`web-handler` execution could be canceled on every ``await``
if client drops connection without reading entire response's BODY.

The behavior is very different from classic WSGI frameworks like
Flask and Django.

Sometimes it is a desirable behavior: on processing ``GET`` request the
code might fetch data from database or other web resource, the
fetching is potentially slow.

Canceling this fetch is very good: the peer dropped connection
already, there is no reason to waste time and resources (memory etc) by
getting data from DB without any chance to send it back to peer.

But sometimes the cancellation is bad: on ``POST`` request very often
is needed to save data to DB regardless to peer closing.

Cancellation prevention could be implemented in several ways:

* Applying :func:`asyncio.shield` to coroutine that saves data into DB.
* Spawning a new task for DB saving
* Using aiojobs_ or other third party library.

:func:`asyncio.shield` works pretty good. The only disadvantage is you
need to split web handler into exactly two async functions: one
for handler itself and other for protected code.

For example the following snippet is not safe::

async def handler(request):
await asyncio.shield(write_to_redis(request))
await asyncio.shield(write_to_postgres(request))
return web.Response(text='OK')

Cancellation might be occurred just after saving data in REDIS,
``write_to_postgres`` will be not called.

Spawning a new task is much worse: there is no place to ``await``
spawned tasks::

async def handler(request):
request.loop.create_task(write_to_redis(request))
return web.Response(text='OK')

In this case errors from ``write_to_redis`` are not awaited, it leads
to many asyncio log messages *Future exception was never retrieved*
and *Task was destroyed but it is pending!*.

Moreover on :ref:`aiohttp-web-graceful-shutdown` phase *aiohttp* don't
wait for these tasks, you have a great chance to loose very important
data.

On other hand aiojobs_ provides an API for spawning new jobs and
awaiting their results etc. It stores all scheduled activity in
internal data structures and could terminate them gracefully::

from aiojobs.aiohttp import setup, spawn

async def coro(timeout):
await asyncio.sleep(timeout) # do something in background

async def handler(request):
await spawn(request, coro())
return web.Response()

app = web.Application()
setup(app)
app.router.add_get('/', handler)

All not finished jobs will be terminated on
:attr:`Application.on_cleanup` signal.
When a client peer is gone a subsequent reading or writing raises :exc:`OSError`
or more specific exception like :exc:`ConnectionResetError`.

To prevent cancellation of the whole :term:`web-handler` use
``@atomic`` decorator::
The reason for disconnection is vary; it can be a network issue or explicit
socket closing on the peer side without reading the whole server response.

from aiojobs.aiohttp import atomic
*aiohttp* handles disconnection properly but you can handle it explicitly, e.g.::

@atomic
async def handler(request):
await write_to_db()
return web.Response()

app = web.Application()
setup(app)
app.router.add_post('/', handler)

It prevents all ``handler`` async function from cancellation,
``write_to_db`` will be never interrupted.

.. _aiojobs: http://aiojobs.readthedocs.io/en/latest/

try:
text = await request.text()
except OSError:
# disconnected

Passing a coroutine into run_app and Gunicorn
---------------------------------------------
Expand Down
10 changes: 3 additions & 7 deletions tests/test_client_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import aiohttp
from aiohttp import client, hdrs
from aiohttp.http import WS_KEY
from aiohttp.log import ws_logger
from aiohttp.streams import EofStream
from aiohttp.test_utils import make_mocked_coro

Expand Down Expand Up @@ -363,7 +362,7 @@ async def test_close_exc2(loop, ws_key, key_data) -> None:
await resp.close()


async def test_send_data_after_close(ws_key, key_data, loop, mocker) -> None:
async def test_send_data_after_close(ws_key, key_data, loop) -> None:
resp = mock.Mock()
resp.status = 101
resp.headers = {
Expand All @@ -381,16 +380,13 @@ async def test_send_data_after_close(ws_key, key_data, loop, mocker) -> None:
'http://test.org')
resp._writer._closing = True

mocker.spy(ws_logger, 'warning')

for meth, args in ((resp.ping, ()),
(resp.pong, ()),
(resp.send_str, ('s',)),
(resp.send_bytes, (b'b',)),
(resp.send_json, ({},))):
await meth(*args)
assert ws_logger.warning.called
ws_logger.warning.reset_mock()
with pytest.raises(ConnectionResetError):
await meth(*args)


async def test_send_data_type_errors(ws_key, key_data, loop) -> None:
Expand Down
6 changes: 5 additions & 1 deletion tests/test_web_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,7 +785,11 @@ async def test_two_data_received_without_waking_up_start_task(srv) -> None:
async def test_client_disconnect(aiohttp_server) -> None:

async def handler(request):
await request.content.read(10)
buf = b""
with pytest.raises(ConnectionError):
while len(buf) < 10:
buf += await request.content.read(10)
# return with closed transport means premature client disconnection
return web.Response()

logger = mock.Mock()
Expand Down
Loading