Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Unify HTTP query parameter type hints #12415

Merged
merged 10 commits into from
Apr 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12415.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve type hints related to HTTP query parameters.
3 changes: 2 additions & 1 deletion synapse/federation/federation_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from synapse.events import EventBase, builder
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
from synapse.federation.transport.client import SendJoinResponse
from synapse.http.types import QueryParams
from synapse.types import JsonDict, UserID, get_domain_from_id
from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.expiringcache import ExpiringCache
Expand Down Expand Up @@ -154,7 +155,7 @@ async def make_query(
self,
destination: str,
query_type: str,
args: dict,
args: QueryParams,
retry_on_dns_fail: bool = False,
ignore_backoff: bool = False,
) -> JsonDict:
Expand Down
5 changes: 3 additions & 2 deletions synapse/federation/transport/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from synapse.events import EventBase, make_event_from_dict
from synapse.federation.units import Transaction
from synapse.http.matrixfederationclient import ByteParser
from synapse.http.types import QueryParams
from synapse.types import JsonDict

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -255,7 +256,7 @@ async def make_query(
self,
destination: str,
query_type: str,
args: dict,
args: QueryParams,
retry_on_dns_fail: bool,
ignore_backoff: bool = False,
prefix: str = FEDERATION_V1_PREFIX,
Expand Down Expand Up @@ -503,7 +504,7 @@ async def get_public_rooms(
else:
path = _create_v1_path("/publicRooms")

args: Dict[str, Any] = {
args: Dict[str, Union[str, Iterable[str]]] = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to use MutableMapping for QueryParams (or have a MutableQueryParams)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I prefer it keeping the (immutable) Mapping. The non-mutable version expresses that functions taking a QueryParams won't mutate those parameters. It's inconvenient here if you conditionally build up a parameter dictionary, but that's rare: we almost always build a parameter dictionary args as a dictionary literal and pass it immediately to an HTTP function (without further mutating args).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That seems fine! 👍

"include_all_networks": "true" if include_all_networks else "false"
}
if third_party_instance_id:
Expand Down
16 changes: 3 additions & 13 deletions synapse/http/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
BinaryIO,
Callable,
Dict,
Iterable,
List,
Mapping,
Optional,
Expand Down Expand Up @@ -72,6 +71,7 @@
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_uri
from synapse.http.proxyagent import ProxyAgent
from synapse.http.types import QueryParams
from synapse.logging.context import make_deferred_yieldable
from synapse.logging.opentracing import set_tag, start_active_span, tags
from synapse.types import ISynapseReactor
Expand All @@ -97,10 +97,6 @@
# the entries can either be Lists or bytes.
RawHeaderValue = Sequence[Union[str, bytes]]

# the type of the query params, to be passed into `urlencode`
QueryParamValue = Union[str, bytes, Iterable[Union[str, bytes]]]
QueryParams = Union[Mapping[str, QueryParamValue], Mapping[bytes, QueryParamValue]]


def check_against_blacklist(
ip_address: IPAddress, ip_whitelist: Optional[IPSet], ip_blacklist: IPSet
Expand Down Expand Up @@ -911,7 +907,7 @@ def read_body_with_max_size(
return d


def encode_query_args(args: Optional[Mapping[str, Union[str, List[str]]]]) -> bytes:
def encode_query_args(args: Optional[QueryParams]) -> bytes:
"""
Encodes a map of query arguments to bytes which can be appended to a URL.

Expand All @@ -924,13 +920,7 @@ def encode_query_args(args: Optional[Mapping[str, Union[str, List[str]]]]) -> by
if args is None:
return b""

encoded_args = {}
for k, vs in args.items():
if isinstance(vs, str):
vs = [vs]
encoded_args[k] = [v.encode("utf8") for v in vs]

query_str = urllib.parse.urlencode(encoded_args, True)
query_str = urllib.parse.urlencode(args, True)

return query_str.encode("utf8")

Expand Down
35 changes: 13 additions & 22 deletions synapse/http/matrixfederationclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
read_body_with_max_size,
)
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.http.types import QueryParams
from synapse.logging import opentracing
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import set_tag, start_active_span, tags
Expand Down Expand Up @@ -98,10 +99,6 @@

_next_id = 1


QueryArgs = Dict[str, Union[str, List[str]]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We pretty much just had two implementations of the same thing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct: one from #8806 and #8372. I think yourself and Rich basically ended up at the same solution.



T = TypeVar("T")


Expand Down Expand Up @@ -144,7 +141,7 @@ class MatrixFederationRequest:
"""A callback to generate the JSON.
"""

query: Optional[dict] = None
query: Optional[QueryParams] = None
"""Query arguments.
"""

Expand All @@ -165,10 +162,7 @@ def __attrs_post_init__(self) -> None:

destination_bytes = self.destination.encode("ascii")
path_bytes = self.path.encode("ascii")
if self.query:
query_bytes = encode_query_args(self.query)
else:
query_bytes = b""
Comment on lines -168 to -171
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this also work properly if you pass an empty dictionary into encode_query_args?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hadn't thought of this, but fortunately it will work:

$ python
Python 3.10.4 (main, Mar 25 2022, 00:00:00) [GCC 11.2.1 20220127 (Red Hat 11.2.1-9)] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import urllib.parse
>>> urllib.parse.urlencode({}, True)
""

query_bytes = encode_query_args(self.query)

# The object is frozen so we can pre-compute this.
uri = urllib.parse.urlunparse(
Expand Down Expand Up @@ -485,10 +479,7 @@ async def _send_request(
method_bytes = request.method.encode("ascii")
destination_bytes = request.destination.encode("ascii")
path_bytes = request.path.encode("ascii")
if request.query:
query_bytes = encode_query_args(request.query)
else:
query_bytes = b""
query_bytes = encode_query_args(request.query)

scope = start_active_span(
"outgoing-federation-request",
Expand Down Expand Up @@ -746,7 +737,7 @@ async def put_json(
self,
destination: str,
path: str,
args: Optional[QueryArgs] = None,
args: Optional[QueryParams] = None,
data: Optional[JsonDict] = None,
json_data_callback: Optional[Callable[[], JsonDict]] = None,
long_retries: bool = False,
Expand All @@ -764,7 +755,7 @@ async def put_json(
self,
destination: str,
path: str,
args: Optional[QueryArgs] = None,
args: Optional[QueryParams] = None,
data: Optional[JsonDict] = None,
json_data_callback: Optional[Callable[[], JsonDict]] = None,
long_retries: bool = False,
Expand All @@ -781,7 +772,7 @@ async def put_json(
self,
destination: str,
path: str,
args: Optional[QueryArgs] = None,
args: Optional[QueryParams] = None,
data: Optional[JsonDict] = None,
json_data_callback: Optional[Callable[[], JsonDict]] = None,
long_retries: bool = False,
Expand Down Expand Up @@ -891,7 +882,7 @@ async def post_json(
long_retries: bool = False,
timeout: Optional[int] = None,
ignore_backoff: bool = False,
args: Optional[QueryArgs] = None,
args: Optional[QueryParams] = None,
) -> Union[JsonDict, list]:
"""Sends the specified json data using POST

Expand Down Expand Up @@ -961,7 +952,7 @@ async def get_json(
self,
destination: str,
path: str,
args: Optional[QueryArgs] = None,
args: Optional[QueryParams] = None,
retry_on_dns_fail: bool = True,
timeout: Optional[int] = None,
ignore_backoff: bool = False,
Expand All @@ -976,7 +967,7 @@ async def get_json(
self,
destination: str,
path: str,
args: Optional[QueryArgs] = ...,
args: Optional[QueryParams] = ...,
retry_on_dns_fail: bool = ...,
timeout: Optional[int] = ...,
ignore_backoff: bool = ...,
Expand All @@ -990,7 +981,7 @@ async def get_json(
self,
destination: str,
path: str,
args: Optional[QueryArgs] = None,
args: Optional[QueryParams] = None,
retry_on_dns_fail: bool = True,
timeout: Optional[int] = None,
ignore_backoff: bool = False,
Expand Down Expand Up @@ -1085,7 +1076,7 @@ async def delete_json(
long_retries: bool = False,
timeout: Optional[int] = None,
ignore_backoff: bool = False,
args: Optional[QueryArgs] = None,
args: Optional[QueryParams] = None,
) -> Union[JsonDict, list]:
"""Send a DELETE request to the remote expecting some json response

Expand Down Expand Up @@ -1150,7 +1141,7 @@ async def get_file(
destination: str,
path: str,
output_stream,
args: Optional[QueryArgs] = None,
args: Optional[QueryParams] = None,
retry_on_dns_fail: bool = True,
max_size: Optional[int] = None,
ignore_backoff: bool = False,
Expand Down
21 changes: 21 additions & 0 deletions synapse/http/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Iterable, Mapping, Union
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved

# the type of the query params, to be passed into `urlencode` with `doseq=True`.
QueryParamValue = Union[str, bytes, Iterable[Union[str, bytes]]]
QueryParams = Union[Mapping[str, QueryParamValue], Mapping[bytes, QueryParamValue]]

__all__ = ["QueryParams"]