Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Convert state and stream stores and related code to async #8194

Merged
merged 8 commits into from
Aug 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/8194.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.
2 changes: 1 addition & 1 deletion synapse/handlers/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ async def clone_existing_room(
old_room_member_state_events = await self.store.get_events(
old_room_member_state_ids.values()
)
for k, old_event in old_room_member_state_events.items():
for old_event in old_room_member_state_events.values():
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mypy complained here that k was changing from a Tuple[str, str] (essentially a state key) to str (event ID).

Turns out we don't use k in this loop, so it is easier to just look at the values.

# Only transfer ban events
if (
"membership" in old_event.content
Expand Down
19 changes: 10 additions & 9 deletions synapse/storage/databases/main/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.state import StateFilter
from synapse.types import StateMap
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedList

Expand Down Expand Up @@ -163,15 +164,15 @@ async def get_create_event_for_room(self, room_id: str) -> EventBase:
return create_event

@cached(max_entries=100000, iterable=True)
def get_current_state_ids(self, room_id):
async def get_current_state_ids(self, room_id: str) -> StateMap[str]:
"""Get the current state event ids for a room based on the
current_state_events table.

Args:
room_id (str)
room_id: The room to get the state IDs of.

Returns:
deferred: dict of (type, state_key) -> event_id
The current state of the room.
"""

def _get_current_state_ids_txn(txn):
Expand All @@ -184,14 +185,14 @@ def _get_current_state_ids_txn(txn):

return {(intern_string(r[0]), intern_string(r[1])): r[2] for r in txn}

return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_current_state_ids", _get_current_state_ids_txn
)

# FIXME: how should this be cached?
def get_filtered_current_state_ids(
async def get_filtered_current_state_ids(
self, room_id: str, state_filter: StateFilter = StateFilter.all()
):
) -> StateMap[str]:
"""Get the current state event of a given type for a room based on the
current_state_events table. This may not be as up-to-date as the result
of doing a fresh state resolution as per state_handler.get_current_state
Expand All @@ -202,14 +203,14 @@ def get_filtered_current_state_ids(
from the database.

Returns:
defer.Deferred[StateMap[str]]: Map from type/state_key to event ID.
Map from type/state_key to event ID.
"""

where_clause, where_args = state_filter.make_sql_filter_clause()

if not where_clause:
# We delegate to the cached version
return self.get_current_state_ids(room_id)
return await self.get_current_state_ids(room_id)

def _get_filtered_current_state_ids_txn(txn):
results = {}
Expand All @@ -231,7 +232,7 @@ def _get_filtered_current_state_ids_txn(txn):

return results

return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
)

Expand Down
21 changes: 11 additions & 10 deletions synapse/storage/databases/main/state_deltas.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,17 @@
# limitations under the License.

import logging

from twisted.internet import defer
from typing import Any, Dict, List, Tuple

from synapse.storage._base import SQLBaseStore

logger = logging.getLogger(__name__)


class StateDeltasStore(SQLBaseStore):
def get_current_state_deltas(self, prev_stream_id: int, max_stream_id: int):
async def get_current_state_deltas(
self, prev_stream_id: int, max_stream_id: int
) -> Tuple[int, List[Dict[str, Any]]]:
"""Fetch a list of room state changes since the given stream id

Each entry in the result contains the following fields:
Expand All @@ -37,12 +38,12 @@ def get_current_state_deltas(self, prev_stream_id: int, max_stream_id: int):
if it's new state.

Args:
prev_stream_id (int): point to get changes since (exclusive)
max_stream_id (int): the point that we know has been correctly persisted
prev_stream_id: point to get changes since (exclusive)
max_stream_id: the point that we know has been correctly persisted
- ie, an upper limit to return changes from.

Returns:
Deferred[tuple[int, list[dict]]: A tuple consisting of:
A tuple consisting of:
- the stream id which these results go up to
- list of current_state_delta_stream rows. If it is empty, we are
up to date.
Expand All @@ -58,7 +59,7 @@ def get_current_state_deltas(self, prev_stream_id: int, max_stream_id: int):
# if the CSDs haven't changed between prev_stream_id and now, we
# know for certain that they haven't changed between prev_stream_id and
# max_stream_id.
return defer.succeed((max_stream_id, []))
return (max_stream_id, [])

def get_current_state_deltas_txn(txn):
# First we calculate the max stream id that will give us less than
Expand Down Expand Up @@ -102,7 +103,7 @@ def get_current_state_deltas_txn(txn):
txn.execute(sql, (prev_stream_id, clipped_stream_id))
return clipped_stream_id, self.db_pool.cursor_to_dict(txn)

return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_current_state_deltas", get_current_state_deltas_txn
)

Expand All @@ -114,8 +115,8 @@ def _get_max_stream_id_in_current_state_deltas_txn(self, txn):
retcol="COALESCE(MAX(stream_id), -1)",
)

def get_max_stream_id_in_current_state_deltas(self):
return self.db_pool.runInteraction(
async def get_max_stream_id_in_current_state_deltas(self):
return await self.db_pool.runInteraction(
"get_max_stream_id_in_current_state_deltas",
self._get_max_stream_id_in_current_state_deltas_txn,
)
11 changes: 7 additions & 4 deletions synapse/storage/databases/main/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,16 +539,17 @@ async def get_recent_event_ids_for_room(

return rows, token

def get_room_event_before_stream_ordering(self, room_id: str, stream_ordering: int):
async def get_room_event_before_stream_ordering(
self, room_id: str, stream_ordering: int
) -> Tuple[int, int, str]:
"""Gets details of the first event in a room at or before a stream ordering

