Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Add type hints to various handlers. (#9223)
Browse files Browse the repository at this point in the history
With this change all handlers except the e2e_* ones have
type hints enabled.
  • Loading branch information
clokep authored Jan 26, 2021
1 parent 26837d5 commit 1baab20
Show file tree
Hide file tree
Showing 14 changed files with 205 additions and 138 deletions.
1 change: 1 addition & 0 deletions changelog.d/9223.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints to handlers code.
14 changes: 14 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ files =
synapse/handlers/_base.py,
synapse/handlers/account_data.py,
synapse/handlers/account_validity.py,
synapse/handlers/acme.py,
synapse/handlers/acme_issuing_service.py,
synapse/handlers/admin.py,
synapse/handlers/appservice.py,
synapse/handlers/auth.py,
Expand All @@ -36,6 +38,7 @@ files =
synapse/handlers/directory.py,
synapse/handlers/events.py,
synapse/handlers/federation.py,
synapse/handlers/groups_local.py,
synapse/handlers/identity.py,
synapse/handlers/initial_sync.py,
synapse/handlers/message.py,
Expand All @@ -52,8 +55,13 @@ files =
synapse/handlers/room_member.py,
synapse/handlers/room_member_worker.py,
synapse/handlers/saml_handler.py,
synapse/handlers/search.py,
synapse/handlers/set_password.py,
synapse/handlers/sso.py,
synapse/handlers/state_deltas.py,
synapse/handlers/stats.py,
synapse/handlers/sync.py,
synapse/handlers/typing.py,
synapse/handlers/user_directory.py,
synapse/handlers/ui_auth,
synapse/http/client.py,
Expand Down Expand Up @@ -194,3 +202,9 @@ ignore_missing_imports = True

[mypy-hiredis]
ignore_missing_imports = True

[mypy-josepy.*]
ignore_missing_imports = True

[mypy-txacme.*]
ignore_missing_imports = True
12 changes: 7 additions & 5 deletions synapse/handlers/acme.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING

import twisted
import twisted.internet.error
Expand All @@ -22,6 +23,9 @@

from synapse.app import check_bind_error

if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer

logger = logging.getLogger(__name__)

ACME_REGISTER_FAIL_ERROR = """
Expand All @@ -35,12 +39,12 @@


class AcmeHandler:
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.reactor = hs.get_reactor()
self._acme_domain = hs.config.acme_domain

async def start_listening(self):
async def start_listening(self) -> None:
from synapse.handlers import acme_issuing_service

# Configure logging for txacme, if you need to debug
Expand Down Expand Up @@ -85,7 +89,7 @@ async def start_listening(self):
logger.error(ACME_REGISTER_FAIL_ERROR)
raise

async def provision_certificate(self):
async def provision_certificate(self) -> None:

logger.warning("Reprovisioning %s", self._acme_domain)

Expand All @@ -110,5 +114,3 @@ async def provision_certificate(self):
except Exception:
logger.exception("Failed saving!")
raise

return True
27 changes: 19 additions & 8 deletions synapse/handlers/acme_issuing_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
imported conditionally.
"""
import logging
from typing import Dict, Iterable, List

import attr
import pem
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from josepy import JWKRSA
Expand All @@ -36,20 +38,27 @@
from zope.interface import implementer

from twisted.internet import defer
from twisted.internet.interfaces import IReactorTCP
from twisted.python.filepath import FilePath
from twisted.python.url import URL
from twisted.web.resource import IResource

logger = logging.getLogger(__name__)


def create_issuing_service(reactor, acme_url, account_key_file, well_known_resource):
def create_issuing_service(
reactor: IReactorTCP,
acme_url: str,
account_key_file: str,
well_known_resource: IResource,
) -> AcmeIssuingService:
"""Create an ACME issuing service, and attach it to a web Resource
Args:
reactor: twisted reactor
acme_url (str): URL to use to request certificates
account_key_file (str): where to store the account key
well_known_resource (twisted.web.IResource): web resource for .well-known.
acme_url: URL to use to request certificates
account_key_file: where to store the account key
well_known_resource: web resource for .well-known.
we will attach a child resource for "acme-challenge".
Returns:
Expand Down Expand Up @@ -83,18 +92,20 @@ class ErsatzStore:
A store that only stores in memory.
"""

certs = attr.ib(default=attr.Factory(dict))
certs = attr.ib(type=Dict[bytes, List[bytes]], default=attr.Factory(dict))

def store(self, server_name, pem_objects):
def store(
self, server_name: bytes, pem_objects: Iterable[pem.AbstractPEMObject]
) -> defer.Deferred:
self.certs[server_name] = [o.as_bytes() for o in pem_objects]
return defer.succeed(None)


def load_or_create_client_key(key_file):
def load_or_create_client_key(key_file: str) -> JWKRSA:
"""Load the ACME account key from a file, creating it if it does not exist.
Args:
key_file (str): name of the file to use as the account key
key_file: name of the file to use as the account key
"""
# this is based on txacme.endpoint.load_or_create_client_key, but doesn't
# hardcode the 'client.key' filename
Expand Down
83 changes: 42 additions & 41 deletions synapse/handlers/groups_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,13 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING, Dict, Iterable, List, Set

from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
from synapse.types import GroupID, get_domain_from_id
from synapse.types import GroupID, JsonDict, get_domain_from_id

if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -56,7 +60,7 @@ async def f(self, group_id, *args, **kwargs):


class GroupsLocalWorkerHandler:
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.room_list_handler = hs.get_room_list_handler()
Expand Down Expand Up @@ -84,7 +88,9 @@ def __init__(self, hs):
get_group_role = _create_rerouter("get_group_role")
get_group_roles = _create_rerouter("get_group_roles")

async def get_group_summary(self, group_id, requester_user_id):
async def get_group_summary(
self, group_id: str, requester_user_id: str
) -> JsonDict:
"""Get the group summary for a group.
If the group is remote we check that the users have valid attestations.
Expand Down Expand Up @@ -137,14 +143,15 @@ async def get_group_summary(self, group_id, requester_user_id):

return res

async def get_users_in_group(self, group_id, requester_user_id):
async def get_users_in_group(
self, group_id: str, requester_user_id: str
) -> JsonDict:
"""Get users in a group
"""
if self.is_mine_id(group_id):
res = await self.groups_server_handler.get_users_in_group(
return await self.groups_server_handler.get_users_in_group(
group_id, requester_user_id
)
return res

group_server_name = get_domain_from_id(group_id)

Expand Down Expand Up @@ -178,11 +185,11 @@ async def get_users_in_group(self, group_id, requester_user_id):

return res

async def get_joined_groups(self, user_id):
async def get_joined_groups(self, user_id: str) -> JsonDict:
group_ids = await self.store.get_joined_groups(user_id)
return {"groups": group_ids}

async def get_publicised_groups_for_user(self, user_id):
async def get_publicised_groups_for_user(self, user_id: str) -> JsonDict:
if self.hs.is_mine_id(user_id):
result = await self.store.get_publicised_groups_for_user(user_id)

Expand All @@ -206,8 +213,10 @@ async def get_publicised_groups_for_user(self, user_id):
# TODO: Verify attestations
return {"groups": result}

async def bulk_get_publicised_groups(self, user_ids, proxy=True):
destinations = {}
async def bulk_get_publicised_groups(
self, user_ids: Iterable[str], proxy: bool = True
) -> JsonDict:
destinations = {} # type: Dict[str, Set[str]]
local_users = set()

for user_id in user_ids:
Expand All @@ -220,7 +229,7 @@ async def bulk_get_publicised_groups(self, user_ids, proxy=True):
raise SynapseError(400, "Some user_ids are not local")

results = {}
failed_results = []
failed_results = [] # type: List[str]
for destination, dest_user_ids in destinations.items():
try:
r = await self.transport_client.bulk_get_publicised_groups(
Expand All @@ -242,7 +251,7 @@ async def bulk_get_publicised_groups(self, user_ids, proxy=True):


class GroupsLocalHandler(GroupsLocalWorkerHandler):
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)

# Ensure attestations get renewed
Expand Down Expand Up @@ -271,7 +280,9 @@ def __init__(self, hs):

set_group_join_policy = _create_rerouter("set_group_join_policy")

async def create_group(self, group_id, user_id, content):
async def create_group(
self, group_id: str, user_id: str, content: JsonDict
) -> JsonDict:
"""Create a group
"""

Expand All @@ -284,27 +295,7 @@ async def create_group(self, group_id, user_id, content):
local_attestation = None
remote_attestation = None
else:
local_attestation = self.attestations.create_attestation(group_id, user_id)
content["attestation"] = local_attestation

content["user_profile"] = await self.profile_handler.get_profile(user_id)

try:
res = await self.transport_client.create_group(
get_domain_from_id(group_id), group_id, user_id, content
)
except HttpResponseException as e:
raise e.to_synapse_error()
except RequestSendFailed:
raise SynapseError(502, "Failed to contact group server")

remote_attestation = res["attestation"]
await self.attestations.verify_attestation(
remote_attestation,
group_id=group_id,
user_id=user_id,
server_name=get_domain_from_id(group_id),
)
raise SynapseError(400, "Unable to create remote groups")

is_publicised = content.get("publicise", False)
token = await self.store.register_user_group_membership(
Expand All @@ -320,7 +311,9 @@ async def create_group(self, group_id, user_id, content):

return res

async def join_group(self, group_id, user_id, content):
async def join_group(
self, group_id: str, user_id: str, content: JsonDict
) -> JsonDict:
"""Request to join a group
"""
if self.is_mine_id(group_id):
Expand Down Expand Up @@ -365,7 +358,9 @@ async def join_group(self, group_id, user_id, content):

return {}

async def accept_invite(self, group_id, user_id, content):
async def accept_invite(
self, group_id: str, user_id: str, content: JsonDict
) -> JsonDict:
"""Accept an invite to a group
"""
if self.is_mine_id(group_id):
Expand Down Expand Up @@ -410,7 +405,9 @@ async def accept_invite(self, group_id, user_id, content):

return {}

async def invite(self, group_id, user_id, requester_user_id, config):
async def invite(
self, group_id: str, user_id: str, requester_user_id: str, config: JsonDict
) -> JsonDict:
"""Invite a user to a group
"""
content = {"requester_user_id": requester_user_id, "config": config}
Expand All @@ -434,7 +431,9 @@ async def invite(self, group_id, user_id, requester_user_id, config):

return res

async def on_invite(self, group_id, user_id, content):
async def on_invite(
self, group_id: str, user_id: str, content: JsonDict
) -> JsonDict:
"""One of our users were invited to a group
"""
# TODO: Support auto join and rejection
Expand Down Expand Up @@ -465,8 +464,8 @@ async def on_invite(self, group_id, user_id, content):
return {"state": "invite", "user_profile": user_profile}

async def remove_user_from_group(
self, group_id, user_id, requester_user_id, content
):
self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict
) -> JsonDict:
"""Remove a user from a group
"""
if user_id == requester_user_id:
Expand Down Expand Up @@ -499,7 +498,9 @@ async def remove_user_from_group(

return res

async def user_removed_from_group(self, group_id, user_id, content):
async def user_removed_from_group(
self, group_id: str, user_id: str, content: JsonDict
) -> None:
"""One of our users was removed/kicked from a group
"""
# TODO: Check if user in group
Expand Down
Loading

0 comments on commit 1baab20

Please sign in to comment.