Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Ensure str header values in connection.py #142

Merged
merged 2 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/urllib3/_async/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,13 @@ async def request(
value = str(content_length)
if enforce_charset_transparency and header.lower() == "content-type":
value_lower = value.lower()
# even if not "officially" supported
# some may send values as bytes, and we have to
# cast "temporarily" the value
# this case is already covered in the parent class.
if isinstance(value_lower, bytes):
value_lower = value_lower.decode()
value = value.decode()
if "charset" not in value_lower:
value = value.strip("; ")
value = f"{value}; charset=utf-8"
Expand Down
7 changes: 7 additions & 0 deletions src/urllib3/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,13 @@ def request(
value = str(content_length)
if enforce_charset_transparency and header.lower() == "content-type":
value_lower = value.lower()
# even if not "officially" supported
# some may send values as bytes, and we have to
# cast "temporarily" the value
# this case is already covered in the parent class.
if isinstance(value_lower, bytes):
value_lower = value_lower.decode()
value = value.decode()
if "charset" not in value_lower:
value = value.strip("; ")
value = f"{value}; charset=utf-8"
Expand Down
36 changes: 36 additions & 0 deletions test/with_dummyserver/test_socketlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2435,6 +2435,42 @@ def socket_handler(listener: socket.socket) -> None:

assert b"Content-Type: application/json; charset=utf-8\r\n" in sent_bytes

def test_partial_overrule_bytes_content_type(self) -> None:
buffer = bytearray()

def socket_handler(listener: socket.socket) -> None:
nonlocal buffer
sock = listener.accept()[0]
sock.settimeout(0)

start = time.time()
while time.time() - start < (LONG_TIMEOUT / 2):
try:
buffer += sock.recv(65536)
except OSError:
continue

sock.sendall(
b"HTTP/1.1 200 OK\r\n"
b"Server: example.com\r\n"
b"Content-Length: 0\r\n\r\n"
)
sock.close()

self._start_server(socket_handler)

with HTTPConnectionPool(
self.host, self.port, timeout=LONG_TIMEOUT, retries=False
) as pool:
resp = pool.request(
"POST", "/", body="{}", headers={"Content-Type": b"application/json"} # type: ignore[dict-item]
)
assert resp.status == 200

sent_bytes = bytes(buffer)

assert b"Content-Type: application/json; charset=utf-8\r\n" in sent_bytes

def test_no_overrule_str_content_type(self) -> None:
buffer = bytearray()

Expand Down
24 changes: 24 additions & 0 deletions test/with_traefik/asynchronous/test_send_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,30 @@ async def test_overrule_unicode_content_length(self) -> None:
assert "Content-Length" in (await resp.json())["headers"]
assert (await resp.json())["headers"]["Content-Length"][0] == "4"

async def test_overrule_unicode_content_length_with_bytes_content_type(
self,
) -> None:
async with AsyncHTTPSConnectionPool(
self.host,
self.https_port,
ca_certs=self.ca_authority,
resolver=self.test_async_resolver,
) as p:
resp = await p.request(
"POST",
"/post",
body="🚀",
headers={"Content-Length": "1", "Content-Type": b"plain/text"}, # type: ignore[dict-item]
)

assert resp.status == 200
assert "Content-Length" in (await resp.json())["headers"]
assert "Content-Type" in (await resp.json())["headers"]
assert (await resp.json())["headers"]["Content-Type"][
0
] == "plain/text; charset=utf-8"
assert (await resp.json())["headers"]["Content-Length"][0] == "4"

@pytest.mark.parametrize(
"method",
[
Expand Down
24 changes: 24 additions & 0 deletions test/with_traefik/test_send_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,30 @@ def test_overrule_unicode_content_length(self) -> None:
assert "Content-Length" in resp.json()["headers"]
assert resp.json()["headers"]["Content-Length"][0] == "4"

def test_overrule_unicode_content_length_with_bytes_content_type(
self,
) -> None:
with HTTPSConnectionPool(
self.host,
self.https_port,
ca_certs=self.ca_authority,
resolver=self.test_resolver,
) as p:
resp = p.request(
"POST",
"/post",
body="🚀",
headers={"Content-Length": "1", "Content-Type": b"plain/text"}, # type: ignore[dict-item]
)

assert resp.status == 200
assert "Content-Length" in resp.json()["headers"]
assert "Content-Type" in resp.json()["headers"]
assert (
resp.json()["headers"]["Content-Type"][0] == "plain/text; charset=utf-8"
)
assert resp.json()["headers"]["Content-Length"][0] == "4"

@pytest.mark.parametrize(
"method",
[
Expand Down
Loading