From 2d490c68bdf7c125539fb07296c2bcb4b3f6aaa6 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 12 Dec 2022 16:49:39 +0000 Subject: [PATCH] Use StateFilter --- synapse/config/api.py | 53 ++------------- .../storage/databases/main/events_worker.py | 38 ++++------- tests/config/test_api.py | 67 +++++++++---------- 3 files changed, 50 insertions(+), 108 deletions(-) diff --git a/synapse/config/api.py b/synapse/config/api.py index a29decd6007e..27d50d118f3f 100644 --- a/synapse/config/api.py +++ b/synapse/config/api.py @@ -13,71 +13,30 @@ # limitations under the License. import logging -from typing import Any, Container, Dict, Iterable, Mapping, Optional, Set, Tuple, Type - -import attr +from typing import Any, Iterable, Optional, Tuple from synapse.api.constants import EventTypes from synapse.config._base import Config, ConfigError from synapse.config._util import validate_config from synapse.types import JsonDict +from synapse.types.state import StateFilter logger = logging.getLogger(__name__) -@attr.s(auto_attribs=True) -class StateKeyFilter(Container[str]): - """A simpler version of StateFilter which ignores event types. - - Represents an optional constraint that state_keys must belong to a given set of - strings called `options`. An empty set of `options` means that there are no - restrictions. - """ - - options: Set[str] - - @classmethod - def any(cls: Type["StateKeyFilter"]) -> "StateKeyFilter": - return cls(set()) - - @classmethod - def only(cls: Type["StateKeyFilter"], state_key: str) -> "StateKeyFilter": - return cls({state_key}) - - def __contains__(self, state_key: object) -> bool: - return not self.options or state_key in self.options - - def add(self, state_key: Optional[str]) -> None: - if state_key is None: - self.options = set() - elif self.options: - self.options.add(state_key) - - class ApiConfig(Config): section = "api" - room_prejoin_state: Mapping[str, StateKeyFilter] + room_prejoin_state: StateFilter track_puppetted_users_ips: bool def read_config(self, config: JsonDict, **kwargs: Any) -> None: validate_config(_MAIN_SCHEMA, config, ()) - self.room_prejoin_state = self._build_prejoin_state(config) + self.room_prejoin_state = StateFilter.from_types( + self._get_prejoin_state_entries(config) + ) self.track_puppeted_user_ips = config.get("track_puppeted_user_ips", False) - def _build_prejoin_state(self, config: JsonDict) -> Dict[str, StateKeyFilter]: - prejoin_events = {} - for event_type, state_key in self._get_prejoin_state_entries(config): - if event_type not in prejoin_events: - if state_key is None: - filter = StateKeyFilter.any() - else: - filter = StateKeyFilter.only(state_key) - prejoin_events[event_type] = filter - else: - prejoin_events[event_type].add(state_key) - return prejoin_events - def _get_prejoin_state_entries( self, config: JsonDict ) -> Iterable[Tuple[str, Optional[str]]]: diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 7f6ecfef127c..a36549e7c24f 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -16,6 +16,7 @@ import threading import weakref from enum import Enum, auto +from itertools import chain from typing import ( TYPE_CHECKING, Any, @@ -23,7 +24,6 @@ Dict, Iterable, List, - Mapping, MutableMapping, Optional, Set, @@ -46,7 +46,6 @@ RoomVersion, RoomVersions, ) -from synapse.config.api import StateKeyFilter from synapse.events import EventBase, make_event_from_dict from synapse.events.snapshot import EventContext from synapse.events.utils import prune_event @@ -77,6 +76,7 @@ ) from synapse.storage.util.sequence import build_sequence_generator from synapse.types import JsonDict, get_domain_from_id +from synapse.types.state import StateFilter from synapse.util import unwrapFirstError from synapse.util.async_helpers import ObservableDeferred, delay_cancellation from synapse.util.caches.descriptors import cached, cachedList @@ -880,7 +880,7 @@ def _get_events_from_local_cache( async def get_stripped_room_state_from_event_context( self, context: EventContext, - state_keys_to_include: Mapping[str, StateKeyFilter], + state_keys_to_include: StateFilter, membership_user_id: Optional[str] = None, ) -> List[JsonDict]: """ @@ -902,31 +902,21 @@ async def get_stripped_room_state_from_event_context( Returns: A list of dictionaries, each representing a stripped state event from the room. """ - current_state_ids = await context.get_current_state_ids() + if membership_user_id: + types = chain( + state_keys_to_include.to_types(), + [(EventTypes.Member, membership_user_id)], + ) + filter = StateFilter.from_types(types) + else: + filter = state_keys_to_include + selected_state_ids = await context.get_current_state_ids(filter) # We know this event is not an outlier, so this must be # non-None. - assert current_state_ids is not None - - def should_include(t: str, s: str) -> bool: - if t in state_keys_to_include and s in state_keys_to_include[t]: - return True - if ( - membership_user_id - and t == EventTypes.Member - and s == membership_user_id - ): - return True - return False - - # The state to include - state_to_include_ids = [ - e_id - for (event_type, state_key), e_id in current_state_ids.items() - if should_include(event_type, state_key) - ] + assert selected_state_ids is not None - state_to_include = await self.get_events(state_to_include_ids) + state_to_include = await self.get_events(selected_state_ids.values()) return [ { diff --git a/tests/config/test_api.py b/tests/config/test_api.py index 8c65d2f58be4..6773c9a2773a 100644 --- a/tests/config/test_api.py +++ b/tests/config/test_api.py @@ -3,40 +3,21 @@ import yaml from synapse.config import ConfigError -from synapse.config.api import ApiConfig, StateKeyFilter - -DEFAULT_PREJOIN_STATE = { - "m.room.join_rules": StateKeyFilter.only(""), - "m.room.canonical_alias": StateKeyFilter.only(""), - "m.room.avatar": StateKeyFilter.only(""), - "m.room.encryption": StateKeyFilter.only(""), - "m.room.name": StateKeyFilter.only(""), - "m.room.create": StateKeyFilter.only(""), - "m.room.topic": StateKeyFilter.only(""), +from synapse.config.api import ApiConfig +from synapse.types.state import StateFilter + +DEFAULT_PREJOIN_STATE_PAIRS = { + ("m.room.join_rules", ""), + ("m.room.canonical_alias", ""), + ("m.room.avatar", ""), + ("m.room.encryption", ""), + ("m.room.name", ""), + ("m.room.create", ""), + ("m.room.topic", ""), } class TestRoomPrejoinState(StdlibTestCase): - def test_state_key_filter(self) -> None: - """Sanity check the StateKeyFilter class.""" - s = StateKeyFilter.only("foo") - self.assertIn("foo", s) - self.assertNotIn("bar", s) - self.assertNotIn("baz", s) - s.add("bar") - self.assertIn("foo", s) - self.assertIn("bar", s) - self.assertNotIn("baz", s) - - s = StateKeyFilter.any() - self.assertIn("foo", s) - self.assertIn("bar", s) - self.assertIn("baz", s) - s.add("bar") - self.assertIn("foo", s) - self.assertIn("bar", s) - self.assertIn("baz", s) - def read_config(self, source: str) -> ApiConfig: config = ApiConfig() config.read_config(yaml.safe_load(source)) @@ -44,7 +25,10 @@ def read_config(self, source: str) -> ApiConfig: def test_no_prejoin_state(self) -> None: config = self.read_config("foo: bar") - self.assertEqual(config.room_prejoin_state, DEFAULT_PREJOIN_STATE) + self.assertFalse(config.room_prejoin_state.has_wildcards()) + self.assertEqual( + set(config.room_prejoin_state.concrete_types()), DEFAULT_PREJOIN_STATE_PAIRS + ) def test_disable_default_event_types(self) -> None: config = self.read_config( @@ -53,7 +37,7 @@ def test_disable_default_event_types(self) -> None: disable_default_event_types: true """ ) - self.assertEqual(config.room_prejoin_state, {}) + self.assertEqual(config.room_prejoin_state, StateFilter.none()) def test_event_without_state_key(self) -> None: config = self.read_config( @@ -64,7 +48,8 @@ def test_event_without_state_key(self) -> None: - foo """ ) - self.assertEqual(config.room_prejoin_state, {"foo": StateKeyFilter.any()}) + self.assertEqual(config.room_prejoin_state.wildcard_types(), ["foo"]) + self.assertEqual(config.room_prejoin_state.concrete_types(), []) def test_event_with_specific_state_key(self) -> None: config = self.read_config( @@ -75,7 +60,11 @@ def test_event_with_specific_state_key(self) -> None: - [foo, bar] """ ) - self.assertEqual(config.room_prejoin_state, {"foo": StateKeyFilter.only("bar")}) + self.assertFalse(config.room_prejoin_state.has_wildcards()) + self.assertEqual( + set(config.room_prejoin_state.concrete_types()), + {("foo", "bar")}, + ) def test_repeated_event_with_specific_state_key(self) -> None: config = self.read_config( @@ -87,8 +76,10 @@ def test_repeated_event_with_specific_state_key(self) -> None: - [foo, baz] """ ) + self.assertFalse(config.room_prejoin_state.has_wildcards()) self.assertEqual( - config.room_prejoin_state, {"foo": StateKeyFilter({"bar", "baz"})} + set(config.room_prejoin_state.concrete_types()), + {("foo", "bar"), ("foo", "baz")}, ) def test_no_specific_state_key_overrides_specific_state_key(self) -> None: @@ -101,7 +92,8 @@ def test_no_specific_state_key_overrides_specific_state_key(self) -> None: - foo """ ) - self.assertEqual(config.room_prejoin_state, {"foo": StateKeyFilter.any()}) + self.assertEqual(config.room_prejoin_state.wildcard_types(), ["foo"]) + self.assertEqual(config.room_prejoin_state.concrete_types(), []) config = self.read_config( """ @@ -112,7 +104,8 @@ def test_no_specific_state_key_overrides_specific_state_key(self) -> None: - [foo, bar] """ ) - self.assertEqual(config.room_prejoin_state, {"foo": StateKeyFilter.any()}) + self.assertEqual(config.room_prejoin_state.wildcard_types(), ["foo"]) + self.assertEqual(config.room_prejoin_state.concrete_types(), []) def test_bad_event_type_entry_raises(self) -> None: with self.assertRaises(ConfigError):