diff --git a/changelog.d/12415.misc b/changelog.d/12415.misc new file mode 100644 index 000000000000..87a5bae5724a --- /dev/null +++ b/changelog.d/12415.misc @@ -0,0 +1 @@ +Improve type hints related to HTTP query parameters. \ No newline at end of file diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 467275b98c3c..6a59cb4b713e 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -56,6 +56,7 @@ from synapse.events import EventBase, builder from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.federation.transport.client import SendJoinResponse +from synapse.http.types import QueryParams from synapse.types import JsonDict, UserID, get_domain_from_id from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.expiringcache import ExpiringCache @@ -154,7 +155,7 @@ async def make_query( self, destination: str, query_type: str, - args: dict, + args: QueryParams, retry_on_dns_fail: bool = False, ignore_backoff: bool = False, ) -> JsonDict: diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 66a41e45fc2d..01dc5ca94f99 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -44,6 +44,7 @@ from synapse.events import EventBase, make_event_from_dict from synapse.federation.units import Transaction from synapse.http.matrixfederationclient import ByteParser +from synapse.http.types import QueryParams from synapse.types import JsonDict logger = logging.getLogger(__name__) @@ -255,7 +256,7 @@ async def make_query( self, destination: str, query_type: str, - args: dict, + args: QueryParams, retry_on_dns_fail: bool, ignore_backoff: bool = False, prefix: str = FEDERATION_V1_PREFIX, @@ -503,7 +504,7 @@ async def get_public_rooms( else: path = _create_v1_path("/publicRooms") - args: Dict[str, Any] = { + args: Dict[str, Union[str, Iterable[str]]] = { "include_all_networks": "true" if include_all_networks else "false" } if third_party_instance_id: diff --git a/synapse/http/client.py b/synapse/http/client.py index c01d2326cf33..8310fb466ac5 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -22,7 +22,6 @@ BinaryIO, Callable, Dict, - Iterable, List, Mapping, Optional, @@ -72,6 +71,7 @@ from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_uri from synapse.http.proxyagent import ProxyAgent +from synapse.http.types import QueryParams from synapse.logging.context import make_deferred_yieldable from synapse.logging.opentracing import set_tag, start_active_span, tags from synapse.types import ISynapseReactor @@ -97,10 +97,6 @@ # the entries can either be Lists or bytes. RawHeaderValue = Sequence[Union[str, bytes]] -# the type of the query params, to be passed into `urlencode` -QueryParamValue = Union[str, bytes, Iterable[Union[str, bytes]]] -QueryParams = Union[Mapping[str, QueryParamValue], Mapping[bytes, QueryParamValue]] - def check_against_blacklist( ip_address: IPAddress, ip_whitelist: Optional[IPSet], ip_blacklist: IPSet @@ -911,7 +907,7 @@ def read_body_with_max_size( return d -def encode_query_args(args: Optional[Mapping[str, Union[str, List[str]]]]) -> bytes: +def encode_query_args(args: Optional[QueryParams]) -> bytes: """ Encodes a map of query arguments to bytes which can be appended to a URL. @@ -924,13 +920,7 @@ def encode_query_args(args: Optional[Mapping[str, Union[str, List[str]]]]) -> by if args is None: return b"" - encoded_args = {} - for k, vs in args.items(): - if isinstance(vs, str): - vs = [vs] - encoded_args[k] = [v.encode("utf8") for v in vs] - - query_str = urllib.parse.urlencode(encoded_args, True) + query_str = urllib.parse.urlencode(args, True) return query_str.encode("utf8") diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 6b98d865f5bb..5097b3ca5796 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -67,6 +67,7 @@ read_body_with_max_size, ) from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent +from synapse.http.types import QueryParams from synapse.logging import opentracing from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.opentracing import set_tag, start_active_span, tags @@ -98,10 +99,6 @@ _next_id = 1 - -QueryArgs = Dict[str, Union[str, List[str]]] - - T = TypeVar("T") @@ -144,7 +141,7 @@ class MatrixFederationRequest: """A callback to generate the JSON. """ - query: Optional[dict] = None + query: Optional[QueryParams] = None """Query arguments. """ @@ -165,10 +162,7 @@ def __attrs_post_init__(self) -> None: destination_bytes = self.destination.encode("ascii") path_bytes = self.path.encode("ascii") - if self.query: - query_bytes = encode_query_args(self.query) - else: - query_bytes = b"" + query_bytes = encode_query_args(self.query) # The object is frozen so we can pre-compute this. uri = urllib.parse.urlunparse( @@ -485,10 +479,7 @@ async def _send_request( method_bytes = request.method.encode("ascii") destination_bytes = request.destination.encode("ascii") path_bytes = request.path.encode("ascii") - if request.query: - query_bytes = encode_query_args(request.query) - else: - query_bytes = b"" + query_bytes = encode_query_args(request.query) scope = start_active_span( "outgoing-federation-request", @@ -746,7 +737,7 @@ async def put_json( self, destination: str, path: str, - args: Optional[QueryArgs] = None, + args: Optional[QueryParams] = None, data: Optional[JsonDict] = None, json_data_callback: Optional[Callable[[], JsonDict]] = None, long_retries: bool = False, @@ -764,7 +755,7 @@ async def put_json( self, destination: str, path: str, - args: Optional[QueryArgs] = None, + args: Optional[QueryParams] = None, data: Optional[JsonDict] = None, json_data_callback: Optional[Callable[[], JsonDict]] = None, long_retries: bool = False, @@ -781,7 +772,7 @@ async def put_json( self, destination: str, path: str, - args: Optional[QueryArgs] = None, + args: Optional[QueryParams] = None, data: Optional[JsonDict] = None, json_data_callback: Optional[Callable[[], JsonDict]] = None, long_retries: bool = False, @@ -891,7 +882,7 @@ async def post_json( long_retries: bool = False, timeout: Optional[int] = None, ignore_backoff: bool = False, - args: Optional[QueryArgs] = None, + args: Optional[QueryParams] = None, ) -> Union[JsonDict, list]: """Sends the specified json data using POST @@ -961,7 +952,7 @@ async def get_json( self, destination: str, path: str, - args: Optional[QueryArgs] = None, + args: Optional[QueryParams] = None, retry_on_dns_fail: bool = True, timeout: Optional[int] = None, ignore_backoff: bool = False, @@ -976,7 +967,7 @@ async def get_json( self, destination: str, path: str, - args: Optional[QueryArgs] = ..., + args: Optional[QueryParams] = ..., retry_on_dns_fail: bool = ..., timeout: Optional[int] = ..., ignore_backoff: bool = ..., @@ -990,7 +981,7 @@ async def get_json( self, destination: str, path: str, - args: Optional[QueryArgs] = None, + args: Optional[QueryParams] = None, retry_on_dns_fail: bool = True, timeout: Optional[int] = None, ignore_backoff: bool = False, @@ -1085,7 +1076,7 @@ async def delete_json( long_retries: bool = False, timeout: Optional[int] = None, ignore_backoff: bool = False, - args: Optional[QueryArgs] = None, + args: Optional[QueryParams] = None, ) -> Union[JsonDict, list]: """Send a DELETE request to the remote expecting some json response @@ -1150,7 +1141,7 @@ async def get_file( destination: str, path: str, output_stream, - args: Optional[QueryArgs] = None, + args: Optional[QueryParams] = None, retry_on_dns_fail: bool = True, max_size: Optional[int] = None, ignore_backoff: bool = False, diff --git a/synapse/http/types.py b/synapse/http/types.py new file mode 100644 index 000000000000..11fe232d77cc --- /dev/null +++ b/synapse/http/types.py @@ -0,0 +1,21 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import Iterable, Mapping, Union + +# the type of the query params, to be passed into `urlencode` with `doseq=True`. +QueryParamValue = Union[str, bytes, Iterable[Union[str, bytes]]] +QueryParams = Union[Mapping[str, QueryParamValue], Mapping[bytes, QueryParamValue]] + +__all__ = ["QueryParams"]