diff --git a/CHANGES/7815.bugfix b/CHANGES/7815.bugfix new file mode 100644 index 00000000000..269c2680d0b --- /dev/null +++ b/CHANGES/7815.bugfix @@ -0,0 +1 @@ +Fixed an issue where the client could go into an infinite loop. -- by :user:`Dreamsorcerer` diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index f509b55e5ad..8e9e080585a 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -56,7 +56,13 @@ reify, set_result, ) -from .http import SERVER_SOFTWARE, HttpVersion10, HttpVersion11, StreamWriter +from .http import ( + SERVER_SOFTWARE, + HttpVersion, + HttpVersion10, + HttpVersion11, + StreamWriter, +) from .log import client_logger from .streams import StreamReader from .typedefs import ( @@ -178,7 +184,7 @@ class ClientRequest: auth = None response = None - _writer = None # async task for streaming data + __writer = None # async task for streaming data _continue = None # waiter future for '100 Continue' response # N.B. @@ -265,6 +271,21 @@ def __init__( traces = [] self._traces = traces + def __reset_writer(self, _: object = None) -> None: + self.__writer = None + + @property + def _writer(self) -> Optional["asyncio.Task[None]"]: + return self.__writer + + @_writer.setter + def _writer(self, writer: Optional["asyncio.Task[None]"]) -> None: + if self.__writer is not None: + self.__writer.remove_done_callback(self.__reset_writer) + self.__writer = writer + if writer is not None: + writer.add_done_callback(self.__reset_writer) + def is_ssl(self) -> bool: return self.url.scheme in ("https", "wss") @@ -563,8 +584,6 @@ async def write_bytes( else: await writer.write_eof() protocol.start_timeout() - finally: - self._writer = None async def send(self, conn: "Connection") -> "ClientResponse": # Specify request target: @@ -649,16 +668,14 @@ async def send(self, conn: "Connection") -> "ClientResponse": async def close(self) -> None: if self._writer is not None: - try: - with contextlib.suppress(asyncio.CancelledError): - await self._writer - finally: - self._writer = None + with contextlib.suppress(asyncio.CancelledError): + await self._writer def terminate(self) -> None: if self._writer is not None: if not self.loop.is_closed(): self._writer.cancel() + self._writer.remove_done_callback(self.__reset_writer) self._writer = None async def _on_chunk_request_sent(self, method: str, url: URL, chunk: bytes) -> None: @@ -677,9 +694,9 @@ class ClientResponse(HeadersMixin): # but will be set by the start() method. # As the end user will likely never see the None values, we cheat the types below. # from the Status-Line of the response - version = None # HTTP-Version + version: Optional[HttpVersion] = None # HTTP-Version status: int = None # type: ignore[assignment] # Status-Code - reason = None # Reason-Phrase + reason: Optional[str] = None # Reason-Phrase content: StreamReader = None # type: ignore[assignment] # Payload stream _headers: CIMultiDictProxy[str] = None # type: ignore[assignment] @@ -691,6 +708,7 @@ class ClientResponse(HeadersMixin): # post-init stage allows to not change ctor signature _closed = True # to allow __del__ for non-initialized properly response _released = False + __writer = None def __init__( self, @@ -737,6 +755,21 @@ def __init__( if loop.get_debug(): self._source_traceback = traceback.extract_stack(sys._getframe(1)) + def __reset_writer(self, _: object = None) -> None: + self.__writer = None + + @property + def _writer(self) -> Optional["asyncio.Task[None]"]: + return self.__writer + + @_writer.setter + def _writer(self, writer: Optional["asyncio.Task[None]"]) -> None: + if self.__writer is not None: + self.__writer.remove_done_callback(self.__reset_writer) + self.__writer = writer + if writer is not None: + writer.add_done_callback(self.__reset_writer) + @reify def url(self) -> URL: return self._url @@ -797,7 +830,7 @@ def __repr__(self) -> str: "ascii", "backslashreplace" ).decode("ascii") else: - ascii_encodable_reason = self.reason + ascii_encodable_reason = "None" print( "".format( ascii_encodable_url, self.status, ascii_encodable_reason @@ -978,18 +1011,12 @@ def _release_connection(self) -> None: async def _wait_released(self) -> None: if self._writer is not None: - try: - await self._writer - finally: - self._writer = None + await self._writer self._release_connection() def _cleanup_writer(self) -> None: if self._writer is not None: - if self._writer.done(): - self._writer = None - else: - self._writer.cancel() + self._writer.cancel() self._session = None def _notify_content(self) -> None: @@ -1001,10 +1028,7 @@ def _notify_content(self) -> None: async def wait_for_close(self) -> None: if self._writer is not None: - try: - await self._writer - finally: - self._writer = None + await self._writer self.release() async def read(self) -> bytes: diff --git a/tests/test_client_response.py b/tests/test_client_response.py index 64161ac5941..1289f8caa6d 100644 --- a/tests/test_client_response.py +++ b/tests/test_client_response.py @@ -4,7 +4,7 @@ import gc import sys from json import JSONDecodeError -from typing import Any +from typing import Any, Callable from unittest import mock import pytest @@ -22,6 +22,9 @@ class WriterMock(mock.AsyncMock): def __await__(self) -> None: return self().__await__() + def add_done_callback(self, cb: Callable[[], None]) -> None: + cb() + def done(self) -> bool: return True diff --git a/tests/test_proxy.py b/tests/test_proxy.py index 58396eeb65f..eaa13fc3510 100644 --- a/tests/test_proxy.py +++ b/tests/test_proxy.py @@ -199,7 +199,7 @@ def test_proxy_server_hostname_default(self, ClientRequestMock) -> None: "get", URL("http://proxy.example.com"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=None, continue100=None, timer=TimerNoop(), traces=[], @@ -261,7 +261,7 @@ def test_proxy_server_hostname_override(self, ClientRequestMock) -> None: "get", URL("http://proxy.example.com"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=None, continue100=None, timer=TimerNoop(), traces=[], @@ -323,7 +323,7 @@ def test_https_connect(self, ClientRequestMock: Any) -> None: "get", URL("http://proxy.example.com"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=None, continue100=None, timer=TimerNoop(), traces=[], @@ -383,7 +383,7 @@ def test_https_connect_certificate_error(self, ClientRequestMock: Any) -> None: "get", URL("http://proxy.example.com"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=None, continue100=None, timer=TimerNoop(), traces=[], @@ -437,7 +437,7 @@ def test_https_connect_ssl_error(self, ClientRequestMock: Any) -> None: "get", URL("http://proxy.example.com"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=None, continue100=None, timer=TimerNoop(), traces=[], @@ -493,7 +493,7 @@ def test_https_connect_http_proxy_error(self, ClientRequestMock: Any) -> None: "get", URL("http://proxy.example.com"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=None, continue100=None, timer=TimerNoop(), traces=[], @@ -552,7 +552,7 @@ def test_https_connect_resp_start_error(self, ClientRequestMock: Any) -> None: "get", URL("http://proxy.example.com"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=None, continue100=None, timer=TimerNoop(), traces=[], @@ -663,7 +663,7 @@ def test_https_connect_pass_ssl_context(self, ClientRequestMock: Any) -> None: "get", URL("http://proxy.example.com"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=None, continue100=None, timer=TimerNoop(), traces=[], @@ -734,7 +734,7 @@ def test_https_auth(self, ClientRequestMock: Any) -> None: "get", URL("http://proxy.example.com"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=None, continue100=None, timer=TimerNoop(), traces=[],