diff --git a/changelog.d/17894.misc b/changelog.d/17894.misc new file mode 100644 index 00000000000..dc1a7577abf --- /dev/null +++ b/changelog.d/17894.misc @@ -0,0 +1 @@ +Remove usage of internal header encoding API. diff --git a/synapse/http/proxy.py b/synapse/http/proxy.py index 97aa429e7d4..5cd990b0d07 100644 --- a/synapse/http/proxy.py +++ b/synapse/http/proxy.py @@ -51,25 +51,17 @@ # "Hop-by-hop" headers (as opposed to "end-to-end" headers) as defined by RFC2616 # section 13.5.1 and referenced in RFC9110 section 7.6.1. These are meant to only be # consumed by the immediate recipient and not be forwarded on. -HOP_BY_HOP_HEADERS = { - "Connection", - "Keep-Alive", - "Proxy-Authenticate", - "Proxy-Authorization", - "TE", - "Trailers", - "Transfer-Encoding", - "Upgrade", +HOP_BY_HOP_HEADERS_LOWERCASE = { + "connection", + "keep-alive", + "proxy-authenticate", + "proxy-authorization", + "te", + "trailers", + "transfer-encoding", + "upgrade", } - -if hasattr(Headers, "_canonicalNameCaps"): - # Twisted < 24.7.0rc1 - _canonicalHeaderName = Headers()._canonicalNameCaps # type: ignore[attr-defined] -else: - # Twisted >= 24.7.0rc1 - # But note that `_encodeName` still exists on prior versions, - # it just encodes differently - _canonicalHeaderName = Headers()._encodeName +assert all(header.lower() == header for header in HOP_BY_HOP_HEADERS_LOWERCASE) def parse_connection_header_value( @@ -92,12 +84,12 @@ def parse_connection_header_value( Returns: The set of header names that should not be copied over from the remote response. - The keys are capitalized in canonical capitalization. + The keys are lowercased. """ extra_headers_to_remove: Set[str] = set() if connection_header_value: extra_headers_to_remove = { - _canonicalHeaderName(connection_option.strip()).decode("ascii") + connection_option.decode("ascii").strip().lower() for connection_option in connection_header_value.split(b",") } @@ -194,7 +186,7 @@ def _send_response( # The `Connection` header also defines which headers should not be copied over. connection_header = response_headers.getRawHeaders(b"connection") - extra_headers_to_remove = parse_connection_header_value( + extra_headers_to_remove_lowercase = parse_connection_header_value( connection_header[0] if connection_header else None ) @@ -202,10 +194,10 @@ def _send_response( for k, v in response_headers.getAllRawHeaders(): # Do not copy over any hop-by-hop headers. These are meant to only be # consumed by the immediate recipient and not be forwarded on. - header_key = k.decode("ascii") + header_key_lowercase = k.decode("ascii").lower() if ( - header_key in HOP_BY_HOP_HEADERS - or header_key in extra_headers_to_remove + header_key_lowercase in HOP_BY_HOP_HEADERS_LOWERCASE + or header_key_lowercase in extra_headers_to_remove_lowercase ): continue diff --git a/tests/http/test_matrixfederationclient.py b/tests/http/test_matrixfederationclient.py index 6588695e373..e34df54e13c 100644 --- a/tests/http/test_matrixfederationclient.py +++ b/tests/http/test_matrixfederationclient.py @@ -903,12 +903,19 @@ def test_proxy_requests_and_discards_hop_by_hop_headers(self) -> None: headers=Headers( { "Content-Type": ["application/json"], - "Connection": ["close, X-Foo, X-Bar"], + "X-Test": ["test"], + # Define some hop-by-hop headers (try with varying casing to + # make sure we still match-up the headers) + "Connection": ["close, X-fOo, X-Bar, X-baz"], # Should be removed because it's defined in the `Connection` header "X-Foo": ["foo"], "X-Bar": ["bar"], + # (not in canonical case) + "x-baZ": ["baz"], # Should be removed because it's a hop-by-hop header "Proxy-Authorization": "abcdef", + # Should be removed because it's a hop-by-hop header (not in canonical case) + "transfer-EnCoDiNg": "abcdef", } ), ) @@ -938,9 +945,17 @@ def test_proxy_requests_and_discards_hop_by_hop_headers(self) -> None: header_names = set(headers.keys()) # Make sure the response does not include the hop-by-hop headers - self.assertNotIn(b"X-Foo", header_names) - self.assertNotIn(b"X-Bar", header_names) - self.assertNotIn(b"Proxy-Authorization", header_names) + self.assertIncludes( + header_names, + { + b"Content-Type", + b"X-Test", + # Default headers from Twisted + b"Date", + b"Server", + }, + exact=True, + ) # Make sure the response is as expected back on the main worker self.assertEqual(res, {"foo": "bar"}) diff --git a/tests/http/test_proxy.py b/tests/http/test_proxy.py index 58952704948..7110dcf9f94 100644 --- a/tests/http/test_proxy.py +++ b/tests/http/test_proxy.py @@ -22,27 +22,42 @@ from parameterized import parameterized -from synapse.http.proxy import parse_connection_header_value +from synapse.http.proxy import ( + HOP_BY_HOP_HEADERS_LOWERCASE, + parse_connection_header_value, +) from tests.unittest import TestCase +def mix_case(s: str) -> str: + """ + Mix up the case of each character in the string (upper or lower case) + """ + return "".join(c.upper() if i % 2 == 0 else c.lower() for i, c in enumerate(s)) + + class ProxyTests(TestCase): @parameterized.expand( [ - [b"close, X-Foo, X-Bar", {"Close", "X-Foo", "X-Bar"}], + [b"close, X-Foo, X-Bar", {"close", "x-foo", "x-bar"}], # No whitespace - [b"close,X-Foo,X-Bar", {"Close", "X-Foo", "X-Bar"}], + [b"close,X-Foo,X-Bar", {"close", "x-foo", "x-bar"}], # More whitespace - [b"close, X-Foo, X-Bar", {"Close", "X-Foo", "X-Bar"}], + [b"close, X-Foo, X-Bar", {"close", "x-foo", "x-bar"}], # "close" directive in not the first position - [b"X-Foo, X-Bar, close", {"X-Foo", "X-Bar", "Close"}], + [b"X-Foo, X-Bar, close", {"x-foo", "x-bar", "close"}], # Normalizes header capitalization - [b"keep-alive, x-fOo, x-bAr", {"Keep-Alive", "X-Foo", "X-Bar"}], + [b"keep-alive, x-fOo, x-bAr", {"keep-alive", "x-foo", "x-bar"}], # Handles header names with whitespace [ b"keep-alive, x foo, x bar", - {"Keep-Alive", "X foo", "X bar"}, + {"keep-alive", "x foo", "x bar"}, + ], + # Make sure we handle all of the hop-by-hop headers + [ + mix_case(", ".join(HOP_BY_HOP_HEADERS_LOWERCASE)).encode("ascii"), + HOP_BY_HOP_HEADERS_LOWERCASE, ], ] ) @@ -54,7 +69,8 @@ def test_parse_connection_header_value( """ Tests that the connection header value is parsed correctly """ - self.assertEqual( + self.assertIncludes( expected_extra_headers_to_remove, parse_connection_header_value(connection_header_value), + exact=True, )