Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Return an immutable value from get_latest_event_ids_in_room. #16326

Merged
merged 3 commits into from
Sep 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/16326.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve type hints.
2 changes: 1 addition & 1 deletion synapse/events/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def is_state(self) -> bool:

async def build(
self,
prev_event_ids: StrCollection,
prev_event_ids: List[str],
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#16301 incorrectly changes Collection[str] -> StrSequence, but even Collection[str] is wrong. This needs to be JSON serializable, which pretty much means a List or Tuple.

auth_event_ids: Optional[List[str]],
depth: Optional[int] = None,
) -> EventBase:
Expand Down
8 changes: 3 additions & 5 deletions synapse/handlers/federation_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,12 +723,11 @@ async def _get_missing_events_for_pdu(
if not prevs - seen:
return

latest_list = await self._store.get_latest_event_ids_in_room(room_id)
latest_frozen = await self._store.get_latest_event_ids_in_room(room_id)

# We add the prev events that we have seen to the latest
# list to ensure the remote server doesn't give them to us
latest = set(latest_list)
latest |= seen
latest = seen | latest_frozen

logger.info(
"Requesting missing events between %s and %s",
Expand Down Expand Up @@ -1976,8 +1975,7 @@ async def _check_for_soft_fail(
# partial and full state and may not be accurate.
return

extrem_ids_list = await self._store.get_latest_event_ids_in_room(event.room_id)
extrem_ids = set(extrem_ids_list)
extrem_ids = await self._store.get_latest_event_ids_in_room(event.room_id)
prev_event_ids = set(event.prev_event_ids())

if extrem_ids == prev_event_ids:
Expand Down
9 changes: 4 additions & 5 deletions synapse/storage/controllers/persist_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from collections import deque
from typing import (
TYPE_CHECKING,
AbstractSet,
Any,
Awaitable,
Callable,
Expand Down Expand Up @@ -618,7 +619,7 @@ async def _persist_event_batch(
)

for room_id, ev_ctx_rm in events_by_room.items():
latest_event_ids = set(
latest_event_ids = (
await self.main_store.get_latest_event_ids_in_room(room_id)
)
new_latest_event_ids = await self._calculate_new_extremities(
Expand Down Expand Up @@ -740,7 +741,7 @@ async def _calculate_new_extremities(
self,
room_id: str,
event_contexts: List[Tuple[EventBase, EventContext]],
latest_event_ids: Collection[str],
latest_event_ids: AbstractSet[str],
) -> Set[str]:
"""Calculates the new forward extremities for a room given events to
persist.
Expand All @@ -758,8 +759,6 @@ async def _calculate_new_extremities(
and not event.internal_metadata.is_soft_failed()
]

latest_event_ids = set(latest_event_ids)

# start with the existing forward extremities
result = set(latest_event_ids)

Expand Down Expand Up @@ -798,7 +797,7 @@ async def _get_new_state_after_events(
self,
room_id: str,
events_context: List[Tuple[EventBase, EventContext]],
old_latest_event_ids: Set[str],
old_latest_event_ids: AbstractSet[str],
new_latest_event_ids: Set[str],
) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]], Set[str]]:
"""Calculate the current state dict after adding some new events to
Expand Down
8 changes: 5 additions & 3 deletions synapse/storage/databases/main/event_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
TYPE_CHECKING,
Collection,
Dict,
FrozenSet,
Iterable,
List,
Optional,
Expand Down Expand Up @@ -47,7 +48,7 @@
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.signatures import SignatureWorkerStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.types import JsonDict, StrCollection, StrSequence
from synapse.types import JsonDict, StrCollection
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
from synapse.util.caches.lrucache import LruCache
Expand Down Expand Up @@ -1179,13 +1180,14 @@ def _get_rooms_with_many_extremities_txn(txn: LoggingTransaction) -> List[str]:
)

@cached(max_entries=5000, iterable=True)
async def get_latest_event_ids_in_room(self, room_id: str) -> StrSequence:
return await self.db_pool.simple_select_onecol(
async def get_latest_event_ids_in_room(self, room_id: str) -> FrozenSet[str]:
event_ids = await self.db_pool.simple_select_onecol(
table="event_forward_extremities",
keyvalues={"room_id": room_id},
retcol="event_id",
desc="get_latest_event_ids_in_room",
)
return frozenset(event_ids)

async def get_min_depth(self, room_id: str) -> Optional[int]:
"""For the given room, get the minimum depth we have seen for it."""
Expand Down
2 changes: 1 addition & 1 deletion synapse/storage/databases/main/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ async def _persist_events_and_state_updates(

for room_id, latest_event_ids in new_forward_extremities.items():
self.store.get_latest_event_ids_in_room.prefill(
(room_id,), list(latest_event_ids)
(room_id,), frozenset(latest_event_ids)
)

async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]:
Expand Down
2 changes: 1 addition & 1 deletion tests/handlers/test_presence.py
Original file line number Diff line number Diff line change
Expand Up @@ -1858,7 +1858,7 @@ def _add_new_user(self, room_id: str, user_id: str) -> None:
)

event = self.get_success(
builder.build(prev_event_ids=prev_event_ids, auth_event_ids=None)
builder.build(prev_event_ids=list(prev_event_ids), auth_event_ids=None)
)

self.get_success(self.federation_event_handler.on_receive_pdu(hostname, event))
Expand Down
4 changes: 2 additions & 2 deletions tests/replication/storage/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def tearDown(self) -> None:
def test_get_latest_event_ids_in_room(self) -> None:
create = self.persist(type="m.room.create", key="", creator=USER_ID)
self.replicate()
self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id])
self.check("get_latest_event_ids_in_room", (ROOM_ID,), {create.event_id})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this works because if s: Set then s == frozenset(s)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

>>> {1, 2, 3} == frozenset({1, 2, 3})
True

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can compare equality of frozenset and set, yes. 👍


join = self.persist(
type="m.room.member",
Expand All @@ -99,7 +99,7 @@ def test_get_latest_event_ids_in_room(self) -> None:
prev_events=[(create.event_id, {})],
)
self.replicate()
self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id])
self.check("get_latest_event_ids_in_room", (ROOM_ID,), {join.event_id})

def test_redactions(self) -> None:
self.persist(type="m.room.create", key="", creator=USER_ID)
Expand Down
10 changes: 5 additions & 5 deletions tests/replication/tcp/streams/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, List, Optional, Sequence
from typing import Any, List, Optional

from twisted.test.proto_helpers import MemoryReactor

Expand Down Expand Up @@ -139,7 +139,7 @@ def test_update_function_huge_state_change(self) -> None:
)

