Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Use an enum for direction. (#14927)
Browse files Browse the repository at this point in the history
For better type safety we  use an enum instead of strings to
configure direction (backwards or forwards).
  • Loading branch information
clokep authored Jan 27, 2023
1 parent fc35e06 commit 265735d
Show file tree
Hide file tree
Showing 9 changed files with 76 additions and 44 deletions.
1 change: 1 addition & 0 deletions changelog.d/14927.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add missing type hints.
7 changes: 7 additions & 0 deletions synapse/api/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

"""Contains constants from the specification."""

import enum

from typing_extensions import Final

# the max size of a (canonical-json-encoded) event
Expand Down Expand Up @@ -290,3 +292,8 @@ class ApprovalNoticeMedium:

NONE = "org.matrix.msc3866.none"
EMAIL = "org.matrix.msc3866.email"


class Direction(enum.Enum):
BACKWARDS = "b"
FORWARDS = "f"
4 changes: 2 additions & 2 deletions synapse/handlers/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set

from synapse.api.constants import Membership
from synapse.api.constants import Direction, Membership
from synapse.events import EventBase
from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID
from synapse.visibility import filter_events_for_client
Expand Down Expand Up @@ -197,7 +197,7 @@ async def export_user_data(self, user_id: str, writer: "ExfiltrationWriter") ->
# efficient method perhaps but it does guarantee we get everything.
while True:
events, _ = await self.store.paginate_room_events(
room_id, from_key, to_key, limit=100, direction="f"
room_id, from_key, to_key, limit=100, direction=Direction.FORWARDS
)
if not events:
break
Expand Down
16 changes: 14 additions & 2 deletions synapse/handlers/initial_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@
import logging
from typing import TYPE_CHECKING, List, Optional, Tuple, cast

from synapse.api.constants import AccountDataTypes, EduTypes, EventTypes, Membership
from synapse.api.constants import (
AccountDataTypes,
Direction,
EduTypes,
EventTypes,
Membership,
)
from synapse.api.errors import SynapseError
from synapse.events import EventBase
from synapse.events.utils import SerializeEventConfig
Expand Down Expand Up @@ -57,7 +63,13 @@ def __init__(self, hs: "HomeServer"):
self.validator = EventValidator()
self.snapshot_cache: ResponseCache[
Tuple[
str, Optional[StreamToken], Optional[StreamToken], str, int, bool, bool
str,
Optional[StreamToken],
Optional[StreamToken],
Direction,
int,
bool,
bool,
]
] = ResponseCache(hs.get_clock(), "initial_sync_cache")
self._event_serializer = hs.get_event_client_serializer()
Expand Down
6 changes: 3 additions & 3 deletions synapse/handlers/pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from twisted.python.failure import Failure

from synapse.api.constants import EventTypes, Membership
from synapse.api.constants import Direction, EventTypes, Membership
from synapse.api.errors import SynapseError
from synapse.api.filtering import Filter
from synapse.events.utils import SerializeEventConfig
Expand Down Expand Up @@ -448,7 +448,7 @@ async def get_messages(

if pagin_config.from_token:
from_token = pagin_config.from_token
elif pagin_config.direction == "f":
elif pagin_config.direction == Direction.FORWARDS:
from_token = (
await self.hs.get_event_sources().get_start_token_for_pagination(
room_id
Expand Down Expand Up @@ -476,7 +476,7 @@ async def get_messages(
room_id, requester, allow_departed_users=True
)

if pagin_config.direction == "b":
if pagin_config.direction == Direction.BACKWARDS:
# if we're going backwards, we might need to backfill. This
# requires that we have a topo token.
if room_token.topological:
Expand Down
8 changes: 6 additions & 2 deletions synapse/handlers/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import attr

from synapse.api.constants import EventTypes, RelationTypes
from synapse.api.constants import Direction, EventTypes, RelationTypes
from synapse.api.errors import SynapseError
from synapse.events import EventBase, relation_from_event
from synapse.logging.context import make_deferred_yieldable, run_in_background
Expand Down Expand Up @@ -413,7 +413,11 @@ async def _get_threads_for_events(

# Attempt to find another event to use as the latest event.
potential_events, _ = await self._main_store.get_relations_for_event(
event_id, event, room_id, RelationTypes.THREAD, direction="f"
event_id,
event,
room_id,
RelationTypes.THREAD,
direction=Direction.FORWARDS,
)

# Filter out ignored users.
Expand Down
8 changes: 4 additions & 4 deletions synapse/storage/databases/main/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

import attr

from synapse.api.constants import MAIN_TIMELINE, RelationTypes
from synapse.api.constants import MAIN_TIMELINE, Direction, RelationTypes
from synapse.api.errors import SynapseError
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore
Expand Down Expand Up @@ -168,7 +168,7 @@ async def get_relations_for_event(
relation_type: Optional[str] = None,
event_type: Optional[str] = None,
limit: int = 5,
direction: str = "b",
direction: Direction = Direction.BACKWARDS,
from_token: Optional[StreamToken] = None,
to_token: Optional[StreamToken] = None,
) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
Expand All @@ -181,8 +181,8 @@ async def get_relations_for_event(
relation_type: Only fetch events with this relation type, if given.
event_type: Only fetch events with this event type, if given.
limit: Only fetch the most recent `limit` events.
direction: Whether to fetch the most recent first (`"b"`) or the
oldest first (`"f"`).
direction: Whether to fetch the most recent first (backwards) or the
oldest first (forwards).
from_token: Fetch rows from the given token, or from the start if None.
to_token: Fetch rows up to the given token, or up to the end if None.
Expand Down
59 changes: 31 additions & 28 deletions synapse/storage/databases/main/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@

from twisted.internet import defer

from synapse.api.constants import Direction
from synapse.api.filtering import Filter
from synapse.events import EventBase
from synapse.logging.context import make_deferred_yieldable, run_in_background
Expand Down Expand Up @@ -86,7 +87,6 @@
_STREAM_TOKEN = "stream"
_TOPOLOGICAL_TOKEN = "topological"


# Used as return values for pagination APIs
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _EventDictReturn:
Expand All @@ -104,7 +104,7 @@ class _EventsAround:


def generate_pagination_where_clause(
direction: str,
direction: Direction,
column_names: Tuple[str, str],
from_token: Optional[Tuple[Optional[int], int]],
to_token: Optional[Tuple[Optional[int], int]],
Expand All @@ -130,27 +130,26 @@ def generate_pagination_where_clause(
token, but include those that match the to token.
Args:
direction: Whether we're paginating backwards("b") or forwards ("f").
direction: Whether we're paginating backwards or forwards.
column_names: The column names to bound. Must *not* be user defined as
these get inserted directly into the SQL statement without escapes.
from_token: The start point for the pagination. This is an exclusive
minimum bound if direction is "f", and an inclusive maximum bound if
direction is "b".
minimum bound if direction is forwards, and an inclusive maximum bound if
direction is backwards.
to_token: The endpoint point for the pagination. This is an inclusive
maximum bound if direction is "f", and an exclusive minimum bound if
direction is "b".
maximum bound if direction is forwards, and an exclusive minimum bound if
direction is backwards.
engine: The database engine to generate the clauses for
Returns:
The sql expression
"""
assert direction in ("b", "f")

where_clause = []
if from_token:
where_clause.append(
_make_generic_sql_bound(
bound=">=" if direction == "b" else "<",
bound=">=" if direction == Direction.BACKWARDS else "<",
column_names=column_names,
values=from_token,
engine=engine,
Expand All @@ -160,7 +159,7 @@ def generate_pagination_where_clause(
if to_token:
where_clause.append(
_make_generic_sql_bound(
bound="<" if direction == "b" else ">=",
bound="<" if direction == Direction.BACKWARDS else ">=",
column_names=column_names,
values=to_token,
engine=engine,
Expand All @@ -171,7 +170,7 @@ def generate_pagination_where_clause(


def generate_pagination_bounds(
direction: str,
direction: Direction,
from_token: Optional[RoomStreamToken],
to_token: Optional[RoomStreamToken],
) -> Tuple[
Expand All @@ -181,7 +180,7 @@ def generate_pagination_bounds(
Generate a start and end point for this page of events.
Args:
direction: Whether pagination is going forwards or backwards. One of "f" or "b".
direction: Whether pagination is going forwards or backwards.
from_token: The token to start pagination at, or None to start at the first value.
to_token: The token to end pagination at, or None to not limit the end point.
Expand All @@ -201,7 +200,7 @@ def generate_pagination_bounds(
# Tokens really represent positions between elements, but we use
# the convention of pointing to the event before the gap. Hence
# we have a bit of asymmetry when it comes to equalities.
if direction == "b":
if direction == Direction.BACKWARDS:
order = "DESC"
else:
order = "ASC"
Expand All @@ -215,7 +214,7 @@ def generate_pagination_bounds(
if from_token:
if from_token.topological is not None:
from_bound = from_token.as_historical_tuple()
elif direction == "b":
elif direction == Direction.BACKWARDS:
from_bound = (
None,
from_token.get_max_stream_pos(),
Expand All @@ -230,7 +229,7 @@ def generate_pagination_bounds(
if to_token:
if to_token.topological is not None:
to_bound = to_token.as_historical_tuple()
elif direction == "b":
elif direction == Direction.BACKWARDS:
to_bound = (
None,
to_token.stream,
Expand All @@ -245,20 +244,20 @@ def generate_pagination_bounds(


def generate_next_token(
direction: str, last_topo_ordering: int, last_stream_ordering: int
direction: Direction, last_topo_ordering: int, last_stream_ordering: int
) -> RoomStreamToken:
"""
Generate the next room stream token based on the currently returned data.
Args:
direction: Whether pagination is going forwards or backwards. One of "f" or "b".
direction: Whether pagination is going forwards or backwards.
last_topo_ordering: The last topological ordering being returned.
last_stream_ordering: The last stream ordering being returned.
Returns:
A new RoomStreamToken to return to the client.
"""
if direction == "b":
if direction == Direction.BACKWARDS:
# Tokens are positions between events.
# This token points *after* the last event in the chunk.
# We need it to point to the event before it in the chunk
Expand Down Expand Up @@ -1201,7 +1200,7 @@ def _get_events_around_txn(
txn,
room_id,
before_token,
direction="b",
direction=Direction.BACKWARDS,
limit=before_limit,
event_filter=event_filter,
)
Expand All @@ -1211,7 +1210,7 @@ def _get_events_around_txn(
txn,
room_id,
after_token,
direction="f",
direction=Direction.FORWARDS,
limit=after_limit,
event_filter=event_filter,
)
Expand Down Expand Up @@ -1374,7 +1373,7 @@ def _paginate_room_events_txn(
room_id: str,
from_token: RoomStreamToken,
to_token: Optional[RoomStreamToken] = None,
direction: str = "b",
direction: Direction = Direction.BACKWARDS,
limit: int = -1,
event_filter: Optional[Filter] = None,
) -> Tuple[List[_EventDictReturn], RoomStreamToken]:
Expand All @@ -1385,8 +1384,8 @@ def _paginate_room_events_txn(
room_id
from_token: The token used to stream from
to_token: A token which if given limits the results to only those before
direction: Either 'b' or 'f' to indicate whether we are paginating
forwards or backwards from `from_key`.
direction: Indicates whether we are paginating forwards or backwards
from `from_key`.
limit: The maximum number of events to return.
event_filter: If provided filters the events to
those that match the filter.
Expand Down Expand Up @@ -1489,8 +1488,12 @@ def _paginate_room_events_txn(
_EventDictReturn(event_id, topological_ordering, stream_ordering)
for event_id, instance_name, topological_ordering, stream_ordering in txn
if _filter_results(
lower_token=to_token if direction == "b" else from_token,
upper_token=from_token if direction == "b" else to_token,
lower_token=to_token
if direction == Direction.BACKWARDS
else from_token,
upper_token=from_token
if direction == Direction.BACKWARDS
else to_token,
instance_name=instance_name,
topological_ordering=topological_ordering,
stream_ordering=stream_ordering,
Expand All @@ -1514,7 +1517,7 @@ async def paginate_room_events(
room_id: str,
from_key: RoomStreamToken,
to_key: Optional[RoomStreamToken] = None,
direction: str = "b",
direction: Direction = Direction.BACKWARDS,
limit: int = -1,
event_filter: Optional[Filter] = None,
) -> Tuple[List[EventBase], RoomStreamToken]:
Expand All @@ -1524,8 +1527,8 @@ async def paginate_room_events(
room_id
from_key: The token used to stream from
to_key: A token which if given limits the results to only those before
direction: Either 'b' or 'f' to indicate whether we are paginating
forwards or backwards from `from_key`.
direction: Indicates whether we are paginating forwards or backwards
from `from_key`.
limit: The maximum number of events to return.
event_filter: If provided filters the events to those that match the filter.
Expand Down
11 changes: 8 additions & 3 deletions synapse/streams/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import attr

from synapse.api.constants import Direction
from synapse.api.errors import SynapseError
from synapse.http.servlet import parse_integer, parse_string
from synapse.http.site import SynapseRequest
Expand All @@ -34,7 +35,7 @@ class PaginationConfig:

from_token: Optional[StreamToken]
to_token: Optional[StreamToken]
direction: str
direction: Direction
limit: int

@classmethod
Expand All @@ -45,9 +46,13 @@ async def from_request(
default_limit: int,
default_dir: str = "f",
) -> "PaginationConfig":
direction = parse_string(
request, "dir", default=default_dir, allowed_values=["f", "b"]
direction_str = parse_string(
request,
"dir",
default=default_dir,
allowed_values=[Direction.FORWARDS.value, Direction.BACKWARDS.value],
)
direction = Direction(direction_str)

from_tok_str = parse_string(request, "from")
to_tok_str = parse_string(request, "to")
Expand Down

0 comments on commit 265735d

Please sign in to comment.