Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix client connection header not reflecting connector force_close value #10003

Merged
merged 13 commits into from
Nov 20, 2024
1 change: 1 addition & 0 deletions CHANGES/10003.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed the HTTP client not considering the connector's ``force_close`` value when setting the ``Connection`` header -- by :user:`bdraco`.
29 changes: 7 additions & 22 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,15 +572,6 @@ def update_proxy(
proxy_headers = CIMultiDict(proxy_headers)
self.proxy_headers = proxy_headers

def keep_alive(self) -> bool:
if self.version >= HttpVersion11:
return self.headers.get(hdrs.CONNECTION) != "close"
if self.version == HttpVersion10:
# no headers means we close for Http 1.0
return self.headers.get(hdrs.CONNECTION) == "keep-alive"
# keep alive not supported at all
return False

async def write_bytes(
self, writer: AbstractStreamWriter, conn: "Connection"
) -> None:
Expand Down Expand Up @@ -678,21 +669,15 @@ async def send(self, conn: "Connection") -> "ClientResponse":
):
self.headers[hdrs.CONTENT_TYPE] = "application/octet-stream"

# set the connection header
connection = self.headers.get(hdrs.CONNECTION)
if not connection:
if self.keep_alive():
if self.version == HttpVersion10:
connection = "keep-alive"
else:
if self.version == HttpVersion11:
connection = "close"

if connection is not None:
self.headers[hdrs.CONNECTION] = connection
v = self.version
if hdrs.CONNECTION not in self.headers:
if conn._connector.force_close:
if v == HttpVersion11:
self.headers[hdrs.CONNECTION] = "close"
elif v == HttpVersion10:
self.headers[hdrs.CONNECTION] = "keep-alive"

# status + headers
v = self.version
status_line = f"{self.method} {path} HTTP/{v.major}.{v.minor}"
await writer.write_headers(status_line, self.headers)
task: Optional["asyncio.Task[None]"]
Expand Down
6 changes: 6 additions & 0 deletions tests/test_benchmarks_client_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,16 @@ async def _drain_helper(self) -> None:
def start_timeout(self) -> None:
"""Swallow start_timeout."""

class MockConnector:

def __init__(self) -> None:
self.force_close = False

class MockConnection:
def __init__(self) -> None:
self.transport = None
self.protocol = MockProtocol()
self._connector = MockConnector()

conn = MockConnection()

Expand Down
58 changes: 20 additions & 38 deletions tests/test_client_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
_gen_default_accept_encoding,
)
from aiohttp.connector import Connection
from aiohttp.http import HttpVersion
from aiohttp.http import HttpVersion10, HttpVersion11
from aiohttp.test_utils import make_mocked_coro
from aiohttp.typedefs import LooseCookies

Expand Down Expand Up @@ -156,30 +156,6 @@ def test_version_err(make_request: _RequestMaker) -> None:
make_request("get", "http://python.org/", version="1.c")


def test_keep_alive(make_request: _RequestMaker) -> None:
req = make_request("get", "http://python.org/", version=(0, 9))
assert not req.keep_alive()

req = make_request("get", "http://python.org/", version=(1, 0))
assert not req.keep_alive()

req = make_request(
"get",
"http://python.org/",
version=(1, 0),
headers={"connection": "keep-alive"},
)
assert req.keep_alive()

req = make_request("get", "http://python.org/", version=(1, 1))
assert req.keep_alive()

req = make_request(
"get", "http://python.org/", version=(1, 1), headers={"connection": "close"}
)
assert not req.keep_alive()


def test_host_port_default_http(make_request: _RequestMaker) -> None:
req = make_request("get", "http://python.org/")
assert req.host == "python.org"
Expand Down Expand Up @@ -624,25 +600,31 @@ async def test_connection_header(
loop: asyncio.AbstractEventLoop, conn: mock.Mock
) -> None:
req = ClientRequest("get", URL("http://python.org"), loop=loop)
with mock.patch.object(req, "keep_alive") as m:
req.headers.clear()
req.headers.clear()

req.version = HttpVersion11
req.headers.clear()
with mock.patch.object(conn._connector, "force_close", False):
await req.send(conn)
assert req.headers.get("CONNECTION") is None

m.return_value = True
req.version = HttpVersion(1, 1)
req.headers.clear()
req.version = HttpVersion10
req.headers.clear()
with mock.patch.object(conn._connector, "force_close", False):
await req.send(conn)
assert req.headers.get("CONNECTION") is None
assert req.headers.get("CONNECTION") == "keep-alive"

req.version = HttpVersion(1, 0)
req.headers.clear()
req.version = HttpVersion11
req.headers.clear()
with mock.patch.object(conn._connector, "force_close", True):
await req.send(conn)
assert req.headers.get("CONNECTION") == "keep-alive"
assert req.headers.get("CONNECTION") == "close"

m.return_value = False
req.version = HttpVersion(1, 1)
req.headers.clear()
req.version = HttpVersion10
req.headers.clear()
with mock.patch.object(conn._connector, "force_close", True):
await req.send(conn)
assert req.headers.get("CONNECTION") == "close"
assert not req.headers.get("CONNECTION")


async def test_no_content_length(
Expand Down
1 change: 0 additions & 1 deletion tests/test_web_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,6 @@ async def handler(request: web.Request) -> web.Response:
resp.release()


@pytest.mark.xfail
async def test_http10_keep_alive_default(aiohttp_client: AiohttpClient) -> None:
async def handler(request: web.Request) -> web.Response:
return web.Response()
Expand Down
Loading