diff --git a/CHANGES/10003.bugfix.rst b/CHANGES/10003.bugfix.rst new file mode 100644 index 00000000000..69aa554591d --- /dev/null +++ b/CHANGES/10003.bugfix.rst @@ -0,0 +1 @@ +Fixed the HTTP client not considering the connector's ``force_close`` value when setting the ``Connection`` header -- by :user:`bdraco`. diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index 267b509b0e6..a0fa093d92e 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -634,15 +634,6 @@ def update_proxy( proxy_headers = CIMultiDict(proxy_headers) self.proxy_headers = proxy_headers - def keep_alive(self) -> bool: - if self.version >= HttpVersion11: - return self.headers.get(hdrs.CONNECTION) != "close" - if self.version == HttpVersion10: - # no headers means we close for Http 1.0 - return self.headers.get(hdrs.CONNECTION) == "keep-alive" - # keep alive not supported at all - return False - async def write_bytes( self, writer: AbstractStreamWriter, conn: "Connection" ) -> None: @@ -737,21 +728,15 @@ async def send(self, conn: "Connection") -> "ClientResponse": ): self.headers[hdrs.CONTENT_TYPE] = "application/octet-stream" - # set the connection header - connection = self.headers.get(hdrs.CONNECTION) - if not connection: - if self.keep_alive(): - if self.version == HttpVersion10: - connection = "keep-alive" - else: - if self.version == HttpVersion11: - connection = "close" - - if connection is not None: - self.headers[hdrs.CONNECTION] = connection + v = self.version + if hdrs.CONNECTION not in self.headers: + if conn._connector.force_close: + if v == HttpVersion11: + self.headers[hdrs.CONNECTION] = "close" + elif v == HttpVersion10: + self.headers[hdrs.CONNECTION] = "keep-alive" # status + headers - v = self.version status_line = f"{self.method} {path} HTTP/{v.major}.{v.minor}" await writer.write_headers(status_line, self.headers) task: Optional["asyncio.Task[None]"] diff --git a/tests/test_benchmarks_client_request.py b/tests/test_benchmarks_client_request.py index 0cdf1f2d776..65667995185 100644 --- a/tests/test_benchmarks_client_request.py +++ b/tests/test_benchmarks_client_request.py @@ -100,10 +100,16 @@ async def _drain_helper(self) -> None: def start_timeout(self) -> None: """Swallow start_timeout.""" + class MockConnector: + + def __init__(self) -> None: + self.force_close = False + class MockConnection: def __init__(self) -> None: self.transport = None self.protocol = MockProtocol() + self._connector = MockConnector() conn = MockConnection() diff --git a/tests/test_client_request.py b/tests/test_client_request.py index 20ccf6c03d1..324eddf7f6e 100644 --- a/tests/test_client_request.py +++ b/tests/test_client_request.py @@ -23,7 +23,7 @@ _gen_default_accept_encoding, _merge_ssl_params, ) -from aiohttp.http import HttpVersion +from aiohttp.http import HttpVersion10, HttpVersion11 from aiohttp.test_utils import make_mocked_coro @@ -141,30 +141,6 @@ def test_version_err(make_request) -> None: make_request("get", "http://python.org/", version="1.c") -def test_keep_alive(make_request) -> None: - req = make_request("get", "http://python.org/", version=(0, 9)) - assert not req.keep_alive() - - req = make_request("get", "http://python.org/", version=(1, 0)) - assert not req.keep_alive() - - req = make_request( - "get", - "http://python.org/", - version=(1, 0), - headers={"connection": "keep-alive"}, - ) - assert req.keep_alive() - - req = make_request("get", "http://python.org/", version=(1, 1)) - assert req.keep_alive() - - req = make_request( - "get", "http://python.org/", version=(1, 1), headers={"connection": "close"} - ) - assert not req.keep_alive() - - def test_host_port_default_http(make_request) -> None: req = make_request("get", "http://python.org/") assert req.host == "python.org" @@ -628,32 +604,40 @@ def test_gen_netloc_no_port(make_request) -> None: ) -async def test_connection_header(loop, conn) -> None: +async def test_connection_header( + loop: asyncio.AbstractEventLoop, conn: mock.Mock +) -> None: req = ClientRequest("get", URL("http://python.org"), loop=loop) - req.keep_alive = mock.Mock() req.headers.clear() - req.keep_alive.return_value = True - req.version = HttpVersion(1, 1) + req.version = HttpVersion11 req.headers.clear() - await req.send(conn) + with mock.patch.object(conn._connector, "force_close", False): + await req.send(conn) assert req.headers.get("CONNECTION") is None - req.version = HttpVersion(1, 0) + req.version = HttpVersion10 req.headers.clear() - await req.send(conn) + with mock.patch.object(conn._connector, "force_close", False): + await req.send(conn) assert req.headers.get("CONNECTION") == "keep-alive" - req.keep_alive.return_value = False - req.version = HttpVersion(1, 1) + req.version = HttpVersion11 req.headers.clear() - await req.send(conn) + with mock.patch.object(conn._connector, "force_close", True): + await req.send(conn) assert req.headers.get("CONNECTION") == "close" - await req.close() + req.version = HttpVersion10 + req.headers.clear() + with mock.patch.object(conn._connector, "force_close", True): + await req.send(conn) + assert not req.headers.get("CONNECTION") -async def test_no_content_length(loop, conn) -> None: +async def test_no_content_length( + loop: asyncio.AbstractEventLoop, conn: mock.Mock +) -> None: req = ClientRequest("get", URL("http://python.org"), loop=loop) resp = await req.send(conn) assert req.headers.get("CONTENT-LENGTH") is None diff --git a/tests/test_web_functional.py b/tests/test_web_functional.py index e46a23c5857..a3a990141a1 100644 --- a/tests/test_web_functional.py +++ b/tests/test_web_functional.py @@ -696,9 +696,8 @@ async def handler(request): await resp.release() -@pytest.mark.xfail -async def test_http10_keep_alive_default(aiohttp_client) -> None: - async def handler(request): +async def test_http10_keep_alive_default(aiohttp_client: AiohttpClient) -> None: + async def handler(request: web.Request) -> web.Response: return web.Response() app = web.Application()