Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Type hints for RegistrationStore #8615

Merged
merged 5 commits into from
Oct 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/8615.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Type hints for `RegistrationStore`.
1 change: 1 addition & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ files =
synapse/spam_checker_api,
synapse/state,
synapse/storage/databases/main/events.py,
synapse/storage/databases/main/registration.py,
synapse/storage/databases/main/stream.py,
synapse/storage/databases/main/ui_auth.py,
synapse/storage/database.py,
Expand Down
1 change: 0 additions & 1 deletion synapse/storage/databases/main/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ def __init__(self, database: DatabasePool, db_conn, hs):
db_conn, "e2e_cross_signing_keys", "stream_id"
)

self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
Expand Down
156 changes: 83 additions & 73 deletions synapse/storage/databases/main/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,29 +16,33 @@
# limitations under the License.
import logging
import re
from typing import Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple

from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
from synapse.storage.types import Cursor
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.databases.main.stats import StatsStore
from synapse.storage.types import Connection, Cursor
from synapse.storage.util.id_generators import IdGenerator
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import UserID
from synapse.util.caches.descriptors import cached

if TYPE_CHECKING:
from synapse.server import HomeServer

THIRTY_MINUTES_IN_MS = 30 * 60 * 1000

logger = logging.getLogger(__name__)


class RegistrationWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
super().__init__(database, db_conn, hs)

self.config = hs.config
self.clock = hs.get_clock()

# Note: we don't check this sequence for consistency as we'd have to
# call `find_max_generated_user_id_localpart` each time, which is
Expand All @@ -55,7 +59,7 @@ def __init__(self, database: DatabasePool, db_conn, hs):

# Create a background job for culling expired 3PID validity tokens
if hs.config.run_background_tasks:
self.clock.looping_call(
self._clock.looping_call(
self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS
)

Expand Down Expand Up @@ -92,7 +96,7 @@ async def is_trial_user(self, user_id: str) -> bool:
if not info:
return False

now = self.clock.time_msec()
now = self._clock.time_msec()
trial_duration_ms = self.config.mau_trial_days * 24 * 60 * 60 * 1000
is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms
return is_trial
Expand Down Expand Up @@ -257,7 +261,7 @@ def select_users_txn(txn, now_ms, renew_at):
return await self.db_pool.runInteraction(
"get_users_expiring_soon",
select_users_txn,
self.clock.time_msec(),
self._clock.time_msec(),
self.config.account_validity.renew_at,
)

Expand Down Expand Up @@ -328,13 +332,17 @@ def set_server_admin_txn(txn):
await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)

def _query_for_auth(self, txn, token):
sql = (
"SELECT users.name, users.is_guest, users.shadow_banned, access_tokens.id as token_id,"
" access_tokens.device_id, access_tokens.valid_until_ms"
" FROM users"
" INNER JOIN access_tokens on users.name = access_tokens.user_id"
" WHERE token = ?"
)
sql = """
SELECT users.name,
users.is_guest,
users.shadow_banned,
access_tokens.id as token_id,
access_tokens.device_id,
access_tokens.valid_until_ms
FROM users
INNER JOIN access_tokens on users.name = access_tokens.user_id
WHERE token = ?
"""

txn.execute(sql, (token,))
rows = self.db_pool.cursor_to_dict(txn)
Expand Down Expand Up @@ -803,7 +811,7 @@ def cull_expired_threepid_validation_tokens_txn(txn, ts):
await self.db_pool.runInteraction(
"cull_expired_threepid_validation_tokens",
cull_expired_threepid_validation_tokens_txn,
self.clock.time_msec(),
self._clock.time_msec(),
)

@wrap_as_background_process("account_validity_set_expiration_dates")
Expand Down Expand Up @@ -890,10 +898,10 @@ async def del_user_pending_deactivation(self, user_id: str) -> None:


class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
super().__init__(database, db_conn, hs)

self.clock = hs.get_clock()
self._clock = hs.get_clock()
self.config = hs.config

