Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Typecheck federation_client.py
Browse files Browse the repository at this point in the history
  • Loading branch information
David Robertson committed Apr 8, 2022
1 parent 6f33250 commit 3d03b4e
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 11 deletions.
4 changes: 3 additions & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ files =
# https://docs.python.org/3/library/re.html#re.X
exclude = (?x)
^(
|scripts-dev/federation_client.py
|scripts-dev/release.py

|synapse/storage/databases/__init__.py
Expand Down Expand Up @@ -314,6 +313,9 @@ ignore_missing_imports = True
[mypy-signedjson.*]
ignore_missing_imports = True

[mypy-srvlookup.*]
ignore_missing_imports = True

[mypy-treq.*]
ignore_missing_imports = True

Expand Down
27 changes: 17 additions & 10 deletions scripts-dev/federation_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
import base64
import json
import sys
from typing import Any, Optional
from typing import Any, Dict, Optional, Tuple
from urllib import parse as urlparse

import requests
Expand All @@ -47,13 +47,14 @@
import srvlookup
import yaml
from requests.adapters import HTTPAdapter
from urllib3 import HTTPConnectionPool

# uncomment the following to enable debug logging of http requests
# from httplib import HTTPConnection
# HTTPConnection.debuglevel = 1


def encode_base64(input_bytes):
def encode_base64(input_bytes: bytes) -> str:
"""Encode bytes as a base64 string without any padding."""

input_len = len(input_bytes)
Expand All @@ -63,7 +64,7 @@ def encode_base64(input_bytes):
return output_string


def encode_canonical_json(value):
def encode_canonical_json(value: object) -> bytes:
return json.dumps(
value,
# Encode code-points outside of ASCII as UTF-8 rather than \u escapes
Expand Down Expand Up @@ -125,7 +126,7 @@ def request(

for key, sig in signed_json["signatures"][origin_name].items():
header = 'X-Matrix origin=%s,key="%s",sig="%s"' % (origin_name, key, sig)
authorization_headers.append(header.encode("ascii"))
authorization_headers.append(header)
print("Authorization: %s" % header, file=sys.stderr)

dest = "matrix://%s%s" % (destination, path)
Expand All @@ -134,7 +135,10 @@ def request(
s = requests.Session()
s.mount("matrix://", MatrixConnectionAdapter())

headers = {"Host": destination, "Authorization": authorization_headers[0]}
headers: Dict[str, str] = {
"Host": destination,
"Authorization": authorization_headers[0],
}

if method == "POST":
headers["Content-Type"] = "application/json"
Expand All @@ -149,7 +153,7 @@ def request(
)


def main():
def main() -> None:
parser = argparse.ArgumentParser(
description="Signs and sends a federation request to a matrix homeserver"
)
Expand Down Expand Up @@ -207,6 +211,7 @@ def main():
if not args.server_name or not args.signing_key:
read_args_from_config(args)

assert isinstance(args.signing_key, str)
algorithm, version, key_base64 = args.signing_key.split()
key = signedjson.key.decode_signing_key_base64(algorithm, version, key_base64)

Expand All @@ -228,7 +233,7 @@ def main():
print("")


def read_args_from_config(args):
def read_args_from_config(args: argparse.Namespace) -> None:
with open(args.config, "r") as fh:
config = yaml.safe_load(fh)

Expand All @@ -245,7 +250,7 @@ def read_args_from_config(args):

class MatrixConnectionAdapter(HTTPAdapter):
@staticmethod
def lookup(s, skip_well_known=False):
def lookup(s: str, skip_well_known: bool = False) -> Tuple[str, int]:
if s[-1] == "]":
# ipv6 literal (with no port)
return s, 8448
Expand All @@ -271,7 +276,7 @@ def lookup(s, skip_well_known=False):
return s, 8448

@staticmethod
def get_well_known(server_name):
def get_well_known(server_name: str) -> Optional[str]:
uri = "https://%s/.well-known/matrix/server" % (server_name,)
print("fetching %s" % (uri,), file=sys.stderr)

Expand All @@ -294,7 +299,9 @@ def get_well_known(server_name):
print("Invalid response from %s: %s" % (uri, e), file=sys.stderr)
return None

def get_connection(self, url, proxies=None):
def get_connection(
self, url: str, proxies: Optional[Dict[str, str]] = None
) -> HTTPConnectionPool:
parsed = urlparse.urlparse(url)

(host, port) = self.lookup(parsed.netloc)
Expand Down

0 comments on commit 3d03b4e

Please sign in to comment.