Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add type hints for state. #8140

Merged
merged 10 commits into from
Aug 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/8140.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints to `synapse.state`.
47 changes: 47 additions & 0 deletions stubs/frozendict.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Stub for frozendict.

from typing import (
Any,
Hashable,
Iterable,
Iterator,
Mapping,
overload,
Tuple,
TypeVar,
)

_KT = TypeVar("_KT", bound=Hashable) # Key type.
_VT = TypeVar("_VT") # Value type.

class frozendict(Mapping[_KT, _VT]):
@overload
def __init__(self, **kwargs: _VT) -> None: ...
@overload
def __init__(self, __map: Mapping[_KT, _VT], **kwargs: _VT) -> None: ...
@overload
def __init__(
self, __iterable: Iterable[Tuple[_KT, _VT]], **kwargs: _VT
) -> None: ...
def __getitem__(self, key: _KT) -> _VT: ...
def __contains__(self, key: Any) -> bool: ...
def copy(self, **add_or_replace: Any) -> frozendict: ...
def __iter__(self) -> Iterator[_KT]: ...
def __len__(self) -> int: ...
def __repr__(self) -> str: ...
def __hash__(self) -> int: ...
4 changes: 2 additions & 2 deletions synapse/federation/sender/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,10 +329,10 @@ async def send_read_receipt(self, receipt: ReadReceipt) -> None:
room_id = receipt.room_id

# Work out which remote servers should be poked and poke them.
domains = await self.state.get_current_hosts_in_room(room_id)
domains_set = await self.state.get_current_hosts_in_room(room_id)
domains = [
d
for d in domains
for d in domains_set
if d != self.server_name
and self._federation_shard_config.should_handle(self._instance_name, d)
]
Expand Down
10 changes: 6 additions & 4 deletions synapse/handlers/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2134,10 +2134,10 @@ async def _check_for_soft_fail(
)
state_sets = list(state_sets.values())
state_sets.append(state)
current_state_ids = await self.state_handler.resolve_events(
current_states = await self.state_handler.resolve_events(
room_version, state_sets, event
)
current_state_ids = {k: e.event_id for k, e in current_state_ids.items()}
current_state_ids = {k: e.event_id for k, e in current_states.items()}
else:
current_state_ids = await self.state_handler.get_current_state_ids(
event.room_id, latest_event_ids=extrem_ids
Expand All @@ -2149,9 +2149,11 @@ async def _check_for_soft_fail(

# Now check if event pass auth against said current state
auth_types = auth_types_for_event(event)
current_state_ids = [e for k, e in current_state_ids.items() if k in auth_types]
current_state_ids_list = [
e for k, e in current_state_ids.items() if k in auth_types
]

auth_events_map = await self.store.get_events(current_state_ids)
auth_events_map = await self.store.get_events(current_state_ids_list)
current_auth_events = {
(e.type, e.state_key): e for e in auth_events_map.values()
}
Expand Down
6 changes: 3 additions & 3 deletions synapse/handlers/presence.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.state import StateHandler
from synapse.storage.databases.main import DataStore
from synapse.types import JsonDict, UserID, get_domain_from_id
from synapse.types import Collection, JsonDict, UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.descriptors import cached
from synapse.util.metrics import Measure
Expand Down Expand Up @@ -1318,7 +1318,7 @@ async def get_interested_parties(

async def get_interested_remotes(
store: DataStore, states: List[UserPresenceState], state_handler: StateHandler
) -> List[Tuple[List[str], List[UserPresenceState]]]:
) -> List[Tuple[Collection[str], List[UserPresenceState]]]:
"""Given a list of presence states figure out which remote servers
should be sent which.

Expand All @@ -1334,7 +1334,7 @@ async def get_interested_remotes(
each tuple the list of UserPresenceState should be sent to each
destination
"""
hosts_and_states = []
hosts_and_states = [] # type: List[Tuple[Collection[str], List[UserPresenceState]]]

# First we look up the rooms each user is in (as well as any explicit
# subscriptions), then for each distinct room we look up the remote
Expand Down
20 changes: 12 additions & 8 deletions synapse/handlers/room_member.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import abc
import logging
from http import HTTPStatus
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union

from unpaddedbase64 import encode_base64

Expand All @@ -31,7 +31,15 @@
from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator
from synapse.storage.roommember import RoomsForUser
from synapse.types import Collection, JsonDict, Requester, RoomAlias, RoomID, UserID
from synapse.types import (
Collection,
JsonDict,
Requester,
RoomAlias,
RoomID,
StateMap,
UserID,
)
from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_joined_room, user_left_room

Expand Down Expand Up @@ -704,9 +712,7 @@ async def send_membership_event(
if prev_member_event.membership == Membership.JOIN:
await self._user_left_room(target_user, room_id)

async def _can_guest_join(
self, current_state_ids: Dict[Tuple[str, str], str]
) -> bool:
async def _can_guest_join(self, current_state_ids: StateMap[str]) -> bool:
"""
Returns whether a guest can join a room based on its current state.
"""
Expand Down Expand Up @@ -909,9 +915,7 @@ async def _make_and_store_3pid_invite(
)
return stream_id

async def _is_host_in_room(
self, current_state_ids: Dict[Tuple[str, str], str]
) -> bool:
async def _is_host_in_room(self, current_state_ids: StateMap[str]) -> bool:
# Have we just created the room, and is this about to be the very
# first member event?
create_event_id = current_state_ids.get(("m.room.create", ""))
Expand Down
Loading