Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Use auto_attribs/native type hints for attrs classes. #11692

Merged
merged 20 commits into from
Jan 13, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions synapse/storage/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,17 +143,17 @@ def make_conn(
return db_conn


@attr.s(slots=True)
@attr.s(slots=True, auto_attribs=True)
class LoggingDatabaseConnection:
"""A wrapper around a database connection that returns `LoggingTransaction`
as its cursor class.

This is mainly used on startup to ensure that queries get logged correctly
"""

conn = attr.ib(type=Connection)
engine = attr.ib(type=BaseDatabaseEngine)
default_txn_name = attr.ib(type=str)
conn: Connection
engine: BaseDatabaseEngine
default_txn_name: str

def cursor(
self, *, txn_name=None, after_callbacks=None, exception_callbacks=None
Expand Down
6 changes: 3 additions & 3 deletions synapse/storage/databases/main/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,16 @@
from synapse.server import HomeServer


@attr.s(slots=True)
@attr.s(slots=True, auto_attribs=True)
class DeviceKeyLookupResult:
"""The type returned by get_e2e_device_keys_and_signatures"""

display_name = attr.ib(type=Optional[str])
display_name: Optional[str]

# the key data from e2e_device_keys_json. Typically includes fields like
# "algorithm", "keys" (including the curve25519 identity key and the ed25519 signing
# key) and "signatures" (a map from (user id) to (key id/device_id) to signature.)
keys = attr.ib(type=Optional[JsonDict])
keys: Optional[JsonDict]


class EndToEndKeyBackgroundStore(SQLBaseStore):
Expand Down
14 changes: 7 additions & 7 deletions synapse/storage/databases/main/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
)


@attr.s(slots=True)
@attr.s(slots=True, auto_attribs=True)
class DeltaState:
"""Deltas to use to update the `current_state_events` table.

Expand All @@ -80,9 +80,9 @@ class DeltaState:
should e.g. be removed from `current_state_events` table.
"""

to_delete = attr.ib(type=List[Tuple[str, str]])
to_insert = attr.ib(type=StateMap[str])
no_longer_in_room = attr.ib(type=bool, default=False)
to_delete: List[Tuple[str, str]]
to_insert: StateMap[str]
no_longer_in_room: bool = False


class PersistEventsStore:
Expand Down Expand Up @@ -2226,17 +2226,17 @@ def _update_backward_extremeties(self, txn, events):
)


@attr.s(slots=True)
@attr.s(slots=True, auto_attribs=True)
class _LinkMap:
"""A helper type for tracking links between chains."""

# Stores the set of links as nested maps: source chain ID -> target chain ID
# -> source sequence number -> target sequence number.
maps = attr.ib(type=Dict[int, Dict[int, Dict[int, int]]], factory=dict)
maps: Dict[int, Dict[int, Dict[int, int]]] = attr.Factory(dict)

# Stores the links that have been added (with new set to true), as tuples of
# `(source chain ID, source sequence no, target chain ID, target sequence no.)`
additions = attr.ib(type=Set[Tuple[int, int, int, int]], factory=set)
additions: Set[Tuple[int, int, int, int]] = attr.Factory(set)

