Skip to content
This repository has been archived by the owner on Mar 26, 2024. It is now read-only.

Commit

Permalink
Speed up updating state in large rooms (matrix-org#15971)
Browse files Browse the repository at this point in the history
This should speed up updating state in rooms with lots of state.
  • Loading branch information
erikjohnston authored and Fizzadar committed Aug 29, 2023
1 parent 24c7f3c commit b037e3f
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 131 deletions.
1 change: 1 addition & 0 deletions changelog.d/15971.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Speed up updating state in large rooms.
9 changes: 4 additions & 5 deletions synapse/handlers/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -1577,12 +1577,11 @@ async def cache_joined_hosts_for_events(
if state_entry.state_group in self._external_cache_joined_hosts_updates:
return

state = await state_entry.get_state(
self._storage_controllers.state, StateFilter.all()
)
with opentracing.start_active_span("get_joined_hosts"):
joined_hosts = await self.store.get_joined_hosts(
event.room_id, state, state_entry
joined_hosts = (
await self._storage_controllers.state.get_joined_hosts(
event.room_id, state_entry
)
)

# Note that the expiry times must be larger than the expiry time in
Expand Down
3 changes: 1 addition & 2 deletions synapse/state/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,7 @@ async def get_hosts_in_room_at_events(
The hosts in the room at the given events
"""
entry = await self.resolve_state_groups_for_events(room_id, event_ids)
state = await entry.get_state(self._state_storage_controller, StateFilter.all())
return await self.store.get_joined_hosts(room_id, state, entry)
return await self._state_storage_controller.get_joined_hosts(room_id, entry)

@trace
@tag_args
Expand Down
137 changes: 135 additions & 2 deletions synapse/storage/controllers/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,36 +12,45 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from itertools import chain
from typing import (
TYPE_CHECKING,
AbstractSet,
Any,
Callable,
Collection,
Dict,
FrozenSet,
Iterable,
List,
Mapping,
Optional,
Tuple,
Union,
)

from synapse.api.constants import EventTypes
from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
from synapse.logging.opentracing import tag_args, trace
from synapse.storage.roommember import ProfileInfo
from synapse.storage.util.partial_state_events_tracker import (
PartialCurrentStateTracker,
PartialStateEventsTracker,
)
from synapse.types import MutableStateMap, StateMap
from synapse.types import MutableStateMap, StateMap, get_domain_from_id
from synapse.types.state import StateFilter
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached
from synapse.util.cancellation import cancellable
from synapse.util.metrics import Measure

if TYPE_CHECKING:
from synapse.server import HomeServer
from synapse.state import _StateCacheEntry
from synapse.storage.databases import Databases


logger = logging.getLogger(__name__)


Expand All @@ -52,10 +61,15 @@ class StateStorageController:

def __init__(self, hs: "HomeServer", stores: "Databases"):
self._is_mine_id = hs.is_mine_id
self._clock = hs.get_clock()
self.stores = stores
self._partial_state_events_tracker = PartialStateEventsTracker(stores.main)
self._partial_state_room_tracker = PartialCurrentStateTracker(stores.main)

# Used by `_get_joined_hosts` to ensure only one thing mutates the cache
# at a time. Keyed by room_id.
self._joined_host_linearizer = Linearizer("_JoinedHostsCache")

def notify_event_un_partial_stated(self, event_id: str) -> None:
self._partial_state_events_tracker.notify_un_partial_stated(event_id)

Expand Down Expand Up @@ -627,3 +641,122 @@ async def get_users_in_room_with_profiles(
await self._partial_state_room_tracker.await_full_state(room_id)

return await self.stores.main.get_users_in_room_with_profiles(room_id)

async def get_joined_hosts(
self, room_id: str, state_entry: "_StateCacheEntry"
) -> FrozenSet[str]:
state_group: Union[object, int] = state_entry.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group
# of None don't hit previous cached calls with a None state_group.
# To do this we set the state_group to a new object as object() != object()
state_group = object()

assert state_group is not None
with Measure(self._clock, "get_joined_hosts"):
return await self._get_joined_hosts(
room_id, state_group, state_entry=state_entry
)

@cached(num_args=2, max_entries=10000, iterable=True)
async def _get_joined_hosts(
self,
room_id: str,
state_group: Union[object, int],
state_entry: "_StateCacheEntry",
) -> FrozenSet[str]:
# We don't use `state_group`, it's there so that we can cache based on
# it. However, its important that its never None, since two
# current_state's with a state_group of None are likely to be different.
#
# The `state_group` must match the `state_entry.state_group` (if not None).
assert state_group is not None
assert state_entry.state_group is None or state_entry.state_group == state_group

# We use a secondary cache of previous work to allow us to build up the
# joined hosts for the given state group based on previous state groups.
#
# We cache one object per room containing the results of the last state
# group we got joined hosts for. The idea is that generally
# `get_joined_hosts` is called with the "current" state group for the
# room, and so consecutive calls will be for consecutive state groups
# which point to the previous state group.
cache = await self.stores.main._get_joined_hosts_cache(room_id)

# If the state group in the cache matches, we already have the data we need.
if state_entry.state_group == cache.state_group:
return frozenset(cache.hosts_to_joined_users)

# Since we'll mutate the cache we need to lock.
async with self._joined_host_linearizer.queue(room_id):
if state_entry.state_group == cache.state_group:
# Same state group, so nothing to do. We've already checked for
# this above, but the cache may have changed while waiting on
# the lock.
pass
elif state_entry.prev_group == cache.state_group:
# The cached work is for the previous state group, so we work out
# the delta.
assert state_entry.delta_ids is not None
for (typ, state_key), event_id in state_entry.delta_ids.items():
if typ != EventTypes.Member:
continue

host = intern_string(get_domain_from_id(state_key))
user_id = state_key
known_joins = cache.hosts_to_joined_users.setdefault(host, set())

event = await self.stores.main.get_event(event_id)
if event.membership == Membership.JOIN:
known_joins.add(user_id)
else:
known_joins.discard(user_id)

if not known_joins:
cache.hosts_to_joined_users.pop(host, None)
else:
# The cache doesn't match the state group or prev state group,
# so we calculate the result from first principles.
#
# We need to fetch all hosts joined to the room according to `state` by
# inspecting all join memberships in `state`. However, if the `state` is
# relatively recent then many of its events are likely to be held in
# the current state of the room, which is easily available and likely
# cached.
#
# We therefore compute the set of `state` events not in the
# current state and only fetch those.
current_memberships = (
await self.stores.main._get_approximate_current_memberships_in_room(
room_id
)
)
unknown_state_events = {}
joined_users_in_current_state = []

state = await state_entry.get_state(
self, StateFilter.from_types([(EventTypes.Member, None)])
)

for (type, state_key), event_id in state.items():
if event_id not in current_memberships:
unknown_state_events[type, state_key] = event_id
elif current_memberships[event_id] == Membership.JOIN:
joined_users_in_current_state.append(state_key)

joined_user_ids = await self.stores.main.get_joined_user_ids_from_state(
room_id, unknown_state_events
)

cache.hosts_to_joined_users = {}
for user_id in chain(joined_user_ids, joined_users_in_current_state):
host = intern_string(get_domain_from_id(user_id))
cache.hosts_to_joined_users.setdefault(host, set()).add(user_id)

if state_entry.state_group:
cache.state_group = state_entry.state_group
else:
cache.state_group = object()

return frozenset(cache.hosts_to_joined_users)
122 changes: 0 additions & 122 deletions synapse/storage/databases/main/roommember.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# limitations under the License.
import logging
import re
from itertools import chain
from typing import (
TYPE_CHECKING,
AbstractSet,
Expand Down Expand Up @@ -58,15 +57,12 @@
StrCollection,
get_domain_from_id,
)
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure

if TYPE_CHECKING:
from synapse.server import HomeServer
from synapse.state import _StateCacheEntry

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -95,10 +91,6 @@ def __init__(
):
super().__init__(database, db_conn, hs)

# Used by `_get_joined_hosts` to ensure only one thing mutates the cache
# at a time. Keyed by room_id.
self._joined_host_linearizer = Linearizer("_JoinedHostsCache")

self._server_notices_mxid = hs.config.servernotices.server_notices_mxid

if (
Expand Down Expand Up @@ -1080,120 +1072,6 @@ def get_current_hosts_in_room_ordered_txn(txn: LoggingTransaction) -> List[str]:
"get_current_hosts_in_room_ordered", get_current_hosts_in_room_ordered_txn
)

async def get_joined_hosts(
self, room_id: str, state: StateMap[str], state_entry: "_StateCacheEntry"
) -> FrozenSet[str]:
state_group: Union[object, int] = state_entry.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group
# of None don't hit previous cached calls with a None state_group.
# To do this we set the state_group to a new object as object() != object()
state_group = object()

assert state_group is not None
with Measure(self._clock, "get_joined_hosts"):
return await self._get_joined_hosts(
room_id, state_group, state, state_entry=state_entry
)

@cached(num_args=2, max_entries=10000, iterable=True)
async def _get_joined_hosts(
self,
room_id: str,
state_group: Union[object, int],
state: StateMap[str],
state_entry: "_StateCacheEntry",
) -> FrozenSet[str]:
# We don't use `state_group`, it's there so that we can cache based on
# it. However, its important that its never None, since two
# current_state's with a state_group of None are likely to be different.
#
# The `state_group` must match the `state_entry.state_group` (if not None).
assert state_group is not None
assert state_entry.state_group is None or state_entry.state_group == state_group

# We use a secondary cache of previous work to allow us to build up the
# joined hosts for the given state group based on previous state groups.
#
# We cache one object per room containing the results of the last state
# group we got joined hosts for. The idea is that generally
# `get_joined_hosts` is called with the "current" state group for the
# room, and so consecutive calls will be for consecutive state groups
# which point to the previous state group.
cache = await self._get_joined_hosts_cache(room_id)

# If the state group in the cache matches, we already have the data we need.
if state_entry.state_group == cache.state_group:
return frozenset(cache.hosts_to_joined_users)

# Since we'll mutate the cache we need to lock.
async with self._joined_host_linearizer.queue(room_id):
if state_entry.state_group == cache.state_group:
# Same state group, so nothing to do. We've already checked for
# this above, but the cache may have changed while waiting on
# the lock.
pass
elif state_entry.prev_group == cache.state_group:
# The cached work is for the previous state group, so we work out
# the delta.
assert state_entry.delta_ids is not None
for (typ, state_key), event_id in state_entry.delta_ids.items():
if typ != EventTypes.Member:
continue

host = intern_string(get_domain_from_id(state_key))
user_id = state_key
known_joins = cache.hosts_to_joined_users.setdefault(host, set())

event = await self.get_event(event_id)
if event.membership == Membership.JOIN:
known_joins.add(user_id)
else:
known_joins.discard(user_id)

if not known_joins:
cache.hosts_to_joined_users.pop(host, None)
else:
# The cache doesn't match the state group or prev state group,
# so we calculate the result from first principles.
#
# We need to fetch all hosts joined to the room according to `state` by
# inspecting all join memberships in `state`. However, if the `state` is
# relatively recent then many of its events are likely to be held in
# the current state of the room, which is easily available and likely
# cached.
#
# We therefore compute the set of `state` events not in the
# current state and only fetch those.
current_memberships = (
await self._get_approximate_current_memberships_in_room(room_id)
)
unknown_state_events = {}
joined_users_in_current_state = []

for (type, state_key), event_id in state.items():
if event_id not in current_memberships:
unknown_state_events[type, state_key] = event_id
elif current_memberships[event_id] == Membership.JOIN:
joined_users_in_current_state.append(state_key)

joined_user_ids = await self.get_joined_user_ids_from_state(
room_id, unknown_state_events
)

cache.hosts_to_joined_users = {}
for user_id in chain(joined_user_ids, joined_users_in_current_state):
host = intern_string(get_domain_from_id(user_id))
cache.hosts_to_joined_users.setdefault(host, set()).add(user_id)

if state_entry.state_group:
cache.state_group = state_entry.state_group
else:
cache.state_group = object()

return frozenset(cache.hosts_to_joined_users)

async def _get_approximate_current_memberships_in_room(
self, room_id: str
) -> Mapping[str, Optional[str]]:
Expand Down

0 comments on commit b037e3f

Please sign in to comment.