Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Better return type for get_all_entities_changed (#14604)
Browse files Browse the repository at this point in the history
Help callers from using the return value incorrectly by ensuring
that callers explicitly check if there was a cache hit or not.
  • Loading branch information
erikjohnston authored Dec 5, 2022
1 parent 6a8310f commit cee9445
Show file tree
Hide file tree
Showing 8 changed files with 138 additions and 76 deletions.
1 change: 1 addition & 0 deletions changelog.d/14604.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix a long-standing bug where a device list update might not be sent to clients in certain circumstances.
4 changes: 2 additions & 2 deletions synapse/handlers/appservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,8 +615,8 @@ async def _get_device_list_summary(
)

# Fetch the users who have modified their device list since then.
users_with_changed_device_lists = (
await self.store.get_users_whose_devices_changed(from_key, to_key=new_key)
users_with_changed_device_lists = await self.store.get_all_devices_changed(
from_key, to_key=new_key
)

# Filter out any users the application service is not interested in
Expand Down
12 changes: 7 additions & 5 deletions synapse/handlers/presence.py
Original file line number Diff line number Diff line change
Expand Up @@ -1692,10 +1692,12 @@ async def get_new_events(

if from_key is not None:
# First get all users that have had a presence update
updated_users = stream_change_cache.get_all_entities_changed(from_key)
result = stream_change_cache.get_all_entities_changed(from_key)

# Cross-reference users we're interested in with those that have had updates.
if updated_users is not None:
if result.hit:
updated_users = result.entities

# If we have the full list of changes for presence we can
# simply check which ones share a room with the user.
get_updates_counter.labels("stream").inc()
Expand Down Expand Up @@ -1767,9 +1769,9 @@ async def _filter_all_presence_updates_for_user(
updated_users = None
if from_key:
# Only return updates since the last sync
updated_users = self.store.presence_stream_cache.get_all_entities_changed(
from_key
)
result = self.store.presence_stream_cache.get_all_entities_changed(from_key)
if result.hit:
updated_users = result.entities

if updated_users is not None:
# Get the actual presence update for each change
Expand Down
6 changes: 4 additions & 2 deletions synapse/handlers/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -1528,10 +1528,12 @@ async def _generate_sync_entry_for_device_list(
#
# If we don't have that info cached then we get all the users that
# share a room with our user and check if those users have changed.
changed_users = self.store.get_cached_device_list_changes(
cache_result = self.store.get_cached_device_list_changes(
since_token.device_list_key
)
if changed_users is not None:
if cache_result.hit:
changed_users = cache_result.entities

result = await self.store.get_rooms_for_users(changed_users)

for changed_user_id, entries in result.items():
Expand Down
8 changes: 4 additions & 4 deletions synapse/handlers/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,11 +420,11 @@ async def get_all_typing_updates(
if last_id == current_id:
return [], current_id, False

changed_rooms: Optional[
Iterable[str]
] = self._typing_stream_change_cache.get_all_entities_changed(last_id)
result = self._typing_stream_change_cache.get_all_entities_changed(last_id)

if changed_rooms is None:
if result.hit:
changed_rooms: Iterable[str] = result.entities
else:
changed_rooms = self._room_serials

rows = []
Expand Down
111 changes: 71 additions & 40 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@
from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.caches.stream_change_cache import (
AllEntitiesChangedResult,
StreamChangeCache,
)
from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter
from synapse.util.stringutils import shortstr
Expand Down Expand Up @@ -799,18 +802,66 @@ async def get_cached_devices_for_user(self, user_id: str) -> Dict[str, JsonDict]
def get_cached_device_list_changes(
self,
from_key: int,
) -> Optional[List[str]]:
) -> AllEntitiesChangedResult:
"""Get set of users whose devices have changed since `from_key`, or None
if that information is not in our cache.
"""

return self._device_list_stream_cache.get_all_entities_changed(from_key)

@cancellable
async def get_all_devices_changed(
self,
from_key: int,
to_key: int,
) -> Set[str]:
"""Get all users whose devices have changed in the given range.
Args:
from_key: The minimum device lists stream token to query device list
changes for, exclusive.
to_key: The maximum device lists stream token to query device list
changes for, inclusive.
Returns:
The set of user_ids whose devices have changed since `from_key`
(exclusive) until `to_key` (inclusive).
"""

result = self._device_list_stream_cache.get_all_entities_changed(from_key)

if result.hit:
# We know which users might have changed devices.
if not result.entities:
# If no users then we can return early.
return set()

# Otherwise we need to filter down the list
return await self.get_users_whose_devices_changed(
from_key, result.entities, to_key
)

# If the cache didn't tell us anything, we just need to query the full
# range.
sql = """
SELECT DISTINCT user_id FROM device_lists_stream
WHERE ? < stream_id AND stream_id <= ?
"""

rows = await self.db_pool.execute(
"get_all_devices_changed",
None,
sql,
from_key,
to_key,
)
return {u for u, in rows}

@cancellable
async def get_users_whose_devices_changed(
self,
from_key: int,
user_ids: Optional[Collection[str]] = None,
user_ids: Collection[str],
to_key: Optional[int] = None,
) -> Set[str]:
"""Get set of users whose devices have changed since `from_key` that
Expand All @@ -830,52 +881,32 @@ async def get_users_whose_devices_changed(
"""
# Get set of users who *may* have changed. Users not in the returned
# list have definitely not changed.
user_ids_to_check: Optional[Collection[str]]
if user_ids is None:
# Get set of all users that have had device list changes since 'from_key'
user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed(
from_key
)
else:
# The same as above, but filter results to only those users in 'user_ids'
user_ids_to_check = self._device_list_stream_cache.get_entities_changed(
user_ids, from_key
)
user_ids_to_check = self._device_list_stream_cache.get_entities_changed(
user_ids, from_key
)

# If an empty set was returned, there's nothing to do.
if user_ids_to_check is not None and not user_ids_to_check:
if not user_ids_to_check:
return set()

def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]:
stream_id_where_clause = "stream_id > ?"
sql_args = [from_key]

if to_key:
stream_id_where_clause += " AND stream_id <= ?"
sql_args.append(to_key)
if to_key is None:
to_key = self._device_list_id_gen.get_current_token()

sql = f"""
def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]:
sql = """
SELECT DISTINCT user_id FROM device_lists_stream
WHERE {stream_id_where_clause}
WHERE ? < stream_id AND stream_id <= ? AND %s
"""

# If the stream change cache gave us no information, fetch *all*
# users between the stream IDs.
if user_ids_to_check is None:
txn.execute(sql, sql_args)
return {user_id for user_id, in txn}
changes: Set[str] = set()

# Otherwise, fetch changes for the given users.
else:
changes: Set[str] = set()

# Query device changes with a batch of users at a time
for chunk in batch_iter(user_ids_to_check, 100):
clause, args = make_in_list_sql_clause(
txn.database_engine, "user_id", chunk
)
txn.execute(sql + " AND " + clause, sql_args + args)
changes.update(user_id for user_id, in txn)
# Query device changes with a batch of users at a time
for chunk in batch_iter(user_ids_to_check, 100):
clause, args = make_in_list_sql_clause(
txn.database_engine, "user_id", chunk
)
txn.execute(sql % (clause,), [from_key, to_key] + args)
changes.update(user_id for user_id, in txn)

return changes

Expand Down
52 changes: 37 additions & 15 deletions synapse/util/caches/stream_change_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import math
from typing import Collection, Dict, FrozenSet, List, Mapping, Optional, Set, Union

import attr
from sortedcontainers import SortedDict

from synapse.util import caches
Expand All @@ -26,6 +27,29 @@
EntityType = str


@attr.s(auto_attribs=True, frozen=True, slots=True)
class AllEntitiesChangedResult:
"""Return type of `get_all_entities_changed`.
Callers must check that there was a cache hit, via `result.hit`, before
using the entities in `result.entities`.
This specifically does *not* implement helpers such as `__bool__` to ensure
that callers do the correct checks.
"""

_entities: Optional[List[EntityType]]

@property
def hit(self) -> bool:
return self._entities is not None

@property
def entities(self) -> List[EntityType]:
assert self._entities is not None
return self._entities


class StreamChangeCache:
"""
Keeps track of the stream positions of the latest change in a set of entities.
Expand Down Expand Up @@ -153,19 +177,19 @@ def get_entities_changed(
This will be all entities if the given stream position is at or earlier
than the earliest known stream position.
"""
changed_entities = self.get_all_entities_changed(stream_pos)
if changed_entities is not None:
cache_result = self.get_all_entities_changed(stream_pos)
if cache_result.hit:
# We now do an intersection, trying to do so in the most efficient
# way possible (some of these sets are *large*). First check in the
# given iterable is already a set that we can reuse, otherwise we
# create a set of the *smallest* of the two iterables and call
# `intersection(..)` on it (this can be twice as fast as the reverse).
if isinstance(entities, (set, frozenset)):
result = entities.intersection(changed_entities)
elif len(changed_entities) < len(entities):
result = set(changed_entities).intersection(entities)
result = entities.intersection(cache_result.entities)
elif len(cache_result.entities) < len(entities):
result = set(cache_result.entities).intersection(entities)
else:
result = set(entities).intersection(changed_entities)
result = set(entities).intersection(cache_result.entities)
self.metrics.inc_hits()
else:
result = set(entities)
Expand Down Expand Up @@ -202,36 +226,34 @@ def has_any_entity_changed(self, stream_pos: int) -> bool:
self.metrics.inc_hits()
return stream_pos < self._cache.peekitem()[0]

def get_all_entities_changed(self, stream_pos: int) -> Optional[List[EntityType]]:
def get_all_entities_changed(self, stream_pos: int) -> AllEntitiesChangedResult:
"""
Returns all entities that have had changes after the given position.
If the stream change cache does not go far enough back, i.e. the position
is too old, it will return None.
If the stream change cache does not go far enough back, i.e. the
position is too old, it will return None.
Returns the entities in the order that they were changed.
Args:
stream_pos: The stream position to check for changes after.
Return:
Entities which have changed after the given stream position.
None if the given stream position is at or earlier than the earliest
known stream position.
A class indicating if we have the requested data cached, and if so
includes the entities in the order they were changed.
"""
assert isinstance(stream_pos, int)

# _cache is not valid at or before the earliest known stream position, so
# return None to mark that it is unknown if an entity has changed.
if stream_pos <= self._earliest_known_stream_pos:
return None
return AllEntitiesChangedResult(None)

changed_entities: List[EntityType] = []

for k in self._cache.islice(start=self._cache.bisect_right(stream_pos)):
changed_entities.extend(self._cache[k])
return changed_entities
return AllEntitiesChangedResult(changed_entities)

def entity_has_changed(self, entity: EntityType, stream_pos: int) -> None:
"""
Expand Down
20 changes: 12 additions & 8 deletions tests/util/test_stream_change_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,19 +73,21 @@ def test_entity_has_changed_pops_off_start(self) -> None:
# The oldest item has been popped off
self.assertTrue("[email protected]" not in cache._entity_to_key)

self.assertEqual(cache.get_all_entities_changed(3), ["[email protected]"])
self.assertIsNone(cache.get_all_entities_changed(2))
self.assertEqual(
cache.get_all_entities_changed(3).entities, ["[email protected]"]
)
self.assertFalse(cache.get_all_entities_changed(2).hit)

# If we update an existing entity, it keeps the two existing entities
cache.entity_has_changed("[email protected]", 5)
self.assertEqual(
{"[email protected]", "[email protected]"}, set(cache._entity_to_key)
)
self.assertEqual(
cache.get_all_entities_changed(3),
cache.get_all_entities_changed(3).entities,
["[email protected]", "[email protected]"],
)
self.assertIsNone(cache.get_all_entities_changed(2))
self.assertFalse(cache.get_all_entities_changed(2).hit)

def test_get_all_entities_changed(self) -> None:
"""
Expand All @@ -105,10 +107,12 @@ def test_get_all_entities_changed(self) -> None:
# Results are ordered so either of these are valid.
ok1 = ["[email protected]", "[email protected]", "[email protected]"]
ok2 = ["[email protected]", "[email protected]", "[email protected]"]
self.assertTrue(r == ok1 or r == ok2)
self.assertTrue(r.entities == ok1 or r.entities == ok2)

self.assertEqual(cache.get_all_entities_changed(3), ["[email protected]"])
self.assertEqual(cache.get_all_entities_changed(1), None)
self.assertEqual(
cache.get_all_entities_changed(3).entities, ["[email protected]"]
)
self.assertFalse(cache.get_all_entities_changed(1).hit)

# ... later, things gest more updates
cache.entity_has_changed("[email protected]", 5)
Expand All @@ -128,7 +132,7 @@ def test_get_all_entities_changed(self) -> None:
"[email protected]",
]
r = cache.get_all_entities_changed(3)
self.assertTrue(r == ok1 or r == ok2)
self.assertTrue(r.entities == ok1 or r.entities == ok2)

def test_has_any_entity_changed(self) -> None:
"""
Expand Down

0 comments on commit cee9445

Please sign in to comment.