Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Convert additional databases to async/await part 3 #8201

Merged
merged 8 commits into from
Sep 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/8201.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.
4 changes: 2 additions & 2 deletions synapse/storage/background_updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,15 +433,15 @@ async def _end_background_update(self, update_name: str) -> None:
"background_updates", keyvalues={"update_name": update_name}
)

def _background_update_progress(self, update_name: str, progress: dict):
async def _background_update_progress(self, update_name: str, progress: dict):
"""Update the progress of a background update

Args:
update_name: The name of the background update task
progress: The progress of the update.
"""

return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"background_update_progress",
self._background_update_progress_txn,
update_name,
Expand Down
59 changes: 32 additions & 27 deletions synapse/storage/databases/main/account_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@

import abc
import logging
from typing import List, Optional, Tuple

from twisted.internet import defer
from typing import Dict, List, Optional, Tuple

from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
Expand Down Expand Up @@ -58,14 +56,16 @@ def get_max_account_data_stream_id(self):
raise NotImplementedError()

@cached()
def get_account_data_for_user(self, user_id):
async def get_account_data_for_user(
self, user_id: str
) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
"""Get all the client account_data for a user.

Args:
user_id(str): The user to get the account_data for.
user_id: The user to get the account_data for.
Returns:
A deferred pair of a dict of global account_data and a dict
mapping from room_id string to per room account_data dicts.
A 2-tuple of a dict of global account_data and a dict mapping from
room_id string to per room account_data dicts.
"""

def get_account_data_for_user_txn(txn):
Expand Down Expand Up @@ -94,7 +94,7 @@ def get_account_data_for_user_txn(txn):

return global_account_data, by_room

return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_account_data_for_user", get_account_data_for_user_txn
)

Expand All @@ -120,14 +120,16 @@ async def get_global_account_data_by_type_for_user(
return None

@cached(num_args=2)
def get_account_data_for_room(self, user_id, room_id):
async def get_account_data_for_room(
self, user_id: str, room_id: str
) -> Dict[str, JsonDict]:
"""Get all the client account_data for a user for a room.

Args:
user_id(str): The user to get the account_data for.
room_id(str): The room to get the account_data for.
user_id: The user to get the account_data for.
room_id: The room to get the account_data for.
Returns:
A deferred dict of the room account_data
A dict of the room account_data
"""

def get_account_data_for_room_txn(txn):
Expand All @@ -142,21 +144,22 @@ def get_account_data_for_room_txn(txn):
row["account_data_type"]: db_to_json(row["content"]) for row in rows
}

return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_account_data_for_room", get_account_data_for_room_txn
)

@cached(num_args=3, max_entries=5000)
def get_account_data_for_room_and_type(self, user_id, room_id, account_data_type):
async def get_account_data_for_room_and_type(
self, user_id: str, room_id: str, account_data_type: str
) -> Optional[JsonDict]:
"""Get the client account_data of given type for a user for a room.

Args:
user_id(str): The user to get the account_data for.
room_id(str): The room to get the account_data for.
account_data_type (str): The account data type to get.
user_id: The user to get the account_data for.
room_id: The room to get the account_data for.
account_data_type: The account data type to get.
Returns:
A deferred of the room account_data for that type, or None if
there isn't any set.
The room account_data for that type, or None if there isn't any set.
"""

def get_account_data_for_room_and_type_txn(txn):
Expand All @@ -174,7 +177,7 @@ def get_account_data_for_room_and_type_txn(txn):

return db_to_json(content_json) if content_json else None

return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
)

Expand Down Expand Up @@ -238,12 +241,14 @@ def get_updated_room_account_data_txn(txn):
"get_updated_room_account_data", get_updated_room_account_data_txn
)

def get_updated_account_data_for_user(self, user_id, stream_id):
async def get_updated_account_data_for_user(
self, user_id: str, stream_id: int
) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
"""Get all the client account_data for a that's changed for a user

Args:
user_id(str): The user to get the account_data for.
stream_id(int): The point in the stream since which to get updates
user_id: The user to get the account_data for.
stream_id: The point in the stream since which to get updates
Returns:
A deferred pair of a dict of global account_data and a dict
mapping from room_id string to per room account_data dicts.
Expand Down Expand Up @@ -277,9 +282,9 @@ def get_updated_account_data_for_user_txn(txn):
user_id, int(stream_id)
)
if not changed:
return defer.succeed(({}, {}))
return ({}, {})

return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_updated_account_data_for_user", get_updated_account_data_for_user_txn
)

Expand Down Expand Up @@ -416,7 +421,7 @@ async def add_account_data_for_user(

return self._account_data_id_gen.get_current_token()

def _update_max_stream_id(self, next_id: int):
async def _update_max_stream_id(self, next_id: int) -> None:
"""Update the max stream_id

