diff --git a/changelog.d/16805.misc b/changelog.d/16805.misc new file mode 100644 index 00000000000..0b54ab0f742 --- /dev/null +++ b/changelog.d/16805.misc @@ -0,0 +1 @@ +Optimize query for fetching to-device messages in `/sync`. diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 40477b9da00..fa47b471e8c 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -245,33 +245,74 @@ async def get_messages_for_device( * The last-processed stream ID. Subsequent calls of this function with the same device should pass this value as 'from_stream_id'. """ - ( - user_id_device_id_to_messages, - last_processed_stream_id, - ) = await self._get_device_messages( - user_ids=[user_id], - device_id=device_id, - from_stream_id=from_stream_id, - to_stream_id=to_stream_id, - limit=limit, - ) - - if not user_id_device_id_to_messages: + if not self._device_inbox_stream_cache.has_entity_changed( + user_id, from_stream_id + ): # There were no messages! return [], to_stream_id - # Extract the messages, no need to return the user and device ID again - to_device_messages = user_id_device_id_to_messages.get((user_id, device_id), []) + def get_device_messages_txn( + txn: LoggingTransaction, + ) -> Tuple[List[JsonDict], int]: + sql = """ + SELECT stream_id, message_json FROM device_inbox + WHERE user_id = ? AND device_id = ? + AND ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC + LIMIT ? + """ + txn.execute(sql, (user_id, device_id, from_stream_id, to_stream_id, limit)) + + # Create and fill a dictionary of (user ID, device ID) -> list of messages + # intended for each device. + last_processed_stream_pos = to_stream_id + to_device_messages: List[JsonDict] = [] + rowcount = 0 + for row in txn: + rowcount += 1 + + last_processed_stream_pos = row[0] + message_dict = db_to_json(row[1]) + + # Store the device details + to_device_messages.append(message_dict) - return to_device_messages, last_processed_stream_id + # start a new span for each message, so that we can tag each separately + with start_active_span("get_to_device_message"): + set_tag(SynapseTags.TO_DEVICE_TYPE, message_dict["type"]) + set_tag(SynapseTags.TO_DEVICE_SENDER, message_dict["sender"]) + set_tag(SynapseTags.TO_DEVICE_RECIPIENT, user_id) + set_tag(SynapseTags.TO_DEVICE_RECIPIENT_DEVICE, device_id) + set_tag( + SynapseTags.TO_DEVICE_MSGID, + message_dict["content"].get(EventContentFields.TO_DEVICE_MSGID), + ) + + if rowcount == limit: + # We ended up bumping up against the message limit. There may be more messages + # to retrieve. Return what we have, as well as the last stream position that + # was processed. + # + # The caller is expected to set this as the lower (exclusive) bound + # for the next query of this device. + return to_device_messages, last_processed_stream_pos + + # The limit was not reached, thus we know that recipient_device_to_messages + # contains all to-device messages for the given device and stream id range. + # + # We return to_stream_id, which the caller should then provide as the lower + # (exclusive) bound on the next query of this device. + return to_device_messages, to_stream_id + + return await self.db_pool.runInteraction( + "get_messages_for_device", get_device_messages_txn + ) async def _get_device_messages( self, user_ids: Collection[str], from_stream_id: int, to_stream_id: int, - device_id: Optional[str] = None, - limit: Optional[int] = None, ) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]: """ Retrieve pending to-device messages for a collection of user devices. @@ -291,11 +332,7 @@ async def _get_device_messages( user_ids: The user IDs to filter device messages by. from_stream_id: The lower boundary of stream id to filter with (exclusive). to_stream_id: The upper boundary of stream id to filter with (inclusive). - device_id: A device ID to query to-device messages for. If not provided, to-device - messages from all device IDs for the given user IDs will be queried. May not be - provided if `user_ids` contains more than one entry. - limit: The maximum number of to-device messages to return. Can only be used when - passing a single user ID / device ID tuple. + Returns: A tuple containing: @@ -308,30 +345,7 @@ async def _get_device_messages( logger.warning("No users provided upon querying for device IDs") return {}, to_stream_id - # Prevent a query for one user's device also retrieving another user's device with - # the same device ID (device IDs are not unique across users). - if len(user_ids) > 1 and device_id is not None: - raise AssertionError( - "Programming error: 'device_id' cannot be supplied to " - "_get_device_messages when >1 user_id has been provided" - ) - - # A limit can only be applied when querying for a single user ID / device ID tuple. - # See the docstring of this function for more details. - if limit is not None and device_id is None: - raise AssertionError( - "Programming error: _get_device_messages was passed 'limit' " - "without a specific user_id/device_id" - ) - user_ids_to_query: Set[str] = set() - device_ids_to_query: Set[str] = set() - - # Note that a device ID could be an empty str - if device_id is not None: - # If a device ID was passed, use it to filter results. - # Otherwise, device IDs will be derived from the given collection of user IDs. - device_ids_to_query.add(device_id) # Determine which users have devices with pending messages for user_id in user_ids: @@ -355,20 +369,20 @@ def get_device_messages_txn( # hidden devices should not receive to-device messages. # Note that this is more efficient than just dropping `device_id` from the query, # since device_inbox has an index on `(user_id, device_id, stream_id)` - if not device_ids_to_query: - user_device_dicts = cast( - List[Tuple[str]], - self.db_pool.simple_select_many_txn( - txn, - table="devices", - column="user_id", - iterable=user_ids_to_query, - keyvalues={"hidden": False}, - retcols=("device_id",), - ), - ) - device_ids_to_query.update({row[0] for row in user_device_dicts}) + user_device_dicts = cast( + List[Tuple[str]], + self.db_pool.simple_select_many_txn( + txn, + table="devices", + column="user_id", + iterable=user_ids_to_query, + keyvalues={"hidden": False}, + retcols=("device_id",), + ), + ) + + device_ids_to_query = {row[0] for row in user_device_dicts} if not device_ids_to_query: # We've ended up with no devices to query. @@ -400,22 +414,15 @@ def get_device_messages_txn( to_stream_id, ) - # If a limit was provided, limit the data retrieved from the database - if limit is not None: - sql += "LIMIT ?" - sql_args += (limit,) - txn.execute(sql, sql_args) # Create and fill a dictionary of (user ID, device ID) -> list of messages # intended for each device. - last_processed_stream_pos = to_stream_id recipient_device_to_messages: Dict[Tuple[str, str], List[JsonDict]] = {} rowcount = 0 for row in txn: rowcount += 1 - last_processed_stream_pos = row[0] recipient_user_id = row[1] recipient_device_id = row[2] message_dict = db_to_json(row[3]) @@ -436,18 +443,6 @@ def get_device_messages_txn( message_dict["content"].get(EventContentFields.TO_DEVICE_MSGID), ) - if limit is not None and rowcount == limit: - # We ended up bumping up against the message limit. There may be more messages - # to retrieve. Return what we have, as well as the last stream position that - # was processed. - # - # The caller is expected to set this as the lower (exclusive) bound - # for the next query of this device. - return recipient_device_to_messages, last_processed_stream_pos - - # The limit was not reached, thus we know that recipient_device_to_messages - # contains all to-device messages for the given device and stream id range. - # # We return to_stream_id, which the caller should then provide as the lower # (exclusive) bound on the next query of this device. return recipient_device_to_messages, to_stream_id