self.db_pool.updates.register_background_index_update(
Expand Down Expand Up @@ -1016,13 +1024,56 @@ def _bg_user_threepids_grandfather_txn(txn):

return 1

async def set_user_deactivated_status(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why were these moved?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They're used elsewhere in that class

self, user_id: str, deactivated: bool
) -> None:
"""Set the `deactivated` property for the provided user to the provided value.

Args:
user_id: The ID of the user to set the status for.
deactivated: The value to set for `deactivated`.
"""

await self.db_pool.runInteraction(
"set_user_deactivated_status",
self.set_user_deactivated_status_txn,
user_id,
deactivated,
)

def set_user_deactivated_status_txn(self, txn, user_id: str, deactivated: bool):
self.db_pool.simple_update_one_txn(
txn=txn,
table="users",
keyvalues={"name": user_id},
updatevalues={"deactivated": 1 if deactivated else 0},
)
self._invalidate_cache_and_stream(
txn, self.get_user_deactivated_status, (user_id,)
)
txn.call_after(self.is_guest.invalidate, (user_id,))

@cached()
async def is_guest(self, user_id: str) -> bool:
res = await self.db_pool.simple_select_one_onecol(
table="users",
keyvalues={"name": user_id},
retcol="is_guest",
allow_none=True,
desc="is_guest",
)

return res if res else False


class RegistrationStore(RegistrationBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
super().__init__(database, db_conn, hs)

self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors

self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")

async def add_access_token_to_user(
self,
user_id: str,
Expand Down Expand Up @@ -1138,19 +1189,19 @@ async def register_user(
def _register_user(
self,
txn,
user_id,
password_hash,
was_guest,
make_guest,
appservice_id,
create_profile_with_displayname,
admin,
user_type,
shadow_banned,
user_id: str,
password_hash: Optional[str],
was_guest: bool,
make_guest: bool,
appservice_id: Optional[str],
create_profile_with_displayname: Optional[str],
admin: bool,
user_type: Optional[str],
shadow_banned: bool,
):
user_id_obj = UserID.from_string(user_id)

now = int(self.clock.time())
now = int(self._clock.time())

try:
if was_guest:
Expand Down Expand Up @@ -1374,18 +1425,6 @@ def f(txn):

await self.db_pool.runInteraction("delete_access_token", f)

@cached()
async def is_guest(self, user_id: str) -> bool:
res = await self.db_pool.simple_select_one_onecol(
table="users",
keyvalues={"name": user_id},
retcol="is_guest",
allow_none=True,
desc="is_guest",
)

return res if res else False

async def add_user_pending_deactivation(self, user_id: str) -> None:
"""
Adds a user to the table of users who need to be parted from all the rooms they're
Expand Down Expand Up @@ -1479,7 +1518,7 @@ def validate_threepid_session_txn(txn):
txn,
table="threepid_validation_session",
keyvalues={"session_id": session_id},
updatevalues={"validated_at": self.clock.time_msec()},
updatevalues={"validated_at": self._clock.time_msec()},
)

return next_link
Expand Down Expand Up @@ -1547,35 +1586,6 @@ def start_or_continue_validation_session_txn(txn):
start_or_continue_validation_session_txn,
)

async def set_user_deactivated_status(
self, user_id: str, deactivated: bool
) -> None:
"""Set the `deactivated` property for the provided user to the provided value.

Args:
user_id: The ID of the user to set the status for.
deactivated: The value to set for `deactivated`.
"""

await self.db_pool.runInteraction(
"set_user_deactivated_status",
self.set_user_deactivated_status_txn,
user_id,
deactivated,
)

def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
self.db_pool.simple_update_one_txn(
txn=txn,
table="users",
keyvalues={"name": user_id},
updatevalues={"deactivated": 1 if deactivated else 0},
)
self._invalidate_cache_and_stream(
txn, self.get_user_deactivated_status, (user_id,)
)
txn.call_after(self.is_guest.invalidate, (user_id,))


def find_max_generated_user_id_localpart(cur: Cursor) -> int:
"""
Expand Down