def add_link(
self,
Expand Down
12 changes: 6 additions & 6 deletions synapse/storage/databases/main/events_bg_updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,22 +65,22 @@ class _BackgroundUpdates:
REPLACE_STREAM_ORDERING_COLUMN = "replace_stream_ordering_column"


@attr.s(slots=True, frozen=True)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _CalculateChainCover:
"""Return value for _calculate_chain_cover_txn."""

# The last room_id/depth/stream processed.
room_id = attr.ib(type=str)
depth = attr.ib(type=int)
stream = attr.ib(type=int)
room_id: str
depth: int
stream: int

# Number of rows processed
processed_count = attr.ib(type=int)
processed_count: int

# Map from room_id to last depth/stream processed for each room that we have
# processed all events for (i.e. the rooms we can flip the
# `has_auth_chain_index` for)
finished_room_map = attr.ib(type=Dict[str, Tuple[int, int]])
finished_room_map: Dict[str, Tuple[int, int]]


class EventsBackgroundUpdatesStore(SQLBaseStore):
Expand Down
18 changes: 9 additions & 9 deletions synapse/storage/databases/main/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class ExternalIDReuseException(Exception):
pass


@attr.s(frozen=True, slots=True)
@attr.s(frozen=True, slots=True, auto_attribs=True)
class TokenLookupResult:
"""Result of looking up an access token.

Expand All @@ -69,14 +69,14 @@ class TokenLookupResult:
cached.
"""

user_id = attr.ib(type=str)
is_guest = attr.ib(type=bool, default=False)
shadow_banned = attr.ib(type=bool, default=False)
token_id = attr.ib(type=Optional[int], default=None)
device_id = attr.ib(type=Optional[str], default=None)
valid_until_ms = attr.ib(type=Optional[int], default=None)
token_owner = attr.ib(type=str)
token_used = attr.ib(type=bool, default=False)
user_id: str
is_guest: bool = False
shadow_banned: bool = False
token_id: Optional[int] = None
device_id: Optional[str] = None
valid_until_ms: Optional[int] = None
token_owner: str = attr.ib()
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
token_used: bool = False

# Make the token owner default to the user ID, which is the common case.
@token_owner.default
Expand Down
6 changes: 3 additions & 3 deletions synapse/storage/databases/main/roommember.py
Original file line number Diff line number Diff line change
Expand Up @@ -1177,18 +1177,18 @@ def f(txn):
await self.db_pool.runInteraction("forget_membership", f)


@attr.s(slots=True)
@attr.s(slots=True, auto_attribs=True)
class _JoinedHostsCache:
"""The cached data used by the `_get_joined_hosts_cache`."""

# Dict of host to the set of their users in the room at the state group.
hosts_to_joined_users = attr.ib(type=Dict[str, Set[str]], factory=dict)
hosts_to_joined_users: Dict[str, Set[str]] = attr.Factory(dict)

# The state group `hosts_to_joined_users` is derived from. Will be an object
# if the instance is newly created or if the state is not based on a state
# group. (An object is used as a sentinel value to ensure that it never is
# equal to anything else).
state_group = attr.ib(type=Union[object, int], factory=object)
state_group: Union[object, int] = attr.Factory(object)

def __len__(self):
return sum(len(v) for v in self.hosts_to_joined_users.values())
12 changes: 6 additions & 6 deletions synapse/storage/databases/main/ui_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,19 @@
from synapse.util import json_encoder, stringutils


@attr.s(slots=True)
@attr.s(slots=True, auto_attribs=True)
class UIAuthSessionData:
session_id = attr.ib(type=str)
session_id: str
# The dictionary from the client root level, not the 'auth' key.
clientdict = attr.ib(type=JsonDict)
clientdict: JsonDict
# The URI and method the session was intiatied with. These are checked at
# each stage of the authentication to ensure that the asked for operation
# has not changed.
uri = attr.ib(type=str)
method = attr.ib(type=str)
uri: str
method: str
# A string description of the operation that the current authentication is
# authorising.
description = attr.ib(type=str)
description: str


class UIAuthWorkerStore(SQLBaseStore):
Expand Down
6 changes: 3 additions & 3 deletions synapse/storage/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
logger = logging.getLogger(__name__)


@attr.s(slots=True, frozen=True)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class FetchKeyResult:
verify_key = attr.ib(type=VerifyKey) # the key itself
valid_until_ts = attr.ib(type=int) # how long we can use this key for
verify_key: VerifyKey # the key itself
valid_until_ts: int # how long we can use this key for
6 changes: 3 additions & 3 deletions synapse/storage/prepare_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,7 +696,7 @@ def _get_or_create_schema_state(
)


@attr.s(slots=True)
@attr.s(slots=True, auto_attribs=True)
class _DirectoryListing:
"""Helper class to store schema file name and the
absolute path to it.
Expand All @@ -705,5 +705,5 @@ class _DirectoryListing:
`file_name` attr is kept first.
"""

file_name = attr.ib(type=str)
absolute_path = attr.ib(type=str)
file_name: str
absolute_path: str
20 changes: 10 additions & 10 deletions synapse/storage/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
logger = logging.getLogger(__name__)


@attr.s(slots=True)
@attr.s(slots=True, auto_attribs=True)
class PaginationChunk:
"""Returned by relation pagination APIs.

Expand All @@ -35,9 +35,9 @@ class PaginationChunk:
None then there are no previous results.
"""

chunk = attr.ib(type=List[JsonDict])
next_batch = attr.ib(type=Optional[Any], default=None)
prev_batch = attr.ib(type=Optional[Any], default=None)
chunk: List[JsonDict]
next_batch: Optional[Any] = None
prev_batch: Optional[Any] = None

def to_dict(self) -> Dict[str, Any]:
d = {"chunk": self.chunk}
Expand All @@ -51,7 +51,7 @@ def to_dict(self) -> Dict[str, Any]:
return d


@attr.s(frozen=True, slots=True)
@attr.s(frozen=True, slots=True, auto_attribs=True)
class RelationPaginationToken:
"""Pagination token for relation pagination API.

Expand All @@ -64,8 +64,8 @@ class RelationPaginationToken:
stream: The stream ordering of the boundary event.
"""

topological = attr.ib(type=int)
stream = attr.ib(type=int)
topological: int
stream: int

@staticmethod
def from_string(string: str) -> "RelationPaginationToken":
Expand All @@ -82,7 +82,7 @@ def as_tuple(self) -> Tuple[Any, ...]:
return attr.astuple(self)


@attr.s(frozen=True, slots=True)
@attr.s(frozen=True, slots=True, auto_attribs=True)
class AggregationPaginationToken:
"""Pagination token for relation aggregation pagination API.

Expand All @@ -94,8 +94,8 @@ class AggregationPaginationToken:
stream: The MAX stream ordering in the boundary group.
"""

count = attr.ib(type=int)
stream = attr.ib(type=int)
count: int
stream: int

@staticmethod
def from_string(string: str) -> "AggregationPaginationToken":
Expand Down
6 changes: 3 additions & 3 deletions synapse/storage/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
T = TypeVar("T")


@attr.s(slots=True, frozen=True)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class StateFilter:
"""A filter used when querying for state.

Expand All @@ -58,8 +58,8 @@ class StateFilter:
appear in `types`.
"""

types = attr.ib(type="frozendict[str, Optional[FrozenSet[str]]]")
include_others = attr.ib(default=False, type=bool)
types: "frozendict[str, Optional[FrozenSet[str]]]"
include_others: bool = False

def __attrs_post_init__(self):
# If `include_others` is set we canonicalise the filter by removing
Expand Down
8 changes: 4 additions & 4 deletions synapse/storage/util/id_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,13 +762,13 @@ async def __aexit__(
return self.inner.__exit__(exc_type, exc, tb)


@attr.s(slots=True)
@attr.s(slots=True, auto_attribs=True)
class _MultiWriterCtxManager:
"""Async context manager returned by MultiWriterIdGenerator"""

id_gen = attr.ib(type=MultiWriterIdGenerator)
multiple_ids = attr.ib(type=Optional[int], default=None)
stream_ids = attr.ib(type=List[int], factory=list)
id_gen: MultiWriterIdGenerator
multiple_ids: Optional[int] = None
stream_ids: List[int] = attr.Factory(list)

async def __aenter__(self) -> Union[int, List[int]]:
# It's safe to run this in autocommit mode as fetching values from a
Expand Down