Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add include_event_in_state to _get_state_for_room #6521

Merged
merged 2 commits into from
Dec 11, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/6521.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Refactor some code in the event authentication path for clarity.
50 changes: 28 additions & 22 deletions synapse/handlers/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,22 +378,10 @@ async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None:
(
remote_state,
got_auth_chain,
) = await self._get_state_for_room(origin, room_id, p)

# we want the state *after* p; _get_state_for_room returns the
# state *before* p.
remote_event = await self.federation_client.get_pdu(
[origin], p, room_version, outlier=True
) = await self._get_state_for_room(
origin, room_id, p, include_event_in_state=True
)

if remote_event is None:
raise Exception(
"Unable to get missing prev_event %s" % (p,)
)

if remote_event.is_state():
remote_state.append(remote_event)

# XXX hrm I'm not convinced that duplicate events will compare
# for equality, so I'm not sure this does what the author
# hoped.
Expand Down Expand Up @@ -579,20 +567,25 @@ async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
else:
raise

@log_function
async def _get_state_for_room(
self, destination: str, room_id: str, event_id: str
self,
destination: str,
room_id: str,
event_id: str,
include_event_in_state: bool = False,
) -> Tuple[List[EventBase], List[EventBase]]:
"""Requests all of the room state at a given event from a remote homeserver.

Args:
destination:: The remote homeserver to query for the state.
destination: The remote homeserver to query for the state.
room_id: The id of the room we're interested in.
event_id: The id of the event we want the state at.
include_event_in_state: if true, the event itself will be included in the
returned state event list.

Returns:
A list of events in the state, and a list of events in the auth chain
for the given event.
A list of events in the state, possibly including the event itself, and
a list of events in the auth chain for the given event.
"""
(
state_event_ids,
Expand All @@ -602,6 +595,10 @@ async def _get_state_for_room(
)

desired_events = set(state_event_ids + auth_event_ids)

if include_event_in_state:
desired_events.add(event_id)

event_map = await self._get_events_from_store_or_dest(
destination, room_id, desired_events
)
Expand All @@ -614,12 +611,21 @@ async def _get_state_for_room(
failed_to_fetch,
)

pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map]
auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
remote_state = [
event_map[e_id] for e_id in state_event_ids if e_id in event_map
]

if include_event_in_state:
remote_event = event_map.get(event_id)
if not remote_event:
raise Exception("Unable to get missing prev_event %s" % (event_id,))
if remote_event.is_state():
remote_state.append(remote_event)

auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
auth_chain.sort(key=lambda e: e.depth)

return pdus, auth_chain
return remote_state, auth_chain

async def _get_events_from_store_or_dest(
self, destination: str, room_id: str, event_ids: Iterable[str]
Expand Down