Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize query for fetching to-device messages in /sync #16805

Merged
merged 3 commits into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/16805.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Optimize query for fetching to-device messages in `/sync`.
149 changes: 72 additions & 77 deletions synapse/storage/databases/main/deviceinbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,33 +245,74 @@ async def get_messages_for_device(
* The last-processed stream ID. Subsequent calls of this function with the
same device should pass this value as 'from_stream_id'.
"""
(
user_id_device_id_to_messages,
last_processed_stream_id,
) = await self._get_device_messages(
user_ids=[user_id],
device_id=device_id,
from_stream_id=from_stream_id,
to_stream_id=to_stream_id,
limit=limit,
)

if not user_id_device_id_to_messages:
if not self._device_inbox_stream_cache.has_entity_changed(
user_id, from_stream_id
):
# There were no messages!
return [], to_stream_id

# Extract the messages, no need to return the user and device ID again
to_device_messages = user_id_device_id_to_messages.get((user_id, device_id), [])
def get_device_messages_txn(
txn: LoggingTransaction,
) -> Tuple[List[JsonDict], int]:
sql = """
SELECT stream_id, message_json FROM device_inbox
WHERE user_id = ? AND device_id = ?
AND ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC
LIMIT ?
"""
txn.execute(sql, (user_id, device_id, from_stream_id, to_stream_id, limit))

# Create and fill a dictionary of (user ID, device ID) -> list of messages
# intended for each device.
last_processed_stream_pos = to_stream_id
to_device_messages: List[JsonDict] = []
rowcount = 0
for row in txn:
rowcount += 1

last_processed_stream_pos = row[0]
message_dict = db_to_json(row[1])

# Store the device details
to_device_messages.append(message_dict)

return to_device_messages, last_processed_stream_id
# start a new span for each message, so that we can tag each separately
with start_active_span("get_to_device_message"):
set_tag(SynapseTags.TO_DEVICE_TYPE, message_dict["type"])
set_tag(SynapseTags.TO_DEVICE_SENDER, message_dict["sender"])
set_tag(SynapseTags.TO_DEVICE_RECIPIENT, user_id)
set_tag(SynapseTags.TO_DEVICE_RECIPIENT_DEVICE, device_id)
set_tag(
SynapseTags.TO_DEVICE_MSGID,
message_dict["content"].get(EventContentFields.TO_DEVICE_MSGID),
)

if rowcount == limit:
# We ended up bumping up against the message limit. There may be more messages
# to retrieve. Return what we have, as well as the last stream position that
# was processed.
#
# The caller is expected to set this as the lower (exclusive) bound
# for the next query of this device.
return to_device_messages, last_processed_stream_pos

# The limit was not reached, thus we know that recipient_device_to_messages
# contains all to-device messages for the given device and stream id range.
#
# We return to_stream_id, which the caller should then provide as the lower
# (exclusive) bound on the next query of this device.
return to_device_messages, to_stream_id

return await self.db_pool.runInteraction(
"get_messages_for_device", get_device_messages_txn
)

async def _get_device_messages(
self,
user_ids: Collection[str],
from_stream_id: int,
to_stream_id: int,
device_id: Optional[str] = None,
limit: Optional[int] = None,
) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]:
"""
Retrieve pending to-device messages for a collection of user devices.
Expand All @@ -291,11 +332,7 @@ async def _get_device_messages(
user_ids: The user IDs to filter device messages by.
from_stream_id: The lower boundary of stream id to filter with (exclusive).
to_stream_id: The upper boundary of stream id to filter with (inclusive).
device_id: A device ID to query to-device messages for. If not provided, to-device
messages from all device IDs for the given user IDs will be queried. May not be
provided if `user_ids` contains more than one entry.
limit: The maximum number of to-device messages to return. Can only be used when
passing a single user ID / device ID tuple.


Returns:
A tuple containing:
Expand All @@ -308,30 +345,7 @@ async def _get_device_messages(
logger.warning("No users provided upon querying for device IDs")
return {}, to_stream_id

# Prevent a query for one user's device also retrieving another user's device with
# the same device ID (device IDs are not unique across users).
if len(user_ids) > 1 and device_id is not None:
raise AssertionError(
"Programming error: 'device_id' cannot be supplied to "
"_get_device_messages when >1 user_id has been provided"
)

# A limit can only be applied when querying for a single user ID / device ID tuple.
# See the docstring of this function for more details.
if limit is not None and device_id is None:
raise AssertionError(
"Programming error: _get_device_messages was passed 'limit' "
"without a specific user_id/device_id"
)

user_ids_to_query: Set[str] = set()
device_ids_to_query: Set[str] = set()

# Note that a device ID could be an empty str
if device_id is not None:
# If a device ID was passed, use it to filter results.
# Otherwise, device IDs will be derived from the given collection of user IDs.
device_ids_to_query.add(device_id)

# Determine which users have devices with pending messages
for user_id in user_ids:
Expand All @@ -355,20 +369,20 @@ def get_device_messages_txn(
# hidden devices should not receive to-device messages.
# Note that this is more efficient than just dropping `device_id` from the query,
# since device_inbox has an index on `(user_id, device_id, stream_id)`
if not device_ids_to_query:
user_device_dicts = cast(
List[Tuple[str]],
self.db_pool.simple_select_many_txn(
txn,
table="devices",
column="user_id",
iterable=user_ids_to_query,
keyvalues={"hidden": False},
retcols=("device_id",),
),
)

device_ids_to_query.update({row[0] for row in user_device_dicts})
user_device_dicts = cast(
List[Tuple[str]],
self.db_pool.simple_select_many_txn(
txn,
table="devices",
column="user_id",
iterable=user_ids_to_query,
keyvalues={"hidden": False},
retcols=("device_id",),
),
)

device_ids_to_query = {row[0] for row in user_device_dicts}

if not device_ids_to_query:
# We've ended up with no devices to query.
Expand Down Expand Up @@ -400,22 +414,15 @@ def get_device_messages_txn(
to_stream_id,
)

# If a limit was provided, limit the data retrieved from the database
if limit is not None:
sql += "LIMIT ?"
sql_args += (limit,)

txn.execute(sql, sql_args)

# Create and fill a dictionary of (user ID, device ID) -> list of messages
# intended for each device.
last_processed_stream_pos = to_stream_id
recipient_device_to_messages: Dict[Tuple[str, str], List[JsonDict]] = {}
rowcount = 0
for row in txn:
rowcount += 1

last_processed_stream_pos = row[0]
recipient_user_id = row[1]
recipient_device_id = row[2]
message_dict = db_to_json(row[3])
Expand All @@ -436,18 +443,6 @@ def get_device_messages_txn(
message_dict["content"].get(EventContentFields.TO_DEVICE_MSGID),
)

if limit is not None and rowcount == limit:
# We ended up bumping up against the message limit. There may be more messages
# to retrieve. Return what we have, as well as the last stream position that
# was processed.
#
# The caller is expected to set this as the lower (exclusive) bound
# for the next query of this device.
return recipient_device_to_messages, last_processed_stream_pos

# The limit was not reached, thus we know that recipient_device_to_messages
# contains all to-device messages for the given device and stream id range.
#
# We return to_stream_id, which the caller should then provide as the lower
# (exclusive) bound on the next query of this device.
return recipient_device_to_messages, to_stream_id
Expand Down
Loading