Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

💅 Propagate error causes via asyncio protocols #8089

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGES/8089.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
The asynchronous internals now set the underlying causes
when assigning exceptions to the future objects
-- by :user:`webknjaz`.
12 changes: 7 additions & 5 deletions aiohttp/_http_parser.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ from multidict import CIMultiDict as _CIMultiDict, CIMultiDictProxy as _CIMultiD
from yarl import URL as _URL

from aiohttp import hdrs
from aiohttp.helpers import DEBUG
from aiohttp.helpers import DEBUG, set_exception

from .http_exceptions import (
BadHttpMessage,
Expand Down Expand Up @@ -763,11 +763,13 @@ cdef int cb_on_body(cparser.llhttp_t* parser,
cdef bytes body = at[:length]
try:
pyparser._payload.feed_data(body, length)
except BaseException as exc:
except BaseException as underlying_exc:
reraised_exc = underlying_exc
if pyparser._payload_exception is not None:
pyparser._payload.set_exception(pyparser._payload_exception(str(exc)))
else:
pyparser._payload.set_exception(exc)
reraised_exc = pyparser._payload_exception(str(underlying_exc))

set_exception(pyparser._payload, reraised_exc, underlying_exc)

pyparser._payload_error = 1
return -1
else:
Expand Down
7 changes: 6 additions & 1 deletion aiohttp/base_protocol.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
from typing import Optional, cast

from .helpers import set_exception
from .tcp_helpers import tcp_nodelay


Expand Down Expand Up @@ -76,7 +77,11 @@ def connection_lost(self, exc: Optional[BaseException]) -> None:
if exc is None:
waiter.set_result(None)
else:
waiter.set_exception(exc)
set_exception(
waiter,
ConnectionError("Connection lost"),
exc,
)

async def _drain_helper(self) -> None:
if not self.connected:
Expand Down
74 changes: 55 additions & 19 deletions aiohttp/client_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,21 @@

from .base_protocol import BaseProtocol
from .client_exceptions import (
ClientConnectionError,
ClientOSError,
ClientPayloadError,
ServerDisconnectedError,
SocketTimeoutError,
)
from .helpers import (
_EXC_SENTINEL,
BaseTimerContext,
set_exception,
set_result,
status_code_must_be_empty_body,
)
from .http import HttpResponseParser, RawResponseMessage, WebSocketReader
from .http_exceptions import HttpProcessingError
from .streams import EMPTY_PAYLOAD, DataQueue, StreamReader


Expand Down Expand Up @@ -80,41 +83,70 @@ def is_connected(self) -> bool:
def connection_lost(self, exc: Optional[BaseException]) -> None:
self._drop_timeout()

if exc is not None:
set_exception(self.closed, exc)
else:
original_connection_error = exc
reraised_exc = original_connection_error

connection_closed_cleanly = original_connection_error is None

if connection_closed_cleanly:
set_result(self.closed, None)
else:
assert original_connection_error is not None
bdraco marked this conversation as resolved.
Show resolved Hide resolved
set_exception(
webknjaz marked this conversation as resolved.
Show resolved Hide resolved
self.closed,
ClientConnectionError(
f"Connection lost: {original_connection_error !s}",
),
original_connection_error,
)

if self._payload_parser is not None:
with suppress(Exception):
with suppress(Exception): # FIXME: log this somehow?
self._payload_parser.feed_eof()

uncompleted = None
webknjaz marked this conversation as resolved.
Show resolved Hide resolved
if self._parser is not None:
try:
uncompleted = self._parser.feed_eof()
except Exception as e:
except Exception as underlying_exc:
if self._payload is not None:
exc = ClientPayloadError("Response payload is not completed")
exc.__cause__ = e
self._payload.set_exception(exc)
client_payload_exc_msg = (
f"Response payload is not completed: {underlying_exc !r}"
)
if not connection_closed_cleanly:
client_payload_exc_msg = (
f"{client_payload_exc_msg !s}. "
f"{original_connection_error !r}"
)
set_exception(
self._payload,
ClientPayloadError(client_payload_exc_msg),
underlying_exc,
)

if not self.is_eof():
if isinstance(exc, OSError):
exc = ClientOSError(*exc.args)
if exc is None:
exc = ServerDisconnectedError(uncompleted)
if isinstance(original_connection_error, OSError):
reraised_exc = ClientOSError(*original_connection_error.args)
if connection_closed_cleanly:
reraised_exc = ServerDisconnectedError(uncompleted)
# assigns self._should_close to True as side effect,
# we do it anyway below
self.set_exception(exc)
underlying_non_eof_exc = (
_EXC_SENTINEL
if connection_closed_cleanly
else original_connection_error
)
assert underlying_non_eof_exc is not None
assert reraised_exc is not None
webknjaz marked this conversation as resolved.
Show resolved Hide resolved
self.set_exception(reraised_exc, underlying_non_eof_exc)

self._should_close = True
self._parser = None
self._payload = None
self._payload_parser = None
self._reading_paused = False

super().connection_lost(exc)
super().connection_lost(reraised_exc)

def eof_received(self) -> None:
# should call parser.feed_eof() most likely
Expand All @@ -128,10 +160,14 @@ def resume_reading(self) -> None:
super().resume_reading()
self._reschedule_timeout()

def set_exception(self, exc: BaseException) -> None:
def set_exception(
self,
exc: BaseException,
exc_cause: BaseException = _EXC_SENTINEL,
) -> None:
self._should_close = True
self._drop_timeout()
super().set_exception(exc)
super().set_exception(exc, exc_cause)

def set_parser(self, parser: Any, payload: Any) -> None:
# TODO: actual types are:
Expand Down Expand Up @@ -208,7 +244,7 @@ def _on_read_timeout(self) -> None:
exc = SocketTimeoutError("Timeout on reading data from socket")
self.set_exception(exc)
if self._payload is not None:
self._payload.set_exception(exc)
set_exception(self._payload, exc)

def data_received(self, data: bytes) -> None:
self._reschedule_timeout()
Expand All @@ -234,14 +270,14 @@ def data_received(self, data: bytes) -> None:
# parse http messages
try:
messages, upgraded, tail = self._parser.feed_data(data)
except BaseException as exc:
except BaseException as underlying_exc:
Dismissed Show dismissed Hide dismissed
if self.transport is not None:
# connection.release() could be called BEFORE
# data_received(), the transport is already
# closed in this case
self.transport.close()
# should_close is True after the call
self.set_exception(exc)
self.set_exception(HttpProcessingError(), underlying_exc)
Copy link
Member Author

@webknjaz webknjaz Feb 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Dreamsorcerer @bdraco I made this unconditional but turned it into a different error, because this turned out to be the wrong level of abstraction. It seems to be turned into what makes the tests pass in the end.

return

self._upgraded = upgraded
Expand Down
34 changes: 22 additions & 12 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
noop,
parse_mimetype,
reify,
set_exception,
set_result,
)
from .http import (
Expand Down Expand Up @@ -566,20 +567,29 @@ async def write_bytes(

for chunk in self.body:
await writer.write(chunk) # type: ignore[arg-type]
except OSError as exc:
if exc.errno is None and isinstance(exc, asyncio.TimeoutError):
protocol.set_exception(exc)
else:
new_exc = ClientOSError(
exc.errno, "Can not write request body for %s" % self.url
except OSError as underlying_exc:
reraised_exc = underlying_exc

exc_is_not_timeout = underlying_exc.errno is not None or not isinstance(
underlying_exc, asyncio.TimeoutError
)
if exc_is_not_timeout:
reraised_exc = ClientOSError(
underlying_exc.errno,
f"Can not write request body for {self.url !s}",
)
new_exc.__context__ = exc
webknjaz marked this conversation as resolved.
Show resolved Hide resolved
new_exc.__cause__ = exc
protocol.set_exception(new_exc)

set_exception(protocol, reraised_exc, underlying_exc)
except asyncio.CancelledError:
await writer.write_eof()
except Exception as exc:
protocol.set_exception(exc)
except Exception as underlying_exc:
set_exception(
protocol,
ClientConnectionError(
f"Failed to send bytes into the underlying connection {conn !s}",
),
underlying_exc,
)
else:
await writer.write_eof()
protocol.start_timeout()
Expand Down Expand Up @@ -1019,7 +1029,7 @@ def _notify_content(self) -> None:
content = self.content
# content can be None here, but the types are cheated elsewhere.
if content and content.exception() is None: # type: ignore[truthy-bool]
content.set_exception(ClientConnectionError("Connection closed"))
set_exception(content, ClientConnectionError("Connection closed"))
self._released = True

async def wait_for_close(self) -> None:
Expand Down
36 changes: 33 additions & 3 deletions aiohttp/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,9 +797,39 @@ def set_result(fut: "asyncio.Future[_T]", result: _T) -> None:
fut.set_result(result)


def set_exception(fut: "asyncio.Future[_T]", exc: BaseException) -> None:
if not fut.done():
fut.set_exception(exc)
_EXC_SENTINEL = BaseException()


class ErrorableProtocol(Protocol):
def set_exception(
self,
exc: BaseException,
exc_cause: BaseException = ...,
) -> None:
... # pragma: no cover
Dismissed Show dismissed Hide dismissed


def set_exception(
fut: "asyncio.Future[_T] | ErrorableProtocol",
exc: BaseException,
exc_cause: BaseException = _EXC_SENTINEL,
) -> None:
"""Set future exception.

If the future is marked as complete, this function is a no-op.

:param exc_cause: An exception that is a direct cause of ``exc``.
Only set if provided.
"""
if asyncio.isfuture(fut) and fut.done():
return

exc_is_sentinel = exc_cause is _EXC_SENTINEL
exc_causes_itself = exc is exc_cause
if not exc_is_sentinel and not exc_causes_itself:
exc.__cause__ = exc_cause
Dreamsorcerer marked this conversation as resolved.
Show resolved Hide resolved

fut.set_exception(exc)


@functools.total_ordering
Expand Down
27 changes: 18 additions & 9 deletions aiohttp/http_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,12 @@
from .base_protocol import BaseProtocol
from .compression_utils import HAS_BROTLI, BrotliDecompressor, ZLibDecompressor
from .helpers import (
_EXC_SENTINEL,
DEBUG,
NO_EXTENSIONS,
BaseTimerContext,
method_must_be_empty_body,
set_exception,
status_code_must_be_empty_body,
)
from .http_exceptions import (
Expand Down Expand Up @@ -439,13 +441,16 @@ def get_content_length() -> Optional[int]:
assert self._payload_parser is not None
try:
eof, data = self._payload_parser.feed_data(data[start_pos:], SEP)
except BaseException as exc:
except BaseException as underlying_exc:
Dismissed Show dismissed Hide dismissed
reraised_exc = underlying_exc
if self.payload_exception is not None:
self._payload_parser.payload.set_exception(
self.payload_exception(str(exc))
)
else:
self._payload_parser.payload.set_exception(exc)
reraised_exc = self.payload_exception(str(underlying_exc))

set_exception(
self._payload_parser.payload,
reraised_exc,
underlying_exc,
)

eof = True
data = b""
Expand Down Expand Up @@ -826,7 +831,7 @@ def feed_data(
exc = TransferEncodingError(
chunk[:pos].decode("ascii", "surrogateescape")
)
self.payload.set_exception(exc)
set_exception(self.payload, exc)
raise exc
size = int(bytes(size_b), 16)

Expand Down Expand Up @@ -929,8 +934,12 @@ def __init__(self, out: StreamReader, encoding: Optional[str]) -> None:
else:
self.decompressor = ZLibDecompressor(encoding=encoding)

def set_exception(self, exc: BaseException) -> None:
self.out.set_exception(exc)
def set_exception(
self,
exc: BaseException,
exc_cause: BaseException = _EXC_SENTINEL,
) -> None:
set_exception(self.out, exc, exc_cause)

def feed_data(self, chunk: bytes, size: int) -> None:
if not size:
Expand Down
4 changes: 2 additions & 2 deletions aiohttp/http_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from .base_protocol import BaseProtocol
from .compression_utils import ZLibCompressor, ZLibDecompressor
from .helpers import NO_EXTENSIONS
from .helpers import NO_EXTENSIONS, set_exception
from .streams import DataQueue

__all__ = (
Expand Down Expand Up @@ -305,7 +305,7 @@ def feed_data(self, data: bytes) -> Tuple[bool, bytes]:
return self._feed_data(data)
except Exception as exc:
self._exc = exc
self.queue.set_exception(exc)
set_exception(self.queue, exc)
return True, b""

def _feed_data(self, data: bytes) -> Tuple[bool, bytes]:
Expand Down
Loading
Loading