# this is the point in the DAG where we make a fork
fork_point: Sequence[str] = self.get_success(
fork_point = self.get_success(
self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
)

Expand Down Expand Up @@ -294,7 +294,7 @@ def test_update_function_state_row_limit(self) -> None:
)

# this is the point in the DAG where we make a fork
fork_point: Sequence[str] = self.get_success(
fork_point = self.get_success(
self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
)

Expand All @@ -316,14 +316,14 @@ def test_update_function_state_row_limit(self) -> None:
self.test_handler.received_rdata_rows.clear()

# now roll back all that state by de-modding the users
prev_events = fork_point
prev_events = list(fork_point)
pl_events = []
for u in user_ids:
pls["users"][u] = 0
e = self.get_success(
inject_event(
self.hs,
prev_event_ids=list(prev_events),
prev_event_ids=prev_events,
type=EventTypes.PowerLevels,
state_key="",
sender=self.user_id,
Expand Down
2 changes: 1 addition & 1 deletion tests/replication/test_federation_sender_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def create_room_with_remote_server(

builder = factory.for_room_version(room_version, event_dict)
join_event = self.get_success(
builder.build(prev_event_ids=prev_event_ids, auth_event_ids=None)
builder.build(prev_event_ids=list(prev_event_ids), auth_event_ids=None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC there is no defined iteration order for a set. I wondered if this might introduce some non-determinism or similar problem... but it should be fine: the order of the prev events list doesn't have any significance.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The order shouldn't matter here is my understanding.

)

self.get_success(federation.on_send_membership_event(remote_server, join_event))
Expand Down
14 changes: 7 additions & 7 deletions tests/storage/test_cleanup_extrems.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def test_soft_failed_extremities_handled_correctly(self) -> None:
self.store.get_latest_event_ids_in_room(self.room_id)
)

self.assertEqual(latest_event_ids, [event_id_4])
self.assertEqual(latest_event_ids, {event_id_4})

def test_basic_cleanup(self) -> None:
"""Test that extremities are correctly calculated in the presence of
Expand All @@ -147,15 +147,15 @@ def test_basic_cleanup(self) -> None:
latest_event_ids = self.get_success(
self.store.get_latest_event_ids_in_room(self.room_id)
)
self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b})
self.assertEqual(latest_event_ids, {event_id_a, event_id_b})

# Run the background update and check it did the right thing
self.run_background_update()

latest_event_ids = self.get_success(
self.store.get_latest_event_ids_in_room(self.room_id)
)
self.assertEqual(latest_event_ids, [event_id_b])
self.assertEqual(latest_event_ids, {event_id_b})

def test_chain_of_fail_cleanup(self) -> None:
"""Test that extremities are correctly calculated in the presence of
Expand Down Expand Up @@ -185,15 +185,15 @@ def test_chain_of_fail_cleanup(self) -> None:
latest_event_ids = self.get_success(
self.store.get_latest_event_ids_in_room(self.room_id)
)
self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b})
self.assertEqual(latest_event_ids, {event_id_a, event_id_b})

# Run the background update and check it did the right thing
self.run_background_update()

latest_event_ids = self.get_success(
self.store.get_latest_event_ids_in_room(self.room_id)
)
self.assertEqual(latest_event_ids, [event_id_b])
self.assertEqual(latest_event_ids, {event_id_b})

def test_forked_graph_cleanup(self) -> None:
r"""Test that extremities are correctly calculated in the presence of
Expand Down Expand Up @@ -240,15 +240,15 @@ def test_forked_graph_cleanup(self) -> None:
latest_event_ids = self.get_success(
self.store.get_latest_event_ids_in_room(self.room_id)
)
self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b, event_id_c})
self.assertEqual(latest_event_ids, {event_id_a, event_id_b, event_id_c})

# Run the background update and check it did the right thing
self.run_background_update()

latest_event_ids = self.get_success(
self.store.get_latest_event_ids_in_room(self.room_id)
)
self.assertEqual(set(latest_event_ids), {event_id_b, event_id_c})
self.assertEqual(latest_event_ids, {event_id_b, event_id_c})


class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
Expand Down
26 changes: 17 additions & 9 deletions tests/test_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,15 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = self.hs.get_datastores().main

# Figure out what the most recent event is
most_recent = self.get_success(
self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
)[0]
most_recent = next(
iter(
self.get_success(
self.hs.get_datastores().main.get_latest_event_ids_in_room(
self.room_id
)
)
)
)

join_event = make_event_from_dict(
{
Expand Down Expand Up @@ -100,8 +106,8 @@ async def _check_sigs_and_hash_for_pulled_events_and_fetch(

# Make sure we actually joined the room
self.assertEqual(
self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))[0],
"$join:test.serv",
self.get_success(self.store.get_latest_event_ids_in_room(self.room_id)),
{"$join:test.serv"},
)

def test_cant_hide_direct_ancestors(self) -> None:
Expand All @@ -127,9 +133,11 @@ async def post_json(
self.http_client.post_json = post_json

# Figure out what the most recent event is
most_recent = self.get_success(
self.store.get_latest_event_ids_in_room(self.room_id)
)[0]
most_recent = next(
iter(
self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))
)
)

# Now lie about an event
lying_event = make_event_from_dict(
Expand Down Expand Up @@ -165,7 +173,7 @@ async def post_json(

# Make sure the invalid event isn't there
extrem = self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))
self.assertEqual(extrem[0], "$join:test.serv")
self.assertEqual(extrem, {"$join:test.serv"})

def test_retry_device_list_resync(self) -> None:
"""Tests that device lists are marked as stale if they couldn't be synced, and
Expand Down
Loading