Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add type hints for state. #8140

Merged
merged 10 commits into from
Aug 24, 2020
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/8140.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints to `synapse.state`.
47 changes: 47 additions & 0 deletions stubs/frozendict.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Stub for frozendict.

from typing import (
Any,
Hashable,
Iterable,
Iterator,
Mapping,
overload,
Tuple,
TypeVar,
)

_KT = TypeVar("_KT", bound=Hashable) # Key type.
_VT = TypeVar("_VT") # Value type.

class frozendict(Mapping[_KT, _VT]):
@overload
def __init__(self, **kwargs: _VT) -> None: ...
@overload
def __init__(self, __map: Mapping[_KT, _VT], **kwargs: _VT) -> None: ...
@overload
def __init__(
self, __iterable: Iterable[Tuple[_KT, _VT]], **kwargs: _VT
) -> None: ...
def __getitem__(self, key: _KT) -> _VT: ...
def __contains__(self, key: Any) -> bool: ...
def copy(self, **add_or_replace: Any) -> frozendict: ...
def __iter__(self) -> Iterator[_KT]: ...
def __len__(self) -> int: ...
def __repr__(self) -> str: ...
def __hash__(self) -> int: ...
4 changes: 2 additions & 2 deletions synapse/federation/sender/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,10 +329,10 @@ async def send_read_receipt(self, receipt: ReadReceipt) -> None:
room_id = receipt.room_id

# Work out which remote servers should be poked and poke them.
domains = await self.state.get_current_hosts_in_room(room_id)
domains_set = await self.state.get_current_hosts_in_room(room_id)
domains = [
d
for d in domains
for d in domains_set
if d != self.server_name
and self._federation_shard_config.should_handle(self._instance_name, d)
]
Expand Down
10 changes: 6 additions & 4 deletions synapse/handlers/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2134,10 +2134,10 @@ async def _check_for_soft_fail(
)
state_sets = list(state_sets.values())
state_sets.append(state)
current_state_ids = await self.state_handler.resolve_events(
current_states = await self.state_handler.resolve_events(
room_version, state_sets, event
)
current_state_ids = {k: e.event_id for k, e in current_state_ids.items()}
current_state_ids = {k: e.event_id for k, e in current_states.items()}
else:
current_state_ids = await self.state_handler.get_current_state_ids(
event.room_id, latest_event_ids=extrem_ids
Expand All @@ -2149,9 +2149,11 @@ async def _check_for_soft_fail(

# Now check if event pass auth against said current state
auth_types = auth_types_for_event(event)
current_state_ids = [e for k, e in current_state_ids.items() if k in auth_types]
current_state_ids_list = [
e for k, e in current_state_ids.items() if k in auth_types
]

auth_events_map = await self.store.get_events(current_state_ids)
auth_events_map = await self.store.get_events(current_state_ids_list)
current_auth_events = {
(e.type, e.state_key): e for e in auth_events_map.values()
}
Expand Down
4 changes: 2 additions & 2 deletions synapse/handlers/presence.py
Original file line number Diff line number Diff line change
Expand Up @@ -1318,7 +1318,7 @@ async def get_interested_parties(

async def get_interested_remotes(
store: DataStore, states: List[UserPresenceState], state_handler: StateHandler
) -> List[Tuple[List[str], List[UserPresenceState]]]:
) -> List[Tuple[Iterable[str], List[UserPresenceState]]]:
clokep marked this conversation as resolved.
Show resolved Hide resolved
"""Given a list of presence states figure out which remote servers
should be sent which.

Expand All @@ -1334,7 +1334,7 @@ async def get_interested_remotes(
each tuple the list of UserPresenceState should be sent to each
destination
"""
hosts_and_states = []
hosts_and_states = [] # type: List[Tuple[Iterable[str], List[UserPresenceState]]]

# First we look up the rooms each user is in (as well as any explicit
# subscriptions), then for each distinct room we look up the remote
Expand Down
Loading