Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Return keys for unwhitelisted servers from /_matrix/key/v2/query #13683

Merged
merged 1 commit into from
Sep 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/13683.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix a long-standing bug which meant that keys for unwhitelisted servers were not returned by `/_matrix/key/v2/query`.
41 changes: 21 additions & 20 deletions synapse/rest/key/v2/remote_key_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,6 @@ async def query_keys(

store_queries = []
for server_name, key_ids in query.items():
if (
self.federation_domain_whitelist is not None
and server_name not in self.federation_domain_whitelist
):
logger.debug("Federation denied with %s", server_name)
continue

if not key_ids:
key_ids = (None,)
for key_id in key_ids:
Expand All @@ -153,21 +146,28 @@ async def query_keys(

time_now_ms = self.clock.time_msec()

# Note that the value is unused.
# Map server_name->key_id->int. Note that the value of the init is unused.
# XXX: why don't we just use a set?
cache_misses: Dict[str, Dict[str, int]] = {}
for (server_name, key_id, _), key_results in cached.items():
results = [(result["ts_added_ms"], result) for result in key_results]

if not results and key_id is not None:
cache_misses.setdefault(server_name, {})[key_id] = 0
if key_id is None:
# all keys were requested. Just return what we have without worrying
# about validity
for _, result in results:
# Cast to bytes since postgresql returns a memoryview.
json_results.add(bytes(result["key_json"]))
continue

if key_id is not None:
miss = False
if not results:
miss = True
else:
ts_added_ms, most_recent_result = max(results)
ts_valid_until_ms = most_recent_result["ts_valid_until_ms"]
req_key = query.get(server_name, {}).get(key_id, {})
req_valid_until = req_key.get("minimum_valid_until_ts")
miss = False
if req_valid_until is not None:
if ts_valid_until_ms < req_valid_until:
logger.debug(
Expand Down Expand Up @@ -211,19 +211,20 @@ async def query_keys(
ts_valid_until_ms,
time_now_ms,
)

if miss:
cache_misses.setdefault(server_name, {})[key_id] = 0
# Cast to bytes since postgresql returns a memoryview.
json_results.add(bytes(most_recent_result["key_json"]))
else:
for _, result in results:
# Cast to bytes since postgresql returns a memoryview.
json_results.add(bytes(result["key_json"]))

if miss and query_remote_on_cache_miss:
# only bother attempting to fetch keys from servers on our whitelist
if (
self.federation_domain_whitelist is None
or server_name in self.federation_domain_whitelist
):
cache_misses.setdefault(server_name, {})[key_id] = 0

# If there is a cache miss, request the missing keys, then recurse (and
# ensure the result is sent).
if cache_misses and query_remote_on_cache_miss:
if cache_misses:
await yieldable_gather_results(
lambda t: self.fetcher.get_keys(*t),
(
Expand Down