Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Cache requests for user's devices from federation #15675

Merged
merged 5 commits into from
Jun 1, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/15675.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Cache requests for user's devices over federation.
4 changes: 4 additions & 0 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -1941,6 +1941,10 @@ def _add_device_change_to_stream_txn(
user_id,
stream_ids[-1],
)
txn.call_after(
self._get_e2e_device_keys_for_federation_query_inner.invalidate,
(user_id,),
)

min_stream_id = stream_ids[0]

Expand Down
63 changes: 61 additions & 2 deletions synapse/storage/databases/main/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import abc
from typing import (
TYPE_CHECKING,
Any,
Collection,
Dict,
Iterable,
Expand All @@ -39,6 +40,7 @@
TransactionUnusedFallbackKeys,
)
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.replication.tcp.streams._base import DeviceListsStream
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import (
DatabasePool,
Expand Down Expand Up @@ -104,6 +106,23 @@ def __init__(
self.hs.config.federation.allow_device_name_lookup_over_federation
)

def process_replication_rows(
self,
stream_name: str,
instance_name: str,
token: int,
rows: Iterable[Any],
) -> None:
if stream_name == DeviceListsStream.NAME:
for row in rows:
assert isinstance(row, DeviceListsStream.DeviceListsStreamRow)
if row.entity.startswith("@"):
self._get_e2e_device_keys_for_federation_query_inner.invalidate(
(row.entity,)
)

super().process_replication_rows(stream_name, instance_name, token, rows)

async def get_e2e_device_keys_for_federation_query(
self, user_id: str
) -> Tuple[int, List[JsonDict]]:
Expand All @@ -114,6 +133,46 @@ async def get_e2e_device_keys_for_federation_query(
"""
now_stream_id = self.get_device_stream_token()

# We need to be careful with the caching here, as we need to always
# return *all* persisted devices, however there may be a lag between a
# new device being persisted and the cache being invalidated.
cached_results = (
self._get_e2e_device_keys_for_federation_query_inner.cache.get_immediate(
user_id, None
)
)
if cached_results is not None:
# Check that there have been no new devices added by another worker
# after the cache. This should be quick as there should be few rows
# with a higher stream ordering.
sql = """
SELECT user_id FROM device_lists_stream
WHERE stream_id >= ? AND user_id = ?
"""
rows = await self.db_pool.execute(
"get_e2e_device_keys_for_federation_query_check",
None,
sql,
now_stream_id,
user_id,
)
if not rows:
# No new rows, so cache is still valid.
return now_stream_id, cached_results

# There has, so let's invalidate the cache and run the query.
self._get_e2e_device_keys_for_federation_query_inner.invalidate((user_id,))

results = await self._get_e2e_device_keys_for_federation_query_inner(user_id)

return now_stream_id, results
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm having trouble figuring out whether this is correct.

cached_results will come from a time when the device stream token was less than now_stream_id. But we only discard the cached result if there are rows with a stream ordering >= now_stream_id, so wouldn't we be overlooking changes between cached_results's stream id and now_stream_id?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that case is fine, as for all changes less than now_stream_id we should have invalidated the cache?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, if now_stream_id can't advance without invalidating the cache then it looks okay.
We could probably add a comment to that effect.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added:

Note that we invalidate based on the device stream, so we only have to check for potential invalidations after the now_stream_id.

Does that make sense?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it makes enough sense.


@cached(iterable=True)
async def _get_e2e_device_keys_for_federation_query_inner(
self, user_id: str
) -> List[JsonDict]:
"""Get all devices (with any device keys) for a user"""

devices = await self.get_e2e_device_keys_and_signatures([(user_id, None)])

if devices:
Expand All @@ -134,9 +193,9 @@ async def get_e2e_device_keys_for_federation_query(

results.append(result)

return now_stream_id, results
return results

return now_stream_id, []
return []

@trace
@cancellable
Expand Down