Args:
Expand All @@ -435,4 +440,4 @@ def _update(txn):
)
txn.execute(update_max_id_sql, (next_id, next_id))

return self.db_pool.runInteraction("update_account_data_max_stream_id", _update)
await self.db_pool.runInteraction("update_account_data_max_stream_id", _update)
54 changes: 34 additions & 20 deletions synapse/storage/databases/main/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,15 @@


class EndToEndKeyWorkerStore(SQLBaseStore):
def get_e2e_device_keys_for_federation_query(self, user_id: str):
async def get_e2e_device_keys_for_federation_query(
self, user_id: str
) -> Tuple[int, List[JsonDict]]:
"""Get all devices (with any device keys) for a user

Returns:
Deferred which resolves to (stream_id, devices)
(stream_id, devices)
"""
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_e2e_device_keys_for_federation_query",
self._get_e2e_device_keys_for_federation_query_txn,
user_id,
Expand Down Expand Up @@ -290,10 +292,12 @@ def _add_e2e_one_time_keys(txn):
)

@cached(max_entries=10000)
def count_e2e_one_time_keys(self, user_id, device_id):
async def count_e2e_one_time_keys(
self, user_id: str, device_id: str
) -> Dict[str, int]:
""" Count the number of one time keys the server has for a device
Returns:
Dict mapping from algorithm to number of keys for that algorithm.
A mapping from algorithm to number of keys for that algorithm.
"""

def _count_e2e_one_time_keys(txn):
Expand All @@ -308,7 +312,7 @@ def _count_e2e_one_time_keys(txn):
result[algorithm] = key_count
return result

return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"count_e2e_one_time_keys", _count_e2e_one_time_keys
)

Expand Down Expand Up @@ -346,24 +350,23 @@ def _get_bare_e2e_cross_signing_keys(self, user_id):
list_name="user_ids",
num_args=1,
)
def _get_bare_e2e_cross_signing_keys_bulk(
async def _get_bare_e2e_cross_signing_keys_bulk(
self, user_ids: List[str]
) -> Dict[str, Dict[str, dict]]:
"""Returns the cross-signing keys for a set of users. The output of this
function should be passed to _get_e2e_cross_signing_signatures_txn if
the signatures for the calling user need to be fetched.

Args:
user_ids (list[str]): the users whose keys are being requested
user_ids: the users whose keys are being requested

Returns:
dict[str, dict[str, dict]]: mapping from user ID to key type to key
data. If a user's cross-signing keys were not found, either
their user ID will not be in the dict, or their user ID will map
to None.
A mapping from user ID to key type to key data. If a user's cross-signing
keys were not found, either their user ID will not be in the dict, or
their user ID will map to None.

"""
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_bare_e2e_cross_signing_keys_bulk",
self._get_bare_e2e_cross_signing_keys_bulk_txn,
user_ids,
Expand Down Expand Up @@ -586,7 +589,9 @@ def get_device_stream_token(self) -> int:


class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
async def set_e2e_device_keys(
self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict
) -> bool:
"""Stores device keys for a device. Returns whether there was a change
or the keys were already in the database.
"""
Expand Down Expand Up @@ -622,12 +627,21 @@ def _set_e2e_device_keys_txn(txn):
log_kv({"message": "Device keys stored."})
return True

return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"set_e2e_device_keys", _set_e2e_device_keys_txn
)

def claim_e2e_one_time_keys(self, query_list):
"""Take a list of one time keys out of the database"""
async def claim_e2e_one_time_keys(
self, query_list: Iterable[Tuple[str, str, str]]
) -> Dict[str, Dict[str, Dict[str, bytes]]]:
"""Take a list of one time keys out of the database.

Args:
query_list: An iterable of tuples of (user ID, device ID, algorithm).

Returns:
A map of user ID -> a map device ID -> a map of key ID -> JSON bytes.
"""

@trace
def _claim_e2e_one_time_keys(txn):
Expand Down Expand Up @@ -663,11 +677,11 @@ def _claim_e2e_one_time_keys(txn):
)
return result

return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys
)

def delete_e2e_keys_by_device(self, user_id, device_id):
async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None:
def delete_e2e_keys_by_device_txn(txn):
log_kv(
{
Expand All @@ -690,7 +704,7 @@ def delete_e2e_keys_by_device_txn(txn):
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)

return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
)

Expand Down
Loading