Skip to content

Commit

Permalink
Sliding sync: Store the per-connection state in the database. (#17599)
Browse files Browse the repository at this point in the history
Based on #17600

---------

Co-authored-by: Eric Eastwood <[email protected]>
  • Loading branch information
erikjohnston and MadLittleMods authored Aug 29, 2024
1 parent 2999a14 commit e43c2b0
Show file tree
Hide file tree
Showing 14 changed files with 692 additions and 116 deletions.
1 change: 1 addition & 0 deletions changelog.d/17599.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Store sliding sync per-connection state in the database.
2 changes: 2 additions & 0 deletions synapse/app/generic_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
from synapse.storage.databases.main.search import SearchStore
from synapse.storage.databases.main.session import SessionStore
from synapse.storage.databases.main.signatures import SignatureWorkerStore
from synapse.storage.databases.main.sliding_sync import SlidingSyncStore
from synapse.storage.databases.main.state import StateGroupWorkerStore
from synapse.storage.databases.main.stats import StatsStore
from synapse.storage.databases.main.stream import StreamWorkerStore
Expand Down Expand Up @@ -159,6 +160,7 @@ class GenericWorkerStore(
SessionStore,
TaskSchedulerWorkerStore,
ExperimentalFeaturesStore,
SlidingSyncStore,
):
# Properties that multiple storage classes define. Tell mypy what the
# expected type is.
Expand Down
9 changes: 2 additions & 7 deletions synapse/handlers/sliding_sync/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(self, hs: "HomeServer"):
self.rooms_to_exclude_globally = hs.config.server.rooms_to_exclude_from_sync
self.is_mine_id = hs.is_mine_id

self.connection_store = SlidingSyncConnectionStore()
self.connection_store = SlidingSyncConnectionStore(self.store)
self.extensions = SlidingSyncExtensionHandler(hs)
self.room_lists = SlidingSyncRoomLists(hs)

Expand Down Expand Up @@ -210,16 +210,11 @@ async def current_sync_for_user(
# amount of time (more with round-trips and re-processing) in the end to
# get everything again.
previous_connection_state = (
await self.connection_store.get_per_connection_state(
await self.connection_store.get_and_clear_connection_positions(
sync_config, from_token
)
)

await self.connection_store.mark_token_seen(
sync_config=sync_config,
from_token=from_token,
)

# Get all of the room IDs that the user should be able to see in the sync
# response
has_lists = sync_config.lists is not None and len(sync_config.lists) > 0
Expand Down
142 changes: 35 additions & 107 deletions synapse/handlers/sliding_sync/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
#

import logging
from typing import TYPE_CHECKING, Dict, Optional, Tuple
from typing import TYPE_CHECKING, Optional

import attr

from synapse.api.errors import SlidingSyncUnknownPosition
from synapse.logging.opentracing import trace
from synapse.storage.databases.main import DataStore
from synapse.types import SlidingSyncStreamToken
from synapse.types.handlers.sliding_sync import (
MutablePerConnectionState,
Expand Down Expand Up @@ -61,22 +61,9 @@ class SlidingSyncConnectionStore:
to mapping of room ID to `HaveSentRoom`.
"""

# `(user_id, conn_id)` -> `connection_position` -> `PerConnectionState`
_connections: Dict[Tuple[str, str], Dict[int, PerConnectionState]] = attr.Factory(
dict
)
store: "DataStore"

async def is_valid_token(
self, sync_config: SlidingSyncConfig, connection_token: int
) -> bool:
"""Return whether the connection token is valid/recognized"""
if connection_token == 0:
return True

conn_key = self._get_connection_key(sync_config)
return connection_token in self._connections.get(conn_key, {})

async def get_per_connection_state(
async def get_and_clear_connection_positions(
self,
sync_config: SlidingSyncConfig,
from_token: Optional[SlidingSyncStreamToken],
Expand All @@ -86,23 +73,21 @@ async def get_per_connection_state(
Raises:
SlidingSyncUnknownPosition if the connection_token is unknown
"""
if from_token is None:
# If this is our first request, there is no previous connection state to fetch out of the database
if from_token is None or from_token.connection_position == 0:
return PerConnectionState()

connection_position = from_token.connection_position
if connection_position == 0:
# Initial sync (request without a `from_token`) starts at `0` so
# there is no existing per-connection state
return PerConnectionState()

conn_key = self._get_connection_key(sync_config)
sync_statuses = self._connections.get(conn_key, {})
connection_state = sync_statuses.get(connection_position)
conn_id = sync_config.conn_id or ""

if connection_state is None:
raise SlidingSyncUnknownPosition()
device_id = sync_config.requester.device_id
assert device_id is not None

return connection_state
return await self.store.get_and_clear_connection_positions(
sync_config.user.to_string(),
device_id,
conn_id,
from_token.connection_position,
)

@trace
async def record_new_state(
Expand All @@ -116,85 +101,28 @@ async def record_new_state(
If there are no changes to the state this may return the same token as
the existing per-connection state.
"""
prev_connection_token = 0
if from_token is not None:
prev_connection_token = from_token.connection_position

if not new_connection_state.has_updates():
return prev_connection_token

conn_key = self._get_connection_key(sync_config)
sync_statuses = self._connections.setdefault(conn_key, {})

# Generate a new token, removing any existing entries in that token
# (which can happen if requests get resent).
new_store_token = prev_connection_token + 1
sync_statuses.pop(new_store_token, None)

# We copy the `MutablePerConnectionState` so that the inner `ChainMap`s
# don't grow forever.
sync_statuses[new_store_token] = new_connection_state.copy()

return new_store_token
if from_token is not None:
return from_token.connection_position
else:
return 0

# A from token with a zero connection position means there was no
# previously stored connection state, so we treat a zero the same as
# there being no previous position.
previous_connection_position = None
if from_token is not None and from_token.connection_position != 0:
previous_connection_position = from_token.connection_position

@trace
async def mark_token_seen(
self,
sync_config: SlidingSyncConfig,
from_token: Optional[SlidingSyncStreamToken],
) -> None:
"""We have received a request with the given token, so we can clear out
any other tokens associated with the connection.
If there is no from token then we have started afresh, and so we delete
all tokens associated with the device.
"""
# Clear out any tokens for the connection that doesn't match the one
# from the request.

conn_key = self._get_connection_key(sync_config)
sync_statuses = self._connections.pop(conn_key, {})
if from_token is None:
return

sync_statuses = {
connection_token: room_statuses
for connection_token, room_statuses in sync_statuses.items()
if connection_token == from_token.connection_position
}
if sync_statuses:
self._connections[conn_key] = sync_statuses

@staticmethod
def _get_connection_key(sync_config: SlidingSyncConfig) -> Tuple[str, str]:
"""Return a unique identifier for this connection.
The first part is simply the user ID.
The second part is generally a combination of device ID and conn_id.
However, both these two are optional (e.g. puppet access tokens don't
have device IDs), so this handles those edge cases.
We use this over the raw `conn_id` to avoid clashes between different
clients that use the same `conn_id`. Imagine a user uses a web client
that uses `conn_id: main_sync_loop` and an Android client that also has
a `conn_id: main_sync_loop`.
"""

user_id = sync_config.user.to_string()

# Only one sliding sync connection is allowed per given conn_id (empty
# or not).
conn_id = sync_config.conn_id or ""

if sync_config.requester.device_id:
return (user_id, f"D/{sync_config.requester.device_id}/{conn_id}")

if sync_config.requester.access_token_id:
# If we don't have a device, then the access token ID should be a
# stable ID.
return (user_id, f"A/{sync_config.requester.access_token_id}/{conn_id}")
device_id = sync_config.requester.device_id
assert device_id is not None

# If we have neither then its likely an AS or some weird token. Either
# way we can just fail here.
raise Exception("Cannot use sliding sync with access token type")
return await self.store.persist_per_connection_state(
sync_config.user.to_string(),
device_id,
conn_id,
previous_connection_position,
new_connection_state,
)
43 changes: 43 additions & 0 deletions synapse/storage/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
from synapse.storage.background_updates import BackgroundUpdater
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.types import Connection, Cursor, SQLQueryParameters
from synapse.types import StrCollection
from synapse.util.async_helpers import delay_cancellation
from synapse.util.iterutils import batch_iter

Expand Down Expand Up @@ -1096,6 +1097,48 @@ def simple_insert_txn(

txn.execute(sql, vals)

@staticmethod
def simple_insert_returning_txn(
txn: LoggingTransaction,
table: str,
values: Dict[str, Any],
returning: StrCollection,
) -> Tuple[Any, ...]:
"""Executes a `INSERT INTO... RETURNING...` statement (or equivalent for
SQLite versions that don't support it).
"""

if txn.database_engine.supports_returning:
sql = "INSERT INTO %s (%s) VALUES(%s) RETURNING %s" % (
table,
", ".join(k for k in values.keys()),
", ".join("?" for _ in values.keys()),
", ".join(k for k in returning),
)

txn.execute(sql, list(values.values()))
row = txn.fetchone()
assert row is not None
return row
else:
# For old versions of SQLite we do a standard insert and then can
# use `last_insert_rowid` to get at the row we just inserted
DatabasePool.simple_insert_txn(
txn,
table=table,
values=values,
)
txn.execute("SELECT last_insert_rowid()")
row = txn.fetchone()
assert row is not None
(rowid,) = row

row = DatabasePool.simple_select_one_txn(
txn, table=table, keyvalues={"rowid": rowid}, retcols=returning
)
assert row is not None
return row

async def simple_insert_many(
self,
table: str,
Expand Down
2 changes: 2 additions & 0 deletions synapse/storage/databases/main/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.databases.main.sliding_sync import SlidingSyncStore
from synapse.storage.databases.main.stats import UserSortOrder
from synapse.storage.engines import BaseDatabaseEngine
from synapse.storage.types import Cursor
Expand Down Expand Up @@ -156,6 +157,7 @@ class DataStore(
LockStore,
SessionStore,
TaskSchedulerWorkerStore,
SlidingSyncStore,
):
def __init__(
self,
Expand Down
Loading

0 comments on commit e43c2b0

Please sign in to comment.