Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Better return type for get_all_entities_changed #14604

Merged
merged 7 commits into from
Dec 5, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/14604.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix a long-standing bug where a device list update might not be sent to clients in certain circumstances.
4 changes: 2 additions & 2 deletions synapse/handlers/appservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,8 +615,8 @@ async def _get_device_list_summary(
)

# Fetch the users who have modified their device list since then.
users_with_changed_device_lists = (
await self.store.get_users_whose_devices_changed(from_key, to_key=new_key)
users_with_changed_device_lists = await self.store.get_all_devices_changed(
from_key, to_key=new_key
)

# Filter out any users the application service is not interested in
Expand Down
12 changes: 7 additions & 5 deletions synapse/handlers/presence.py
Original file line number Diff line number Diff line change
Expand Up @@ -1692,10 +1692,12 @@ async def get_new_events(

if from_key is not None:
# First get all users that have had a presence update
updated_users = stream_change_cache.get_all_entities_changed(from_key)
result = stream_change_cache.get_all_entities_changed(from_key)

# Cross-reference users we're interested in with those that have had updates.
if updated_users is not None:
if result.hit:
updated_users = result.entities

# If we have the full list of changes for presence we can
# simply check which ones share a room with the user.
get_updates_counter.labels("stream").inc()
Expand Down Expand Up @@ -1767,9 +1769,9 @@ async def _filter_all_presence_updates_for_user(
updated_users = None
if from_key:
# Only return updates since the last sync
updated_users = self.store.presence_stream_cache.get_all_entities_changed(
from_key
)
result = self.store.presence_stream_cache.get_all_entities_changed(from_key)
if result.hit:
updated_users = result.entities

if updated_users is not None:
# Get the actual presence update for each change
Expand Down
6 changes: 4 additions & 2 deletions synapse/handlers/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -1528,10 +1528,12 @@ async def _generate_sync_entry_for_device_list(
#
# If we don't have that info cached then we get all the users that
# share a room with our user and check if those users have changed.
changed_users = self.store.get_cached_device_list_changes(
cache_result = self.store.get_cached_device_list_changes(
since_token.device_list_key
)
if changed_users is not None:
if cache_result.hit:
changed_users = cache_result.entities

result = await self.store.get_rooms_for_users(changed_users)

for changed_user_id, entries in result.items():
Expand Down
8 changes: 4 additions & 4 deletions synapse/handlers/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,11 +420,11 @@ async def get_all_typing_updates(
if last_id == current_id:
return [], current_id, False

changed_rooms: Optional[
Iterable[str]
] = self._typing_stream_change_cache.get_all_entities_changed(last_id)
result = self._typing_stream_change_cache.get_all_entities_changed(last_id)

if changed_rooms is None:
if result.hit:
changed_rooms: Iterable[str] = result.entities
else:
changed_rooms = self._room_serials

rows = []
Expand Down
86 changes: 46 additions & 40 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@
from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.caches.stream_change_cache import (
AllEntitiesChangedResult,
StreamChangeCache,
)
from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter
from synapse.util.stringutils import shortstr
Expand Down Expand Up @@ -799,18 +802,41 @@ async def get_cached_devices_for_user(self, user_id: str) -> Dict[str, JsonDict]
def get_cached_device_list_changes(
self,
from_key: int,
) -> Optional[List[str]]:
) -> AllEntitiesChangedResult:
"""Get set of users whose devices have changed since `from_key`, or None
if that information is not in our cache.
"""

return self._device_list_stream_cache.get_all_entities_changed(from_key)

@cancellable
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not 100% sure what makes something cancellable, but I think this is OK.

async def get_all_devices_changed(
self,
from_key: int,
to_key: int,
) -> Set[str]:
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved
result = self._device_list_stream_cache.get_all_entities_changed(from_key)

if result.hit:
return await self.get_users_whose_devices_changed(
from_key, result.entities, to_key
)

sql = """
SELECT DISTINCT user_id FROM device_lists_stream
WHERE ? < stream_id AND stream_id <= ?
"""

rows = await self.db_pool.execute(
"get_all_devices_changed", None, sql, (from_key, to_key)
)
return {u for u, in rows}

@cancellable
async def get_users_whose_devices_changed(
self,
from_key: int,
user_ids: Optional[Collection[str]] = None,
user_ids: Collection[str],
to_key: Optional[int] = None,
) -> Set[str]:
"""Get set of users whose devices have changed since `from_key` that
Expand All @@ -830,52 +856,32 @@ async def get_users_whose_devices_changed(
"""
# Get set of users who *may* have changed. Users not in the returned
# list have definitely not changed.
user_ids_to_check: Optional[Collection[str]]
if user_ids is None:
# Get set of all users that have had device list changes since 'from_key'
user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed(
from_key
)
else:
# The same as above, but filter results to only those users in 'user_ids'
user_ids_to_check = self._device_list_stream_cache.get_entities_changed(
user_ids, from_key
)
user_ids_to_check = self._device_list_stream_cache.get_entities_changed(
user_ids, from_key
)

# If an empty set was returned, there's nothing to do.
if user_ids_to_check is not None and not user_ids_to_check:
if not user_ids_to_check:
return set()

def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]:
stream_id_where_clause = "stream_id > ?"
sql_args = [from_key]

if to_key:
stream_id_where_clause += " AND stream_id <= ?"
sql_args.append(to_key)
if to_key is None:
to_key = self._device_list_id_gen.get_current_token()

sql = f"""
def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]:
sql = """
SELECT DISTINCT user_id FROM device_lists_stream
WHERE {stream_id_where_clause}
WHERE ? < stream_id AND stream_id <= ? AND %s
"""

# If the stream change cache gave us no information, fetch *all*
# users between the stream IDs.
if user_ids_to_check is None:
txn.execute(sql, sql_args)
return {user_id for user_id, in txn}
changes: Set[str] = set()

# Otherwise, fetch changes for the given users.
else:
changes: Set[str] = set()

# Query device changes with a batch of users at a time
for chunk in batch_iter(user_ids_to_check, 100):
clause, args = make_in_list_sql_clause(
txn.database_engine, "user_id", chunk
)
txn.execute(sql + " AND " + clause, sql_args + args)
changes.update(user_id for user_id, in txn)
# Query device changes with a batch of users at a time
for chunk in batch_iter(user_ids_to_check, 100):
clause, args = make_in_list_sql_clause(
txn.database_engine, "user_id", chunk
)
txn.execute(sql % (clause,), [from_key, to_key] + args)
changes.update(user_id for user_id, in txn)

return changes

Expand Down
42 changes: 33 additions & 9 deletions synapse/util/caches/stream_change_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import math
from typing import Collection, Dict, FrozenSet, List, Mapping, Optional, Set, Union

import attr
from sortedcontainers import SortedDict

from synapse.util import caches
Expand All @@ -26,6 +27,29 @@
EntityType = str


@attr.s(auto_attribs=True, frozen=True, slots=True)
class AllEntitiesChangedResult:
"""Return type of `get_all_entities_changed`.

Callers must check that there was a cache hit, via `result.hit`, before
using the entities in `result.entities`.

This specifically does *not* implement helpers such as `__bool__` to ensure
that callers do the correct checks.
"""

_entities: Optional[List[EntityType]]

@property
def hit(self) -> bool:
return self._entities is not None

@property
def entities(self) -> List[EntityType]:
assert self._entities is not None
return self._entities


class StreamChangeCache:
"""Keeps track of the stream positions of the latest change in a set of entities.

Expand Down Expand Up @@ -109,19 +133,19 @@ def get_entities_changed(
position. Entities unknown to the cache will be returned. If the
position is too old it will just return the given list.
"""
changed_entities = self.get_all_entities_changed(stream_pos)
if changed_entities is not None:
cache_result = self.get_all_entities_changed(stream_pos)
if cache_result.hit:
# We now do an intersection, trying to do so in the most efficient
# way possible (some of these sets are *large*). First check in the
# given iterable is already set that we can reuse, otherwise we
# create a set of the *smallest* of the two iterables and call
# `intersection(..)` on it (this can be twice as fast as the reverse).
if isinstance(entities, (set, frozenset)):
result = entities.intersection(changed_entities)
elif len(changed_entities) < len(entities):
result = set(changed_entities).intersection(entities)
result = entities.intersection(cache_result.entities)
elif len(cache_result.entities) < len(entities):
result = set(cache_result.entities).intersection(entities)
else:
result = set(entities).intersection(changed_entities)
result = set(entities).intersection(cache_result.entities)
self.metrics.inc_hits()
else:
result = set(entities)
Expand All @@ -144,7 +168,7 @@ def has_any_entity_changed(self, stream_pos: int) -> bool:
self.metrics.inc_misses()
return True

def get_all_entities_changed(self, stream_pos: int) -> Optional[List[EntityType]]:
def get_all_entities_changed(self, stream_pos: int) -> AllEntitiesChangedResult:
"""Returns all entities that have had new things since the given
position. If the position is too old it will return None.
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved

Expand All @@ -153,13 +177,13 @@ def get_all_entities_changed(self, stream_pos: int) -> Optional[List[EntityType]
assert type(stream_pos) is int

if stream_pos < self._earliest_known_stream_pos:
return None
return AllEntitiesChangedResult(None)

changed_entities: List[EntityType] = []

for k in self._cache.islice(start=self._cache.bisect_right(stream_pos)):
changed_entities.extend(self._cache[k])
return changed_entities
return AllEntitiesChangedResult(changed_entities)

def entity_has_changed(self, entity: EntityType, stream_pos: int) -> None:
"""Informs the cache that the entity has been changed at the given
Expand Down
20 changes: 11 additions & 9 deletions tests/util/test_stream_change_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,21 +70,21 @@ def test_entity_has_changed_pops_off_start(self):
self.assertTrue("[email protected]" not in cache._entity_to_key)

self.assertEqual(
cache.get_all_entities_changed(2),
cache.get_all_entities_changed(2).entities,
["[email protected]", "[email protected]"],
)
self.assertIsNone(cache.get_all_entities_changed(1))
self.assertFalse(cache.get_all_entities_changed(1).hit)

# If we update an existing entity, it keeps the two existing entities
cache.entity_has_changed("[email protected]", 5)
self.assertEqual(
{"[email protected]", "[email protected]"}, set(cache._entity_to_key)
)
self.assertEqual(
cache.get_all_entities_changed(2),
cache.get_all_entities_changed(2).entities,
["[email protected]", "[email protected]"],
)
self.assertIsNone(cache.get_all_entities_changed(1))
self.assertFalse(cache.get_all_entities_changed(1).hit)

def test_get_all_entities_changed(self):
"""
Expand Down Expand Up @@ -114,13 +114,15 @@ def test_get_all_entities_changed(self):
"[email protected]",
"[email protected]",
]
self.assertTrue(r == ok1 or r == ok2)
self.assertTrue(r.entities == ok1 or r.entities == ok2)

r = cache.get_all_entities_changed(2)
self.assertTrue(r == ok1[1:] or r == ok2[1:])
self.assertTrue(r.entities == ok1[1:] or r.entities == ok2[1:])

self.assertEqual(cache.get_all_entities_changed(3), ["[email protected]"])
self.assertEqual(cache.get_all_entities_changed(0), None)
self.assertEqual(
cache.get_all_entities_changed(3).entities, ["[email protected]"]
)
self.assertFalse(cache.get_all_entities_changed(0).hit)

# ... later, things gest more updates
cache.entity_has_changed("[email protected]", 5)
Expand All @@ -140,7 +142,7 @@ def test_get_all_entities_changed(self):
"[email protected]",
]
r = cache.get_all_entities_changed(3)
self.assertTrue(r == ok1 or r == ok2)
self.assertTrue(r.entities == ok1 or r.entities == ok2)

def test_has_any_entity_changed(self):
"""
Expand Down