Args:
room_id:
stream_ordering:

Returns:
Deferred[(int, int, str)]:
(stream ordering, topological ordering, event_id)
A tuple of (stream ordering, topological ordering, event_id)
"""

def _f(txn):
Expand All @@ -563,7 +564,9 @@ def _f(txn):
txn.execute(sql, (room_id, stream_ordering))
return txn.fetchone()

return self.db_pool.runInteraction("get_room_event_before_stream_ordering", _f)
return await self.db_pool.runInteraction(
"get_room_event_before_stream_ordering", _f
)

async def get_room_events_max_id(self, room_id: Optional[str] = None) -> str:
"""Returns the current token for rooms stream.
Expand Down
26 changes: 13 additions & 13 deletions synapse/storage/databases/state/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
from collections import namedtuple
from typing import Dict, Iterable, List, Set, Tuple

from twisted.internet import defer

from synapse.api.constants import EventTypes
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
Expand Down Expand Up @@ -103,7 +101,7 @@ def get_max_state_group_txn(txn: Cursor):
)

@cached(max_entries=10000, iterable=True)
def get_state_group_delta(self, state_group):
async def get_state_group_delta(self, state_group):
"""Given a state group try to return a previous group and a delta between
the old and the new.

Expand Down Expand Up @@ -135,7 +133,7 @@ def _get_state_group_delta_txn(txn):
{(row["type"], row["state_key"]): row["event_id"] for row in delta_ids},
)

return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_state_group_delta", _get_state_group_delta_txn
)

Expand Down Expand Up @@ -367,9 +365,9 @@ def _insert_into_cache(
fetched_keys=non_member_types,
)

def store_state_group(
async def store_state_group(
self, event_id, room_id, prev_group, delta_ids, current_state_ids
):
) -> int:
"""Store a new set of state, returning a newly assigned state group.

Args:
Expand All @@ -383,7 +381,7 @@ def store_state_group(
to event_id.

Returns:
Deferred[int]: The state group ID
The state group ID
"""

def _store_state_group_txn(txn):
Expand Down Expand Up @@ -484,11 +482,13 @@ def _store_state_group_txn(txn):

return state_group

return self.db_pool.runInteraction("store_state_group", _store_state_group_txn)
return await self.db_pool.runInteraction(
"store_state_group", _store_state_group_txn
)

def purge_unreferenced_state_groups(
async def purge_unreferenced_state_groups(
self, room_id: str, state_groups_to_delete
) -> defer.Deferred:
) -> None:
"""Deletes no longer referenced state groups and de-deltas any state
groups that reference them.

Expand All @@ -499,7 +499,7 @@ def purge_unreferenced_state_groups(
to delete.
"""

return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"purge_unreferenced_state_groups",
self._purge_unreferenced_state_groups,
room_id,
Expand Down Expand Up @@ -594,15 +594,15 @@ async def get_previous_state_groups(

return {row["state_group"]: row["prev_state_group"] for row in rows}

def purge_room_state(self, room_id, state_groups_to_delete):
async def purge_room_state(self, room_id, state_groups_to_delete):
"""Deletes all record of a room from state tables

Args:
room_id (str):
state_groups_to_delete (list[int]): State groups to delete
"""

return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"purge_room_state",
self._purge_room_state_txn,
room_id,
Expand Down
16 changes: 8 additions & 8 deletions synapse/storage/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,19 +333,19 @@ class StateGroupStorage(object):
def __init__(self, hs, stores):
self.stores = stores

def get_state_group_delta(self, state_group: int):
async def get_state_group_delta(self, state_group: int):
"""Given a state group try to return a previous group and a delta between
the old and the new.

Args:
state_group: The state group used to retrieve state deltas.

Returns:
Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]:
Tuple[Optional[int], Optional[StateMap[str]]]:
(prev_group, delta_ids)
"""

return self.stores.state.get_state_group_delta(state_group)
return await self.stores.state.get_state_group_delta(state_group)

async def get_state_groups_ids(
self, _room_id: str, event_ids: Iterable[str]
Expand Down Expand Up @@ -525,7 +525,7 @@ async def get_state_ids_for_event(
state_filter: The state filter used to fetch state from the database.

Returns:
A deferred dict from (type, state_key) -> state_event
A dict from (type, state_key) -> state_event
"""
state_map = await self.get_state_ids_for_events([event_id], state_filter)
return state_map[event_id]
Expand All @@ -546,14 +546,14 @@ def _get_state_for_groups(
"""
return self.stores.state._get_state_for_groups(groups, state_filter)

def store_state_group(
async def store_state_group(
self,
event_id: str,
room_id: str,
prev_group: Optional[int],
delta_ids: Optional[dict],
current_state_ids: dict,
):
) -> int:
"""Store a new set of state, returning a newly assigned state group.

Args:
Expand All @@ -567,8 +567,8 @@ def store_state_group(
to event_id.

Returns:
Deferred[int]: The state group ID
The state group ID
"""
return self.stores.state.store_state_group(
return await self.stores.state.store_state_group(
event_id, room_id, prev_group, delta_ids, current_state_ids
)