diff --git a/CHANGES/8089.bugfix.rst b/CHANGES/8089.bugfix.rst new file mode 100644 index 00000000000..7f47448478d --- /dev/null +++ b/CHANGES/8089.bugfix.rst @@ -0,0 +1,3 @@ +The asynchronous internals now set the underlying causes +when assigning exceptions to the future objects +-- by :user:`webknjaz`. diff --git a/aiohttp/_http_parser.pyx b/aiohttp/_http_parser.pyx index 3f28fbdab43..7ea9b32ca55 100644 --- a/aiohttp/_http_parser.pyx +++ b/aiohttp/_http_parser.pyx @@ -19,7 +19,7 @@ from multidict import CIMultiDict as _CIMultiDict, CIMultiDictProxy as _CIMultiD from yarl import URL as _URL from aiohttp import hdrs -from aiohttp.helpers import DEBUG +from aiohttp.helpers import DEBUG, set_exception from .http_exceptions import ( BadHttpMessage, @@ -763,11 +763,13 @@ cdef int cb_on_body(cparser.llhttp_t* parser, cdef bytes body = at[:length] try: pyparser._payload.feed_data(body, length) - except BaseException as exc: + except BaseException as underlying_exc: + reraised_exc = underlying_exc if pyparser._payload_exception is not None: - pyparser._payload.set_exception(pyparser._payload_exception(str(exc))) - else: - pyparser._payload.set_exception(exc) + reraised_exc = pyparser._payload_exception(str(underlying_exc)) + + set_exception(pyparser._payload, reraised_exc, underlying_exc) + pyparser._payload_error = 1 return -1 else: diff --git a/aiohttp/base_protocol.py b/aiohttp/base_protocol.py index 4c9f0a752e3..dc1f24f99cd 100644 --- a/aiohttp/base_protocol.py +++ b/aiohttp/base_protocol.py @@ -1,6 +1,7 @@ import asyncio from typing import Optional, cast +from .helpers import set_exception from .tcp_helpers import tcp_nodelay @@ -76,7 +77,11 @@ def connection_lost(self, exc: Optional[BaseException]) -> None: if exc is None: waiter.set_result(None) else: - waiter.set_exception(exc) + set_exception( + waiter, + ConnectionError("Connection lost"), + exc, + ) async def _drain_helper(self) -> None: if not self.connected: diff --git a/aiohttp/client_proto.py b/aiohttp/client_proto.py index ca99808080d..723f5aae5f4 100644 --- a/aiohttp/client_proto.py +++ b/aiohttp/client_proto.py @@ -9,8 +9,14 @@ ServerDisconnectedError, ServerTimeoutError, ) -from .helpers import BaseTimerContext, status_code_must_be_empty_body +from .helpers import ( + _EXC_SENTINEL, + BaseTimerContext, + set_exception, + status_code_must_be_empty_body, +) from .http import HttpResponseParser, RawResponseMessage +from .http_exceptions import HttpProcessingError from .streams import EMPTY_PAYLOAD, DataQueue, StreamReader @@ -73,28 +79,50 @@ def is_connected(self) -> bool: def connection_lost(self, exc: Optional[BaseException]) -> None: self._drop_timeout() + original_connection_error = exc + reraised_exc = original_connection_error + + connection_closed_cleanly = original_connection_error is None + if self._payload_parser is not None: - with suppress(Exception): + with suppress(Exception): # FIXME: log this somehow? self._payload_parser.feed_eof() uncompleted = None if self._parser is not None: try: uncompleted = self._parser.feed_eof() - except Exception as e: + except Exception as underlying_exc: if self._payload is not None: - exc = ClientPayloadError("Response payload is not completed") - exc.__cause__ = e - self._payload.set_exception(exc) + client_payload_exc_msg = ( + f"Response payload is not completed: {underlying_exc !r}" + ) + if not connection_closed_cleanly: + client_payload_exc_msg = ( + f"{client_payload_exc_msg !s}. " + f"{original_connection_error !r}" + ) + set_exception( + self._payload, + ClientPayloadError(client_payload_exc_msg), + underlying_exc, + ) if not self.is_eof(): - if isinstance(exc, OSError): - exc = ClientOSError(*exc.args) - if exc is None: - exc = ServerDisconnectedError(uncompleted) + if isinstance(original_connection_error, OSError): + reraised_exc = ClientOSError(*original_connection_error.args) + if connection_closed_cleanly: + reraised_exc = ServerDisconnectedError(uncompleted) # assigns self._should_close to True as side effect, # we do it anyway below - self.set_exception(exc) + underlying_non_eof_exc = ( + _EXC_SENTINEL + if connection_closed_cleanly + else original_connection_error + ) + assert underlying_non_eof_exc is not None + assert reraised_exc is not None + self.set_exception(reraised_exc, underlying_non_eof_exc) self._should_close = True self._parser = None @@ -102,7 +130,7 @@ def connection_lost(self, exc: Optional[BaseException]) -> None: self._payload_parser = None self._reading_paused = False - super().connection_lost(exc) + super().connection_lost(reraised_exc) def eof_received(self) -> None: # should call parser.feed_eof() most likely @@ -116,10 +144,14 @@ def resume_reading(self) -> None: super().resume_reading() self._reschedule_timeout() - def set_exception(self, exc: BaseException) -> None: + def set_exception( + self, + exc: BaseException, + exc_cause: BaseException = _EXC_SENTINEL, + ) -> None: self._should_close = True self._drop_timeout() - super().set_exception(exc) + super().set_exception(exc, exc_cause) def set_parser(self, parser: Any, payload: Any) -> None: # TODO: actual types are: @@ -196,7 +228,7 @@ def _on_read_timeout(self) -> None: exc = ServerTimeoutError("Timeout on reading data from socket") self.set_exception(exc) if self._payload is not None: - self._payload.set_exception(exc) + set_exception(self._payload, exc) def data_received(self, data: bytes) -> None: self._reschedule_timeout() @@ -222,14 +254,14 @@ def data_received(self, data: bytes) -> None: # parse http messages try: messages, upgraded, tail = self._parser.feed_data(data) - except BaseException as exc: + except BaseException as underlying_exc: if self.transport is not None: # connection.release() could be called BEFORE # data_received(), the transport is already # closed in this case self.transport.close() # should_close is True after the call - self.set_exception(exc) + self.set_exception(HttpProcessingError(), underlying_exc) return self._upgraded = upgraded diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index e0de951a33a..afe719da16e 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -50,6 +50,7 @@ netrc_from_env, noop, reify, + set_exception, set_result, ) from .http import ( @@ -630,20 +631,29 @@ async def write_bytes( for chunk in self.body: await writer.write(chunk) # type: ignore[arg-type] - except OSError as exc: - if exc.errno is None and isinstance(exc, asyncio.TimeoutError): - protocol.set_exception(exc) - else: - new_exc = ClientOSError( - exc.errno, "Can not write request body for %s" % self.url + except OSError as underlying_exc: + reraised_exc = underlying_exc + + exc_is_not_timeout = underlying_exc.errno is not None or not isinstance( + underlying_exc, asyncio.TimeoutError + ) + if exc_is_not_timeout: + reraised_exc = ClientOSError( + underlying_exc.errno, + f"Can not write request body for {self.url !s}", ) - new_exc.__context__ = exc - new_exc.__cause__ = exc - protocol.set_exception(new_exc) + + set_exception(protocol, reraised_exc, underlying_exc) except asyncio.CancelledError: await writer.write_eof() - except Exception as exc: - protocol.set_exception(exc) + except Exception as underlying_exc: + set_exception( + protocol, + ClientConnectionError( + f"Failed to send bytes into the underlying connection {conn !s}", + ), + underlying_exc, + ) else: await writer.write_eof() protocol.start_timeout() @@ -1086,7 +1096,7 @@ def _cleanup_writer(self) -> None: def _notify_content(self) -> None: content = self.content if content and content.exception() is None: - content.set_exception(ClientConnectionError("Connection closed")) + set_exception(content, ClientConnectionError("Connection closed")) self._released = True async def wait_for_close(self) -> None: diff --git a/aiohttp/helpers.py b/aiohttp/helpers.py index a5c762ed795..284033b7a04 100644 --- a/aiohttp/helpers.py +++ b/aiohttp/helpers.py @@ -810,9 +810,39 @@ def set_result(fut: "asyncio.Future[_T]", result: _T) -> None: fut.set_result(result) -def set_exception(fut: "asyncio.Future[_T]", exc: BaseException) -> None: - if not fut.done(): - fut.set_exception(exc) +_EXC_SENTINEL = BaseException() + + +class ErrorableProtocol(Protocol): + def set_exception( + self, + exc: BaseException, + exc_cause: BaseException = ..., + ) -> None: + ... # pragma: no cover + + +def set_exception( + fut: "asyncio.Future[_T] | ErrorableProtocol", + exc: BaseException, + exc_cause: BaseException = _EXC_SENTINEL, +) -> None: + """Set future exception. + + If the future is marked as complete, this function is a no-op. + + :param exc_cause: An exception that is a direct cause of ``exc``. + Only set if provided. + """ + if asyncio.isfuture(fut) and fut.done(): + return + + exc_is_sentinel = exc_cause is _EXC_SENTINEL + exc_causes_itself = exc is exc_cause + if not exc_is_sentinel and not exc_causes_itself: + exc.__cause__ = exc_cause + + fut.set_exception(exc) @functools.total_ordering diff --git a/aiohttp/http_parser.py b/aiohttp/http_parser.py index 1877f558308..1301f025810 100644 --- a/aiohttp/http_parser.py +++ b/aiohttp/http_parser.py @@ -28,10 +28,12 @@ from .base_protocol import BaseProtocol from .compression_utils import HAS_BROTLI, BrotliDecompressor, ZLibDecompressor from .helpers import ( + _EXC_SENTINEL, DEBUG, NO_EXTENSIONS, BaseTimerContext, method_must_be_empty_body, + set_exception, status_code_must_be_empty_body, ) from .http_exceptions import ( @@ -446,13 +448,16 @@ def get_content_length() -> Optional[int]: assert self._payload_parser is not None try: eof, data = self._payload_parser.feed_data(data[start_pos:], SEP) - except BaseException as exc: + except BaseException as underlying_exc: + reraised_exc = underlying_exc if self.payload_exception is not None: - self._payload_parser.payload.set_exception( - self.payload_exception(str(exc)) - ) - else: - self._payload_parser.payload.set_exception(exc) + reraised_exc = self.payload_exception(str(underlying_exc)) + + set_exception( + self._payload_parser.payload, + reraised_exc, + underlying_exc, + ) eof = True data = b"" @@ -834,7 +839,7 @@ def feed_data( exc = TransferEncodingError( chunk[:pos].decode("ascii", "surrogateescape") ) - self.payload.set_exception(exc) + set_exception(self.payload, exc) raise exc size = int(bytes(size_b), 16) @@ -939,8 +944,12 @@ def __init__(self, out: StreamReader, encoding: Optional[str]) -> None: else: self.decompressor = ZLibDecompressor(encoding=encoding) - def set_exception(self, exc: BaseException) -> None: - self.out.set_exception(exc) + def set_exception( + self, + exc: BaseException, + exc_cause: BaseException = _EXC_SENTINEL, + ) -> None: + set_exception(self.out, exc, exc_cause) def feed_data(self, chunk: bytes, size: int) -> None: if not size: diff --git a/aiohttp/http_websocket.py b/aiohttp/http_websocket.py index b63453f99e5..39f2e4a5c15 100644 --- a/aiohttp/http_websocket.py +++ b/aiohttp/http_websocket.py @@ -25,7 +25,7 @@ from .base_protocol import BaseProtocol from .compression_utils import ZLibCompressor, ZLibDecompressor -from .helpers import NO_EXTENSIONS +from .helpers import NO_EXTENSIONS, set_exception from .streams import DataQueue __all__ = ( @@ -314,7 +314,7 @@ def feed_data(self, data: bytes) -> Tuple[bool, bytes]: return self._feed_data(data) except Exception as exc: self._exc = exc - self.queue.set_exception(exc) + set_exception(self.queue, exc) return True, b"" def _feed_data(self, data: bytes) -> Tuple[bool, bytes]: diff --git a/aiohttp/streams.py b/aiohttp/streams.py index 3e4c355b5cb..b9b9c3fd96f 100644 --- a/aiohttp/streams.py +++ b/aiohttp/streams.py @@ -14,7 +14,13 @@ ) from .base_protocol import BaseProtocol -from .helpers import BaseTimerContext, TimerNoop, set_exception, set_result +from .helpers import ( + _EXC_SENTINEL, + BaseTimerContext, + TimerNoop, + set_exception, + set_result, +) from .log import internal_logger __all__ = ( @@ -146,19 +152,23 @@ def get_read_buffer_limits(self) -> Tuple[int, int]: def exception(self) -> Optional[BaseException]: return self._exception - def set_exception(self, exc: BaseException) -> None: + def set_exception( + self, + exc: BaseException, + exc_cause: BaseException = _EXC_SENTINEL, + ) -> None: self._exception = exc self._eof_callbacks.clear() waiter = self._waiter if waiter is not None: self._waiter = None - set_exception(waiter, exc) + set_exception(waiter, exc, exc_cause) waiter = self._eof_waiter if waiter is not None: self._eof_waiter = None - set_exception(waiter, exc) + set_exception(waiter, exc, exc_cause) def on_eof(self, callback: Callable[[], None]) -> None: if self._eof: @@ -513,7 +523,11 @@ def __repr__(self) -> str: def exception(self) -> Optional[BaseException]: return None - def set_exception(self, exc: BaseException) -> None: + def set_exception( + self, + exc: BaseException, + exc_cause: BaseException = _EXC_SENTINEL, + ) -> None: pass def on_eof(self, callback: Callable[[], None]) -> None: @@ -588,14 +602,18 @@ def at_eof(self) -> bool: def exception(self) -> Optional[BaseException]: return self._exception - def set_exception(self, exc: BaseException) -> None: + def set_exception( + self, + exc: BaseException, + exc_cause: BaseException = _EXC_SENTINEL, + ) -> None: self._eof = True self._exception = exc waiter = self._waiter if waiter is not None: self._waiter = None - set_exception(waiter, exc) + set_exception(waiter, exc, exc_cause) def feed_data(self, data: _T, size: int = 0) -> None: self._size += size diff --git a/aiohttp/web_protocol.py b/aiohttp/web_protocol.py index ec5856a0a22..f083b13eb0f 100644 --- a/aiohttp/web_protocol.py +++ b/aiohttp/web_protocol.py @@ -26,7 +26,7 @@ from .abc import AbstractAccessLogger, AbstractStreamWriter from .base_protocol import BaseProtocol -from .helpers import ceil_timeout +from .helpers import ceil_timeout, set_exception from .http import ( HttpProcessingError, HttpRequestParser, @@ -565,7 +565,7 @@ async def start(self) -> None: self.log_debug("Uncompleted request.") self.close() - payload.set_exception(PayloadAccessError()) + set_exception(payload, PayloadAccessError()) except asyncio.CancelledError: self.log_debug("Ignored premature client disconnection ") diff --git a/aiohttp/web_request.py b/aiohttp/web_request.py index 61fc831b032..781713e5985 100644 --- a/aiohttp/web_request.py +++ b/aiohttp/web_request.py @@ -48,6 +48,7 @@ parse_http_date, reify, sentinel, + set_exception, ) from .http_parser import RawRequestMessage from .http_writer import HttpVersion @@ -814,7 +815,7 @@ async def _prepare_hook(self, response: StreamResponse) -> None: return def _cancel(self, exc: BaseException) -> None: - self._payload.set_exception(exc) + set_exception(self._payload, exc) class Request(BaseRequest): diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index 783377716f5..d20a26ca470 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -11,7 +11,7 @@ from . import hdrs from .abc import AbstractStreamWriter -from .helpers import call_later, set_result +from .helpers import call_later, set_exception, set_result from .http import ( WS_CLOSED_MESSAGE, WS_CLOSING_MESSAGE, @@ -526,4 +526,4 @@ async def __anext__(self) -> WSMessage: def _cancel(self, exc: BaseException) -> None: if self._reader is not None: - self._reader.set_exception(exc) + set_exception(self._reader, exc) diff --git a/tests/test_base_protocol.py b/tests/test_base_protocol.py index b26011095e9..72c8c7c6b63 100644 --- a/tests/test_base_protocol.py +++ b/tests/test_base_protocol.py @@ -186,9 +186,9 @@ async def test_lost_drain_waited_exception() -> None: assert pr._drain_waiter is not None exc = RuntimeError() pr.connection_lost(exc) - with pytest.raises(RuntimeError) as cm: + with pytest.raises(ConnectionError, match=r"^Connection lost$") as cm: await t - assert cm.value is exc + assert cm.value.__cause__ is exc assert pr._drain_waiter is None diff --git a/tests/test_client_request.py b/tests/test_client_request.py index c54e1828e34..6084f685405 100644 --- a/tests/test_client_request.py +++ b/tests/test_client_request.py @@ -14,6 +14,7 @@ import aiohttp from aiohttp import BaseConnector, hdrs, helpers, payload +from aiohttp.client_exceptions import ClientConnectionError from aiohttp.client_reqrep import ( ClientRequest, ClientResponse, @@ -1096,9 +1097,8 @@ async def throw_exc(): # assert connection.close.called assert conn.protocol.set_exception.called outer_exc = conn.protocol.set_exception.call_args[0][0] - assert isinstance(outer_exc, ValueError) - assert inner_exc is outer_exc - assert inner_exc is outer_exc + assert isinstance(outer_exc, ClientConnectionError) + assert outer_exc.__cause__ is inner_exc await req.close() diff --git a/tests/test_http_parser.py b/tests/test_http_parser.py index 3fb0ab77d98..a37a08632d7 100644 --- a/tests/test_http_parser.py +++ b/tests/test_http_parser.py @@ -280,6 +280,7 @@ def test_parse_headers_longline(parser: Any) -> None: header_name = b"Test" + invalid_unicode_byte + b"Header" + b"A" * 8192 text = b"GET /test HTTP/1.1\r\n" + header_name + b": test\r\n" + b"\r\n" + b"\r\n" with pytest.raises((http_exceptions.LineTooLong, http_exceptions.BadHttpMessage)): + # FIXME: `LineTooLong` doesn't seem to actually be happening parser.feed_data(text)