Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Type annotations for test_v2 #12985

Merged
merged 8 commits into from
Jun 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12985.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type annotations to `tests.state.test_v2`.
4 changes: 3 additions & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ exclude = (?x)
|tests/rest/media/v1/test_media_storage.py
|tests/server.py
|tests/server_notices/test_resource_limits_server_notices.py
|tests/state/test_v2.py
|tests/test_metrics.py
|tests/test_server.py
|tests/test_state.py
Expand Down Expand Up @@ -115,6 +114,9 @@ disallow_untyped_defs = False
[mypy-tests.handlers.test_user_directory]
disallow_untyped_defs = True

[mypy-tests.state.test_profile]
disallow_untyped_defs = True

[mypy-tests.storage.test_profile]
disallow_untyped_defs = True

Expand Down
57 changes: 42 additions & 15 deletions synapse/state/v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,46 +17,73 @@
import logging
from typing import (
Any,
Awaitable,
Callable,
Collection,
Dict,
Generator,
Iterable,
List,
Mapping,
Optional,
Sequence,
Set,
Tuple,
overload,
)

from typing_extensions import Literal
from typing_extensions import Literal, Protocol

import synapse.state
from synapse import event_auth
from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
from synapse.api.room_versions import RoomVersion
from synapse.events import EventBase
from synapse.types import MutableStateMap, StateMap
from synapse.util import Clock

logger = logging.getLogger(__name__)


class Clock(Protocol):
# This is usually synapse.util.Clock, but it's replaced with a FakeClock in tests.
# We only ever sleep(0) though, so that other async functions can make forward
# progress without waiting for stateres to complete.
def sleep(self, duration_ms: float) -> Awaitable[None]:
...


class StateResolutionStore(Protocol):
# This is usually synapse.state.StateResolutionStore, but it's replaced with a
# TestStateResolutionStore in tests.
def get_events(
self, event_ids: Collection[str], allow_rejected: bool = False
) -> Awaitable[Dict[str, EventBase]]:
...

def get_auth_chain_difference(
self, room_id: str, state_sets: List[Set[str]]
) -> Awaitable[Set[str]]:
...


# We want to await to the reactor occasionally during state res when dealing
# with large data sets, so that we don't exhaust the reactor. This is done by
# awaiting to reactor during loops every N iterations.
_AWAIT_AFTER_ITERATIONS = 100


__all__ = [
"resolve_events_with_store",
]


async def resolve_events_with_store(
clock: Clock,
room_id: str,
room_version: RoomVersion,
state_sets: Sequence[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
state_res_store: "synapse.state.StateResolutionStore",
state_res_store: StateResolutionStore,
) -> StateMap[str]:
"""Resolves the state using the v2 state resolution algorithm

Expand Down Expand Up @@ -194,7 +221,7 @@ async def _get_power_level_for_sender(
room_id: str,
event_id: str,
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
state_res_store: StateResolutionStore,
) -> int:
"""Return the power level of the sender of the given event according to
their auth events.
Expand Down Expand Up @@ -243,9 +270,9 @@ async def _get_power_level_for_sender(

async def _get_auth_chain_difference(
room_id: str,
state_sets: Sequence[StateMap[str]],
state_sets: Sequence[Mapping[Any, str]],
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
state_res_store: StateResolutionStore,
) -> Set[str]:
"""Compare the auth chains of each state set and return the set of events
that only appear in some but not all of the auth chains.
Expand Down Expand Up @@ -406,7 +433,7 @@ async def _add_event_and_auth_chain_to_graph(
room_id: str,
event_id: str,
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
state_res_store: StateResolutionStore,
auth_diff: Set[str],
) -> None:
"""Helper function for _reverse_topological_power_sort that add the event
Expand Down Expand Up @@ -440,7 +467,7 @@ async def _reverse_topological_power_sort(
room_id: str,
event_ids: Iterable[str],
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
state_res_store: StateResolutionStore,
auth_diff: Set[str],
) -> List[str]:
"""Returns a list of the event_ids sorted by reverse topological ordering,
Expand Down Expand Up @@ -501,7 +528,7 @@ async def _iterative_auth_checks(
event_ids: List[str],
base_state: StateMap[str],
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
state_res_store: StateResolutionStore,
) -> MutableStateMap[str]:
"""Sequentially apply auth checks to each event in given list, updating the
state as it goes along.
Expand Down Expand Up @@ -570,7 +597,7 @@ async def _mainline_sort(
event_ids: List[str],
resolved_power_event_id: Optional[str],
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
state_res_store: StateResolutionStore,
) -> List[str]:
"""Returns a sorted list of event_ids sorted by mainline ordering based on
the given event resolved_power_event_id
Expand Down Expand Up @@ -639,7 +666,7 @@ async def _get_mainline_depth_for_event(
event: EventBase,
mainline_map: Dict[str, int],
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
state_res_store: StateResolutionStore,
) -> int:
"""Get the mainline depths for the given event based on the mainline map

Expand Down Expand Up @@ -683,7 +710,7 @@ async def _get_event(
room_id: str,
event_id: str,
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
state_res_store: StateResolutionStore,
allow_none: Literal[False] = False,
) -> EventBase:
...
Expand All @@ -694,7 +721,7 @@ async def _get_event(
room_id: str,
event_id: str,
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
state_res_store: StateResolutionStore,
allow_none: Literal[True],
) -> Optional[EventBase]:
...
Expand All @@ -704,7 +731,7 @@ async def _get_event(
room_id: str,
event_id: str,
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
state_res_store: StateResolutionStore,
allow_none: bool = False,
) -> Optional[EventBase]:
"""Helper function to look up event in event_map, falling back to looking
Expand Down
Loading