From b069515d7804d977e16e38f55addcf4411f2de06 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 20 Nov 2024 17:17:26 -0600 Subject: [PATCH] [PR #10003/78d1be5 backport][3.10] Fix client connection header not reflecting connector `force_close` value (#10008) --- CHANGES/10003.bugfix.rst | 1 + aiohttp/client_reqrep.py | 32 +++----------- tests/test_benchmarks_client_request.py | 6 +++ tests/test_client_request.py | 58 +++++++++---------------- tests/test_web_functional.py | 5 +-- 5 files changed, 37 insertions(+), 65 deletions(-) create mode 100644 CHANGES/10003.bugfix.rst diff --git a/CHANGES/10003.bugfix.rst b/CHANGES/10003.bugfix.rst new file mode 100644 index 00000000000..69aa554591d --- /dev/null +++ b/CHANGES/10003.bugfix.rst @@ -0,0 +1 @@ +Fixed the HTTP client not considering the connector's ``force_close`` value when setting the ``Connection`` header -- by :user:`bdraco`. diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index 91605f0e83d..b847ec8a261 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -639,18 +639,6 @@ def update_proxy( proxy_headers = CIMultiDict(proxy_headers) self.proxy_headers = proxy_headers - def keep_alive(self) -> bool: - if self.version < HttpVersion10: - # keep alive not supported at all - return False - if self.version == HttpVersion10: - # no headers means we close for Http 1.0 - return self.headers.get(hdrs.CONNECTION) == "keep-alive" - elif self.headers.get(hdrs.CONNECTION) == "close": - return False - - return True - async def write_bytes( self, writer: AbstractStreamWriter, conn: "Connection" ) -> None: @@ -751,21 +739,15 @@ async def send(self, conn: "Connection") -> "ClientResponse": ): self.headers[hdrs.CONTENT_TYPE] = "application/octet-stream" - # set the connection header - connection = self.headers.get(hdrs.CONNECTION) - if not connection: - if self.keep_alive(): - if self.version == HttpVersion10: - connection = "keep-alive" - else: - if self.version == HttpVersion11: - connection = "close" - - if connection is not None: - self.headers[hdrs.CONNECTION] = connection + v = self.version + if hdrs.CONNECTION not in self.headers: + if conn._connector.force_close: + if v == HttpVersion11: + self.headers[hdrs.CONNECTION] = "close" + elif v == HttpVersion10: + self.headers[hdrs.CONNECTION] = "keep-alive" # status + headers - v = self.version status_line = f"{self.method} {path} HTTP/{v.major}.{v.minor}" await writer.write_headers(status_line, self.headers) coro = self.write_bytes(writer, conn) diff --git a/tests/test_benchmarks_client_request.py b/tests/test_benchmarks_client_request.py index 3f132d04d14..2913aa0eb91 100644 --- a/tests/test_benchmarks_client_request.py +++ b/tests/test_benchmarks_client_request.py @@ -95,10 +95,16 @@ async def _drain_helper(self) -> None: def start_timeout(self) -> None: """Swallow start_timeout.""" + class MockConnector: + + def __init__(self) -> None: + self.force_close = False + class MockConnection: def __init__(self) -> None: self.transport = None self.protocol = MockProtocol() + self._connector = MockConnector() conn = MockConnection() diff --git a/tests/test_client_request.py b/tests/test_client_request.py index bf2fd4b7bc0..abb59641d0c 100644 --- a/tests/test_client_request.py +++ b/tests/test_client_request.py @@ -23,7 +23,7 @@ _gen_default_accept_encoding, _merge_ssl_params, ) -from aiohttp.http import HttpVersion +from aiohttp.http import HttpVersion10, HttpVersion11 from aiohttp.test_utils import make_mocked_coro @@ -140,30 +140,6 @@ def test_version_err(make_request) -> None: make_request("get", "http://python.org/", version="1.c") -def test_keep_alive(make_request) -> None: - req = make_request("get", "http://python.org/", version=(0, 9)) - assert not req.keep_alive() - - req = make_request("get", "http://python.org/", version=(1, 0)) - assert not req.keep_alive() - - req = make_request( - "get", - "http://python.org/", - version=(1, 0), - headers={"connection": "keep-alive"}, - ) - assert req.keep_alive() - - req = make_request("get", "http://python.org/", version=(1, 1)) - assert req.keep_alive() - - req = make_request( - "get", "http://python.org/", version=(1, 1), headers={"connection": "close"} - ) - assert not req.keep_alive() - - def test_host_port_default_http(make_request) -> None: req = make_request("get", "http://python.org/") assert req.host == "python.org" @@ -634,32 +610,40 @@ def test_gen_netloc_no_port(make_request) -> None: ) -async def test_connection_header(loop, conn) -> None: +async def test_connection_header( + loop: asyncio.AbstractEventLoop, conn: mock.Mock +) -> None: req = ClientRequest("get", URL("http://python.org"), loop=loop) - req.keep_alive = mock.Mock() req.headers.clear() - req.keep_alive.return_value = True - req.version = HttpVersion(1, 1) + req.version = HttpVersion11 req.headers.clear() - await req.send(conn) + with mock.patch.object(conn._connector, "force_close", False): + await req.send(conn) assert req.headers.get("CONNECTION") is None - req.version = HttpVersion(1, 0) + req.version = HttpVersion10 req.headers.clear() - await req.send(conn) + with mock.patch.object(conn._connector, "force_close", False): + await req.send(conn) assert req.headers.get("CONNECTION") == "keep-alive" - req.keep_alive.return_value = False - req.version = HttpVersion(1, 1) + req.version = HttpVersion11 req.headers.clear() - await req.send(conn) + with mock.patch.object(conn._connector, "force_close", True): + await req.send(conn) assert req.headers.get("CONNECTION") == "close" - await req.close() + req.version = HttpVersion10 + req.headers.clear() + with mock.patch.object(conn._connector, "force_close", True): + await req.send(conn) + assert not req.headers.get("CONNECTION") -async def test_no_content_length(loop, conn) -> None: +async def test_no_content_length( + loop: asyncio.AbstractEventLoop, conn: mock.Mock +) -> None: req = ClientRequest("get", URL("http://python.org"), loop=loop) resp = await req.send(conn) assert req.headers.get("CONTENT-LENGTH") is None diff --git a/tests/test_web_functional.py b/tests/test_web_functional.py index eadb43b1ecb..c61740dc61e 100644 --- a/tests/test_web_functional.py +++ b/tests/test_web_functional.py @@ -696,9 +696,8 @@ async def handler(request): await resp.release() -@pytest.mark.xfail -async def test_http10_keep_alive_default(aiohttp_client) -> None: - async def handler(request): +async def test_http10_keep_alive_default(aiohttp_client: AiohttpClient) -> None: + async def handler(request: web.Request) -> web.Response: return web.Response() app = web.Application()