diff --git a/changelog.d/12753.misc b/changelog.d/12753.misc new file mode 100644 index 000000000000..e793d08e5e3f --- /dev/null +++ b/changelog.d/12753.misc @@ -0,0 +1 @@ +Add some type hints to datastore. \ No newline at end of file diff --git a/mypy.ini b/mypy.ini index 45668974b363..4fa020b8764d 100644 --- a/mypy.ini +++ b/mypy.ini @@ -27,7 +27,6 @@ exclude = (?x) |synapse/storage/databases/__init__.py |synapse/storage/databases/main/cache.py |synapse/storage/databases/main/devices.py - |synapse/storage/databases/main/event_federation.py |synapse/storage/schema/ |tests/api/test_auth.py diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py index 29de7e5bed10..fbfd7484065c 100644 --- a/synapse/handlers/room_batch.py +++ b/synapse/handlers/room_batch.py @@ -53,6 +53,7 @@ async def inherit_depth_from_prev_ids(self, prev_event_ids: List[str]) -> int: # We want to use the successor event depth so they appear after `prev_event` because # it has a larger `depth` but before the successor event because the `stream_ordering` # is negative before the successor event. + assert most_recent_prev_event_id is not None successor_event_ids = await self.store.get_successor_events( most_recent_prev_event_id ) @@ -139,6 +140,7 @@ async def get_most_recent_full_state_ids_from_event_id_list( _, ) = await self.store.get_max_depth_of(event_ids) # mapping from (type, state_key) -> state_event_id + assert most_recent_event_id is not None prev_state_map = await self.state_store.get_state_ids_for_event( most_recent_event_id ) diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 471022470843..dcfe8caf473a 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -14,7 +14,17 @@ import itertools import logging from queue import Empty, PriorityQueue -from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple +from typing import ( + TYPE_CHECKING, + Collection, + Dict, + Iterable, + List, + Optional, + Set, + Tuple, + cast, +) import attr from prometheus_client import Counter, Gauge @@ -33,7 +43,7 @@ from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.signatures import SignatureWorkerStore from synapse.storage.engines import PostgresEngine -from synapse.storage.types import Cursor +from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached from synapse.util.caches.lrucache import LruCache @@ -135,7 +145,7 @@ async def get_auth_chain_ids( # Check if we have indexed the room so we can use the chain cover # algorithm. - room = await self.get_room(room_id) + room = await self.get_room(room_id) # type: ignore[attr-defined] if room["has_auth_chain_index"]: try: return await self.db_pool.runInteraction( @@ -158,7 +168,11 @@ async def get_auth_chain_ids( ) def _get_auth_chain_ids_using_cover_index_txn( - self, txn: Cursor, room_id: str, event_ids: Collection[str], include_given: bool + self, + txn: LoggingTransaction, + room_id: str, + event_ids: Collection[str], + include_given: bool, ) -> Set[str]: """Calculates the auth chain IDs using the chain index.""" @@ -215,9 +229,9 @@ def _get_auth_chain_ids_using_cover_index_txn( chains: Dict[int, int] = {} # Add all linked chains reachable from initial set of chains. - for batch in batch_iter(event_chains, 1000): + for batch2 in batch_iter(event_chains, 1000): clause, args = make_in_list_sql_clause( - txn.database_engine, "origin_chain_id", batch + txn.database_engine, "origin_chain_id", batch2 ) txn.execute(sql % (clause,), args) @@ -297,7 +311,7 @@ def _get_auth_chain_ids_txn( front = set(event_ids) while front: - new_front = set() + new_front: Set[str] = set() for chunk in batch_iter(front, 100): # Pull the auth events either from the cache or DB. to_fetch: List[str] = [] # Event IDs to fetch from DB @@ -316,7 +330,7 @@ def _get_auth_chain_ids_txn( # Note we need to batch up the results by event ID before # adding to the cache. - to_cache = {} + to_cache: Dict[str, List[Tuple[str, int]]] = {} for event_id, auth_event_id, auth_event_depth in txn: to_cache.setdefault(event_id, []).append( (auth_event_id, auth_event_depth) @@ -349,7 +363,7 @@ async def get_auth_chain_difference( # Check if we have indexed the room so we can use the chain cover # algorithm. - room = await self.get_room(room_id) + room = await self.get_room(room_id) # type: ignore[attr-defined] if room["has_auth_chain_index"]: try: return await self.db_pool.runInteraction( @@ -370,7 +384,7 @@ async def get_auth_chain_difference( ) def _get_auth_chain_difference_using_cover_index_txn( - self, txn: Cursor, room_id: str, state_sets: List[Set[str]] + self, txn: LoggingTransaction, room_id: str, state_sets: List[Set[str]] ) -> Set[str]: """Calculates the auth chain difference using the chain index. @@ -444,9 +458,9 @@ def _get_auth_chain_difference_using_cover_index_txn( # (We need to take a copy of `seen_chains` as we want to mutate it in # the loop) - for batch in batch_iter(set(seen_chains), 1000): + for batch2 in batch_iter(set(seen_chains), 1000): clause, args = make_in_list_sql_clause( - txn.database_engine, "origin_chain_id", batch + txn.database_engine, "origin_chain_id", batch2 ) txn.execute(sql % (clause,), args) @@ -529,7 +543,7 @@ def _get_auth_chain_difference_using_cover_index_txn( return result def _get_auth_chain_difference_txn( - self, txn, state_sets: List[Set[str]] + self, txn: LoggingTransaction, state_sets: List[Set[str]] ) -> Set[str]: """Calculates the auth chain difference using a breadth first search. @@ -602,7 +616,7 @@ def _get_auth_chain_difference_txn( # I think building a temporary list with fetchall is more efficient than # just `search.extend(txn)`, but this is unconfirmed - search.extend(txn.fetchall()) + search.extend(cast(List[Tuple[int, str]], txn.fetchall())) # sort by depth search.sort() @@ -645,7 +659,7 @@ def _get_auth_chain_difference_txn( # We parse the results and add the to the `found` set and the # cache (note we need to batch up the results by event ID before # adding to the cache). - to_cache = {} + to_cache: Dict[str, List[Tuple[str, int]]] = {} for event_id, auth_event_id, auth_event_depth in txn: to_cache.setdefault(event_id, []).append( (auth_event_id, auth_event_depth) @@ -696,7 +710,7 @@ def _get_auth_chain_difference_txn( return {eid for eid, n in event_to_missing_sets.items() if n} async def get_oldest_event_ids_with_depth_in_room( - self, room_id + self, room_id: str ) -> List[Tuple[str, int]]: """Gets the oldest events(backwards extremities) in the room along with the aproximate depth. @@ -713,7 +727,9 @@ async def get_oldest_event_ids_with_depth_in_room( List of (event_id, depth) tuples """ - def get_oldest_event_ids_with_depth_in_room_txn(txn, room_id): + def get_oldest_event_ids_with_depth_in_room_txn( + txn: LoggingTransaction, room_id: str + ) -> List[Tuple[str, int]]: # Assemble a dictionary with event_id -> depth for the oldest events # we know of in the room. Backwards extremeties are the oldest # events we know of in the room but we only know of them because @@ -743,7 +759,7 @@ def get_oldest_event_ids_with_depth_in_room_txn(txn, room_id): txn.execute(sql, (room_id, False)) - return txn.fetchall() + return cast(List[Tuple[str, int]], txn.fetchall()) return await self.db_pool.runInteraction( "get_oldest_event_ids_with_depth_in_room", @@ -752,7 +768,7 @@ def get_oldest_event_ids_with_depth_in_room_txn(txn, room_id): ) async def get_insertion_event_backward_extremities_in_room( - self, room_id + self, room_id: str ) -> List[Tuple[str, int]]: """Get the insertion events we know about that we haven't backfilled yet. @@ -768,7 +784,9 @@ async def get_insertion_event_backward_extremities_in_room( List of (event_id, depth) tuples """ - def get_insertion_event_backward_extremities_in_room_txn(txn, room_id): + def get_insertion_event_backward_extremities_in_room_txn( + txn: LoggingTransaction, room_id: str + ) -> List[Tuple[str, int]]: sql = """ SELECT b.event_id, MAX(e.depth) FROM insertion_events as i /* We only want insertion events that are also marked as backwards extremities */ @@ -780,7 +798,7 @@ def get_insertion_event_backward_extremities_in_room_txn(txn, room_id): """ txn.execute(sql, (room_id,)) - return txn.fetchall() + return cast(List[Tuple[str, int]], txn.fetchall()) return await self.db_pool.runInteraction( "get_insertion_event_backward_extremities_in_room", @@ -788,7 +806,7 @@ def get_insertion_event_backward_extremities_in_room_txn(txn, room_id): room_id, ) - async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[str, int]: + async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[Optional[str], int]: """Returns the event ID and depth for the event that has the max depth from a set of event IDs Args: @@ -817,7 +835,7 @@ async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[str, int]: return max_depth_event_id, current_max_depth - async def get_min_depth_of(self, event_ids: List[str]) -> Tuple[str, int]: + async def get_min_depth_of(self, event_ids: List[str]) -> Tuple[Optional[str], int]: """Returns the event ID and depth for the event that has the min depth from a set of event IDs Args: @@ -865,7 +883,9 @@ async def get_prev_events_for_room(self, room_id: str) -> List[str]: "get_prev_events_for_room", self._get_prev_events_for_room_txn, room_id ) - def _get_prev_events_for_room_txn(self, txn, room_id: str): + def _get_prev_events_for_room_txn( + self, txn: LoggingTransaction, room_id: str + ) -> List[str]: # we just use the 10 newest events. Older events will become # prev_events of future events. @@ -896,7 +916,7 @@ async def get_rooms_with_many_extremities( sorted by extremity count. """ - def _get_rooms_with_many_extremities_txn(txn): + def _get_rooms_with_many_extremities_txn(txn: LoggingTransaction) -> List[str]: where_clause = "1=1" if room_id_filter: where_clause = "room_id NOT IN (%s)" % ( @@ -937,7 +957,9 @@ async def get_min_depth(self, room_id: str) -> Optional[int]: "get_min_depth", self._get_min_depth_interaction, room_id ) - def _get_min_depth_interaction(self, txn, room_id): + def _get_min_depth_interaction( + self, txn: LoggingTransaction, room_id: str + ) -> Optional[int]: min_depth = self.db_pool.simple_select_one_onecol_txn( txn, table="room_depth", @@ -966,22 +988,24 @@ async def get_forward_extremities_for_room_at_stream_ordering( """ # We want to make the cache more effective, so we clamp to the last # change before the given ordering. - last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id) + last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id) # type: ignore[attr-defined] # We don't always have a full stream_to_exterm_id table, e.g. after # the upgrade that introduced it, so we make sure we never ask for a # stream_ordering from before a restart - last_change = max(self._stream_order_on_start, last_change) + last_change = max(self._stream_order_on_start, last_change) # type: ignore[attr-defined] # provided the last_change is recent enough, we now clamp the requested # stream_ordering to it. - if last_change > self.stream_ordering_month_ago: + if last_change > self.stream_ordering_month_ago: # type: ignore[attr-defined] stream_ordering = min(last_change, stream_ordering) return await self._get_forward_extremeties_for_room(room_id, stream_ordering) @cached(max_entries=5000, num_args=2) - async def _get_forward_extremeties_for_room(self, room_id, stream_ordering): + async def _get_forward_extremeties_for_room( + self, room_id: str, stream_ordering: int + ) -> List[str]: """For a given room_id and stream_ordering, return the forward extremeties of the room at that point in "time". @@ -989,7 +1013,7 @@ async def _get_forward_extremeties_for_room(self, room_id, stream_ordering): stream_orderings from that point. """ - if stream_ordering <= self.stream_ordering_month_ago: + if stream_ordering <= self.stream_ordering_month_ago: # type: ignore[attr-defined] raise StoreError(400, "stream_ordering too old %s" % (stream_ordering,)) sql = """ @@ -1002,7 +1026,7 @@ async def _get_forward_extremeties_for_room(self, room_id, stream_ordering): WHERE room_id = ? """ - def get_forward_extremeties_for_room_txn(txn): + def get_forward_extremeties_for_room_txn(txn: LoggingTransaction) -> List[str]: txn.execute(sql, (stream_ordering, room_id)) return [event_id for event_id, in txn] @@ -1104,8 +1128,8 @@ def _get_connected_prev_event_backfill_results_txn( ] async def get_backfill_events( - self, room_id: str, seed_event_id_list: list, limit: int - ): + self, room_id: str, seed_event_id_list: List[str], limit: int + ) -> List[EventBase]: """Get a list of Events for a given topic that occurred before (and including) the events in seed_event_id_list. Return a list of max size `limit` @@ -1123,10 +1147,19 @@ async def get_backfill_events( ) events = await self.get_events_as_list(event_ids) return sorted( - events, key=lambda e: (-e.depth, -e.internal_metadata.stream_ordering) + # type-ignore: mypy doesn't like negating the Optional[int] stream_ordering. + # But it's never None, because these events were previously persisted to the DB. + events, + key=lambda e: (-e.depth, -e.internal_metadata.stream_ordering), # type: ignore[operator] ) - def _get_backfill_events(self, txn, room_id, seed_event_id_list, limit): + def _get_backfill_events( + self, + txn: LoggingTransaction, + room_id: str, + seed_event_id_list: List[str], + limit: int, + ) -> Set[str]: """ We want to make sure that we do a breadth-first, "depth" ordered search. We also handle navigating historical branches of history connected by @@ -1139,7 +1172,7 @@ def _get_backfill_events(self, txn, room_id, seed_event_id_list, limit): limit, ) - event_id_results = set() + event_id_results: Set[str] = set() # In a PriorityQueue, the lowest valued entries are retrieved first. # We're using depth as the priority in the queue and tie-break based on @@ -1147,7 +1180,7 @@ def _get_backfill_events(self, txn, room_id, seed_event_id_list, limit): # highest and newest-in-time message. We add events to the queue with a # negative depth so that we process the newest-in-time messages first # going backwards in time. stream_ordering follows the same pattern. - queue = PriorityQueue() + queue: "PriorityQueue[Tuple[int, int, str, str]]" = PriorityQueue() for seed_event_id in seed_event_id_list: event_lookup_result = self.db_pool.simple_select_one_txn( @@ -1253,7 +1286,13 @@ def _get_backfill_events(self, txn, room_id, seed_event_id_list, limit): return event_id_results - async def get_missing_events(self, room_id, earliest_events, latest_events, limit): + async def get_missing_events( + self, + room_id: str, + earliest_events: List[str], + latest_events: List[str], + limit: int, + ) -> List[EventBase]: ids = await self.db_pool.runInteraction( "get_missing_events", self._get_missing_events, @@ -1264,11 +1303,18 @@ async def get_missing_events(self, room_id, earliest_events, latest_events, limi ) return await self.get_events_as_list(ids) - def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit): + def _get_missing_events( + self, + txn: LoggingTransaction, + room_id: str, + earliest_events: List[str], + latest_events: List[str], + limit: int, + ) -> List[str]: seen_events = set(earliest_events) front = set(latest_events) - seen_events - event_results = [] + event_results: List[str] = [] query = ( "SELECT prev_event_id FROM event_edges " @@ -1311,7 +1357,7 @@ async def get_successor_events(self, event_id: str) -> List[str]: @wrap_as_background_process("delete_old_forward_extrem_cache") async def _delete_old_forward_extrem_cache(self) -> None: - def _delete_old_forward_extrem_cache_txn(txn): + def _delete_old_forward_extrem_cache_txn(txn: LoggingTransaction) -> None: # Delete entries older than a month, while making sure we don't delete # the only entries for a room. sql = """ @@ -1324,7 +1370,7 @@ def _delete_old_forward_extrem_cache_txn(txn): ) AND stream_ordering < ? """ txn.execute( - sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago) + sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago) # type: ignore[attr-defined] ) await self.db_pool.runInteraction( @@ -1382,7 +1428,9 @@ async def remove_received_event_from_staging( """ if self.db_pool.engine.supports_returning: - def _remove_received_event_from_staging_txn(txn): + def _remove_received_event_from_staging_txn( + txn: LoggingTransaction, + ) -> Optional[int]: sql = """ DELETE FROM federation_inbound_events_staging WHERE origin = ? AND event_id = ? @@ -1390,21 +1438,24 @@ def _remove_received_event_from_staging_txn(txn): """ txn.execute(sql, (origin, event_id)) - return txn.fetchone() + row = cast(Optional[Tuple[int]], txn.fetchone()) - row = await self.db_pool.runInteraction( + if row is None: + return None + + return row[0] + + return await self.db_pool.runInteraction( "remove_received_event_from_staging", _remove_received_event_from_staging_txn, db_autocommit=True, ) - if row is None: - return None - - return row[0] else: - def _remove_received_event_from_staging_txn(txn): + def _remove_received_event_from_staging_txn( + txn: LoggingTransaction, + ) -> Optional[int]: received_ts = self.db_pool.simple_select_one_onecol_txn( txn, table="federation_inbound_events_staging", @@ -1437,7 +1488,9 @@ async def get_next_staged_event_id_for_room( ) -> Optional[Tuple[str, str]]: """Get the next event ID in the staging area for the given room.""" - def _get_next_staged_event_id_for_room_txn(txn): + def _get_next_staged_event_id_for_room_txn( + txn: LoggingTransaction, + ) -> Optional[Tuple[str, str]]: sql = """ SELECT origin, event_id FROM federation_inbound_events_staging @@ -1448,7 +1501,7 @@ def _get_next_staged_event_id_for_room_txn(txn): txn.execute(sql, (room_id,)) - return txn.fetchone() + return cast(Optional[Tuple[str, str]], txn.fetchone()) return await self.db_pool.runInteraction( "get_next_staged_event_id_for_room", _get_next_staged_event_id_for_room_txn @@ -1461,7 +1514,9 @@ async def get_next_staged_event_for_room( ) -> Optional[Tuple[str, EventBase]]: """Get the next event in the staging area for the given room.""" - def _get_next_staged_event_for_room_txn(txn): + def _get_next_staged_event_for_room_txn( + txn: LoggingTransaction, + ) -> Optional[Tuple[str, str, str]]: sql = """ SELECT event_json, internal_metadata, origin FROM federation_inbound_events_staging @@ -1471,7 +1526,7 @@ def _get_next_staged_event_for_room_txn(txn): """ txn.execute(sql, (room_id,)) - return txn.fetchone() + return cast(Optional[Tuple[str, str, str]], txn.fetchone()) row = await self.db_pool.runInteraction( "get_next_staged_event_for_room", _get_next_staged_event_for_room_txn @@ -1599,18 +1654,20 @@ async def get_all_rooms_with_staged_incoming_events(self) -> List[str]: ) @wrap_as_background_process("_get_stats_for_federation_staging") - async def _get_stats_for_federation_staging(self): + async def _get_stats_for_federation_staging(self) -> None: """Update the prometheus metrics for the inbound federation staging area.""" - def _get_stats_for_federation_staging_txn(txn): + def _get_stats_for_federation_staging_txn( + txn: LoggingTransaction, + ) -> Tuple[int, int]: txn.execute("SELECT count(*) FROM federation_inbound_events_staging") - (count,) = txn.fetchone() + (count,) = cast(Tuple[int], txn.fetchone()) txn.execute( "SELECT min(received_ts) FROM federation_inbound_events_staging" ) - (received_ts,) = txn.fetchone() + (received_ts,) = cast(Tuple[Optional[int]], txn.fetchone()) # If there is nothing in the staging area default it to 0. age = 0 @@ -1651,19 +1708,21 @@ def __init__( self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth ) - async def clean_room_for_join(self, room_id): - return await self.db_pool.runInteraction( + async def clean_room_for_join(self, room_id: str) -> None: + await self.db_pool.runInteraction( "clean_room_for_join", self._clean_room_for_join_txn, room_id ) - def _clean_room_for_join_txn(self, txn, room_id): + def _clean_room_for_join_txn(self, txn: LoggingTransaction, room_id: str) -> None: query = "DELETE FROM event_forward_extremities WHERE room_id = ?" txn.execute(query, (room_id,)) txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,)) - async def _background_delete_non_state_event_auth(self, progress, batch_size): - def delete_event_auth(txn): + async def _background_delete_non_state_event_auth( + self, progress: JsonDict, batch_size: int + ) -> int: + def delete_event_auth(txn: LoggingTransaction) -> bool: target_min_stream_id = progress.get("target_min_stream_id_inclusive") max_stream_id = progress.get("max_stream_id_exclusive") diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index 060ba5f5174f..e95dfdce2086 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -332,6 +332,7 @@ def test_backfill_floating_outlier_membership_auth(self) -> None: most_recent_prev_event_depth, ) = self.get_success(self.store.get_max_depth_of(prev_event_ids)) # mapping from (type, state_key) -> state_event_id + assert most_recent_prev_event_id is not None prev_state_map = self.get_success( self.state_store.get_state_ids_for_event(most_recent_prev_event_id) )