Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add some type hints to datastore #12485

Merged
merged 7 commits into from
Apr 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12485.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add some type hints to datastore.
21 changes: 15 additions & 6 deletions synapse/storage/databases/main/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,17 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING, List, Optional, Tuple
from typing import TYPE_CHECKING, List, Optional, Tuple, cast

from synapse.config.homeserver import HomeServerConfig
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.databases.main.stats import UserSortOrder
from synapse.storage.engines import PostgresEngine
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import (
IdGenerator,
MultiWriterIdGenerator,
Expand Down Expand Up @@ -266,7 +271,9 @@ async def get_users_paginate(
A tuple of a list of mappings from user to information and a count of total users.
"""

def get_users_paginate_txn(txn):
def get_users_paginate_txn(
txn: LoggingTransaction,
) -> Tuple[List[JsonDict], int]:
filters = []
args = [self.hs.config.server.server_name]

Expand Down Expand Up @@ -301,7 +308,7 @@ def get_users_paginate_txn(txn):
"""
sql = "SELECT COUNT(*) as total_users " + sql_base
txn.execute(sql, args)
count = txn.fetchone()[0]
count = cast(Tuple[int], txn.fetchone())[0]

sql = f"""
SELECT name, user_type, is_guest, admin, deactivated, shadow_banned,
Expand Down Expand Up @@ -338,7 +345,9 @@ async def search_users(self, term: str) -> Optional[List[JsonDict]]:
)


def check_database_before_upgrade(cur, database_engine, config: HomeServerConfig):
def check_database_before_upgrade(
cur: Cursor, database_engine: BaseDatabaseEngine, config: HomeServerConfig
) -> None:
"""Called before upgrading an existing database to check that it is broadly sane
compared with the configuration.
"""
Expand Down
4 changes: 2 additions & 2 deletions synapse/storage/databases/main/appservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
import logging
import re
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern, Tuple, cast

from synapse.appservice import (
ApplicationService,
Expand Down Expand Up @@ -83,7 +83,7 @@ def get_max_as_txn_id(txn: Cursor) -> int:
txn.execute(
"SELECT COALESCE(max(txn_id), 0) FROM application_services_txns"
)
return txn.fetchone()[0] # type: ignore
return cast(Tuple[int], txn.fetchone())[0]

self._as_txn_seq_gen = build_sequence_generator(
db_conn,
Expand Down
79 changes: 56 additions & 23 deletions synapse/storage/databases/main/deviceinbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,17 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple, cast
from typing import (
TYPE_CHECKING,
Collection,
Dict,
Iterable,
List,
Optional,
Set,
Tuple,
cast,
)

from synapse.logging import issue9533_logger
from synapse.logging.opentracing import log_kv, set_tag, trace
Expand Down Expand Up @@ -118,7 +128,13 @@ def __init__(
prefilled_cache=device_outbox_prefill,
)

def process_replication_rows(self, stream_name, instance_name, token, rows):
def process_replication_rows(
self,
stream_name: str,
instance_name: str,
token: int,
rows: Iterable[ToDeviceStream.ToDeviceStreamRow],
) -> None:
if stream_name == ToDeviceStream.NAME:
# If replication is happening than postgres must be being used.
assert isinstance(self._device_inbox_id_gen, MultiWriterIdGenerator)
Expand All @@ -134,7 +150,7 @@ def process_replication_rows(self, stream_name, instance_name, token, rows):
)
return super().process_replication_rows(stream_name, instance_name, token, rows)

def get_to_device_stream_token(self):
def get_to_device_stream_token(self) -> int:
return self._device_inbox_id_gen.get_current_token()

async def get_messages_for_user_devices(
Expand Down Expand Up @@ -301,7 +317,9 @@ async def _get_device_messages(
if not user_ids_to_query:
return {}, to_stream_id

def get_device_messages_txn(txn: LoggingTransaction):
def get_device_messages_txn(
txn: LoggingTransaction,
) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]:
# Build a query to select messages from any of the given devices that
# are between the given stream id bounds.

Expand Down Expand Up @@ -428,7 +446,7 @@ async def delete_messages_for_device(
log_kv({"message": "No changes in cache since last check"})
return 0

def delete_messages_for_device_txn(txn):
def delete_messages_for_device_txn(txn: LoggingTransaction) -> int:
sql = (
"DELETE FROM device_inbox"
" WHERE user_id = ? AND device_id = ?"
Expand All @@ -455,15 +473,14 @@ def delete_messages_for_device_txn(txn):

@trace
async def get_new_device_msgs_for_remote(
self, destination, last_stream_id, current_stream_id, limit
) -> Tuple[List[dict], int]:
self, destination: str, last_stream_id: int, current_stream_id: int, limit: int
) -> Tuple[List[JsonDict], int]:
"""
Args:
destination(str): The name of the remote server.
last_stream_id(int|long): The last position of the device message stream
destination: The name of the remote server.
last_stream_id: The last position of the device message stream
that the server sent up to.
current_stream_id(int|long): The current position of the device
message stream.
current_stream_id: The current position of the device message stream.
Returns:
A list of messages for the device and where in the stream the messages got to.
"""
Expand All @@ -485,7 +502,9 @@ async def get_new_device_msgs_for_remote(
return [], last_stream_id

@trace
def get_new_messages_for_remote_destination_txn(txn):
def get_new_messages_for_remote_destination_txn(
txn: LoggingTransaction,
) -> Tuple[List[JsonDict], int]:
sql = (
"SELECT stream_id, messages_json FROM device_federation_outbox"
" WHERE destination = ?"
Expand Down Expand Up @@ -527,7 +546,7 @@ async def delete_device_msgs_for_remote(
up_to_stream_id: Where to delete messages up to.
"""

def delete_messages_for_remote_destination_txn(txn):
def delete_messages_for_remote_destination_txn(txn: LoggingTransaction) -> None:
sql = (
"DELETE FROM device_federation_outbox"
" WHERE destination = ?"
Expand Down Expand Up @@ -566,7 +585,9 @@ async def get_all_new_device_messages(
if last_id == current_id:
return [], current_id, False

def get_all_new_device_messages_txn(txn):
def get_all_new_device_messages_txn(
txn: LoggingTransaction,
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
# We limit like this as we might have multiple rows per stream_id, and
# we want to make sure we always get all entries for any stream_id
# we return.
Expand Down Expand Up @@ -607,8 +628,8 @@ def get_all_new_device_messages_txn(txn):
@trace
async def add_messages_to_device_inbox(
self,
local_messages_by_user_then_device: dict,
remote_messages_by_destination: dict,
local_messages_by_user_then_device: Dict[str, Dict[str, JsonDict]],
remote_messages_by_destination: Dict[str, JsonDict],
) -> int:
"""Used to send messages from this server.

Expand All @@ -624,7 +645,9 @@ async def add_messages_to_device_inbox(

assert self._can_write_to_device

def add_messages_txn(txn, now_ms, stream_id):
def add_messages_txn(
txn: LoggingTransaction, now_ms: int, stream_id: int
) -> None:
# Add the local messages directly to the local inbox.
self._add_messages_to_local_device_inbox_txn(
txn, stream_id, local_messages_by_user_then_device
Expand Down Expand Up @@ -677,11 +700,16 @@ def add_messages_txn(txn, now_ms, stream_id):
return self._device_inbox_id_gen.get_current_token()

async def add_messages_from_remote_to_device_inbox(
self, origin: str, message_id: str, local_messages_by_user_then_device: dict
self,
origin: str,
message_id: str,
local_messages_by_user_then_device: Dict[str, Dict[str, JsonDict]],
) -> int:
assert self._can_write_to_device

def add_messages_txn(txn, now_ms, stream_id):
def add_messages_txn(
txn: LoggingTransaction, now_ms: int, stream_id: int
) -> None:
# Check if we've already inserted a matching message_id for that
# origin. This can happen if the origin doesn't receive our
# acknowledgement from the first time we received the message.
Expand Down Expand Up @@ -727,8 +755,11 @@ def add_messages_txn(txn, now_ms, stream_id):
return stream_id

def _add_messages_to_local_device_inbox_txn(
self, txn, stream_id, messages_by_user_then_device
):
self,
txn: LoggingTransaction,
stream_id: int,
messages_by_user_then_device: Dict[str, Dict[str, JsonDict]],
) -> None:
assert self._can_write_to_device

local_by_user_then_device = {}
Expand Down Expand Up @@ -840,8 +871,10 @@ def __init__(
self._remove_dead_devices_from_device_inbox,
)

async def _background_drop_index_device_inbox(self, progress, batch_size):
def reindex_txn(conn):
async def _background_drop_index_device_inbox(
self, progress: JsonDict, batch_size: int
) -> int:
def reindex_txn(conn: LoggingDatabaseConnection) -> None:
txn = conn.cursor()
txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
txn.close()
Expand Down
Loading