Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Pull out less state when handling gaps mk2 (#12852)
Browse files Browse the repository at this point in the history
  • Loading branch information
erikjohnston authored May 26, 2022
1 parent 1b33847 commit b83bc5f
Show file tree
Hide file tree
Showing 8 changed files with 236 additions and 127 deletions.
1 change: 1 addition & 0 deletions changelog.d/12852.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Pull out less state when handling gaps in room DAG.
178 changes: 84 additions & 94 deletions synapse/handlers/federation_event.py

Large diffs are not rendered by default.

40 changes: 37 additions & 3 deletions synapse/handlers/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,14 @@
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.state import StateFilter
from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester
from synapse.types import (
MutableStateMap,
Requester,
RoomAlias,
StreamToken,
UserID,
create_requester,
)
from synapse.util import json_decoder, json_encoder, log_failure, unwrapFirstError
from synapse.util.async_helpers import Linearizer, gather_results
from synapse.util.caches.expiringcache import ExpiringCache
Expand Down Expand Up @@ -1022,8 +1029,35 @@ async def create_new_client_event(
#
# TODO(faster_joins): figure out how this works, and make sure that the
# old state is complete.
old_state = await self.store.get_events_as_list(state_event_ids)
context = await self.state.compute_event_context(event, old_state=old_state)
metadata = await self.store.get_metadata_for_events(state_event_ids)

state_map_for_event: MutableStateMap[str] = {}
for state_id in state_event_ids:
data = metadata.get(state_id)
if data is None:
# We're trying to persist a new historical batch of events
# with the given state, e.g. via
# `RoomBatchSendEventRestServlet`. The state can be inferred
# by Synapse or set directly by the client.
#
# Either way, we should have persisted all the state before
# getting here.
raise Exception(
f"State event {state_id} not found in DB,"
" Synapse should have persisted it before using it."
)

if data.state_key is None:
raise Exception(
f"Trying to set non-state event {state_id} as state"
)

state_map_for_event[(data.event_type, data.state_key)] = state_id

context = await self.state.compute_event_context(
event,
state_ids_before_event=state_map_for_event,
)
else:
context = await self.state.compute_event_context(event)

Expand Down
22 changes: 10 additions & 12 deletions synapse/state/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ async def get_hosts_in_room_at_events(
async def compute_event_context(
self,
event: EventBase,
old_state: Optional[Iterable[EventBase]] = None,
state_ids_before_event: Optional[StateMap[str]] = None,
partial_state: bool = False,
) -> EventContext:
"""Build an EventContext structure for a non-outlier event.
Expand All @@ -273,26 +273,24 @@ async def compute_event_context(
Args:
event:
old_state: The state at the event if it can't be
calculated from existing events. This is normally only specified
when receiving an event from federation where we don't have the
prev events for, e.g. when backfilling.
partial_state: True if `old_state` is partial and omits non-critical
membership events
state_ids_before_event: The event ids of the state before the event if
it can't be calculated from existing events. This is normally
only specified when receiving an event from federation where we
don't have the prev events, e.g. when backfilling.
partial_state: True if `state_ids_before_event` is partial and omits
non-critical membership events
Returns:
The event context.
"""

assert not event.internal_metadata.is_outlier()

#
# first of all, figure out the state before the event
# first of all, figure out the state before the event, unless we
# already have it.
#
if old_state:
if state_ids_before_event:
# if we're given the state before the event, then we use that
state_ids_before_event: StateMap[str] = {
(s.type, s.state_key): s.event_id for s in old_state
}
state_group_before_event = None
state_group_before_event_prev_group = None
deltas_to_state_group_before_event = None
Expand Down
59 changes: 59 additions & 0 deletions synapse/storage/databases/main/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import logging
from typing import TYPE_CHECKING, Collection, Dict, Iterable, Optional, Set, Tuple

import attr

from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
Expand All @@ -26,13 +28,15 @@
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
make_in_list_sql_clause,
)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.state import StateFilter
from synapse.types import JsonDict, JsonMapping, StateMap
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.iterutils import batch_iter

if TYPE_CHECKING:
from synapse.server import HomeServer
Expand All @@ -43,6 +47,15 @@
MAX_STATE_DELTA_HOPS = 100


@attr.s(slots=True, frozen=True, auto_attribs=True)
class EventMetadata:
"""Returned by `get_metadata_for_events`"""

room_id: str
event_type: str
state_key: Optional[str]


def _retrieve_and_check_room_version(room_id: str, room_version_id: str) -> RoomVersion:
v = KNOWN_ROOM_VERSIONS.get(room_version_id)
if not v:
Expand Down Expand Up @@ -133,6 +146,52 @@ def get_room_version_id_txn(self, txn: LoggingTransaction, room_id: str) -> str:

return room_version

async def get_metadata_for_events(
self, event_ids: Collection[str]
) -> Dict[str, EventMetadata]:
"""Get some metadata (room_id, type, state_key) for the given events.
This method is a faster alternative than fetching the full events from
the DB, and should be used when the full event is not needed.
Returns metadata for rejected and redacted events. Events that have not
been persisted are omitted from the returned dict.
"""

def get_metadata_for_events_txn(
txn: LoggingTransaction,
batch_ids: Collection[str],
) -> Dict[str, EventMetadata]:
clause, args = make_in_list_sql_clause(
self.database_engine, "e.event_id", batch_ids
)

sql = f"""
SELECT e.event_id, e.room_id, e.type, e.state_key FROM events AS e
LEFT JOIN state_events USING (event_id)
WHERE {clause}
"""

txn.execute(sql, args)
return {
event_id: EventMetadata(
room_id=room_id, event_type=event_type, state_key=state_key
)
for event_id, room_id, event_type, state_key in txn
}

result_map: Dict[str, EventMetadata] = {}
for batch_ids in batch_iter(event_ids, 1000):
result_map.update(
await self.db_pool.runInteraction(
"get_metadata_for_events",
get_metadata_for_events_txn,
batch_ids=batch_ids,
)
)

return result_map

async def get_room_predecessor(self, room_id: str) -> Optional[JsonMapping]:
"""Get the predecessor of an upgraded room if it exists.
Otherwise return None.
Expand Down
6 changes: 5 additions & 1 deletion tests/handlers/test_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,11 @@ def test_backfill_with_many_backward_extremities(self) -> None:
# federation handler wanting to backfill the fake event.
self.get_success(
federation_event_handler._process_received_pdu(
self.OTHER_SERVER_NAME, event, state=current_state
self.OTHER_SERVER_NAME,
event,
state_ids={
(e.type, e.state_key): e.event_id for e in current_state
},
)
)

Expand Down
43 changes: 28 additions & 15 deletions tests/storage/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def prepare(self, reactor, clock, homeserver):
def persist_event(self, event, state=None):
"""Persist the event, with optional state"""
context = self.get_success(
self.state.compute_event_context(event, old_state=state)
self.state.compute_event_context(event, state_ids_before_event=state)
)
self.get_success(self.persistence.persist_event(event, context))

Expand Down Expand Up @@ -103,9 +103,11 @@ def test_prune_gap(self):
RoomVersions.V6,
)

state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
state_before_gap = self.get_success(
self.state.get_current_state_ids(self.room_id)
)

self.persist_event(remote_event_2, state=state_before_gap.values())
self.persist_event(remote_event_2, state=state_before_gap)

# Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id])
Expand Down Expand Up @@ -135,13 +137,14 @@ def test_do_not_prune_gap_if_state_different(self):
# setting. The state resolution across the old and new event will then
# include it, and so the resolved state won't match the new state.
state_before_gap = dict(
self.get_success(self.state.get_current_state(self.room_id))
self.get_success(self.state.get_current_state_ids(self.room_id))
)
state_before_gap.pop(("m.room.history_visibility", ""))

context = self.get_success(
self.state.compute_event_context(
remote_event_2, old_state=state_before_gap.values()
remote_event_2,
state_ids_before_event=state_before_gap,
)
)

Expand Down Expand Up @@ -177,9 +180,11 @@ def test_prune_gap_if_old(self):
RoomVersions.V6,
)

state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
state_before_gap = self.get_success(
self.state.get_current_state_ids(self.room_id)
)

self.persist_event(remote_event_2, state=state_before_gap.values())
self.persist_event(remote_event_2, state=state_before_gap)

# Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id])
Expand Down Expand Up @@ -207,9 +212,11 @@ def test_do_not_prune_gap_if_other_server(self):
RoomVersions.V6,
)

state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
state_before_gap = self.get_success(
self.state.get_current_state_ids(self.room_id)
)

self.persist_event(remote_event_2, state=state_before_gap.values())
self.persist_event(remote_event_2, state=state_before_gap)

# Check the new extremity is just the new remote event.
self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id])
Expand Down Expand Up @@ -247,9 +254,11 @@ def test_prune_gap_if_dummy_remote(self):
RoomVersions.V6,
)

state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
state_before_gap = self.get_success(
self.state.get_current_state_ids(self.room_id)
)

self.persist_event(remote_event_2, state=state_before_gap.values())
self.persist_event(remote_event_2, state=state_before_gap)

# Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id])
Expand Down Expand Up @@ -289,9 +298,11 @@ def test_prune_gap_if_dummy_local(self):
RoomVersions.V6,
)

state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
state_before_gap = self.get_success(
self.state.get_current_state_ids(self.room_id)
)

self.persist_event(remote_event_2, state=state_before_gap.values())
self.persist_event(remote_event_2, state=state_before_gap)

# Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id, local_message_event_id])
Expand Down Expand Up @@ -323,9 +334,11 @@ def test_do_not_prune_gap_if_not_dummy(self):
RoomVersions.V6,
)

state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
state_before_gap = self.get_success(
self.state.get_current_state_ids(self.room_id)
)

self.persist_event(remote_event_2, state=state_before_gap.values())
self.persist_event(remote_event_2, state=state_before_gap)

# Check the new extremity is just the new remote event.
self.assert_extremities([local_message_event_id, remote_event_2.event_id])
Expand Down
14 changes: 12 additions & 2 deletions tests/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,12 @@ def test_annotate_with_old_message(self):
]

context = yield defer.ensureDeferred(
self.state.compute_event_context(event, old_state=old_state)
self.state.compute_event_context(
event,
state_ids_before_event={
(e.type, e.state_key): e.event_id for e in old_state
},
)
)

prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
Expand All @@ -467,7 +472,12 @@ def test_annotate_with_old_state(self):
]

context = yield defer.ensureDeferred(
self.state.compute_event_context(event, old_state=old_state)
self.state.compute_event_context(
event,
state_ids_before_event={
(e.type, e.state_key): e.event_id for e in old_state
},
)
)

prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
Expand Down

0 comments on commit b83bc5f

Please sign in to comment.