Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Rewrite get_server_verify_keys, again. #5299

Merged
merged 2 commits into from
May 31, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/5299.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Rewrite get_server_verify_keys, again.
101 changes: 53 additions & 48 deletions synapse/crypto/keyring.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,59 +280,21 @@ def _get_server_verify_keys(self, verify_requests):
verify_requests (list[VerifyKeyRequest]): list of verify requests
"""

remaining_requests = set(
(rq for rq in verify_requests if not rq.deferred.called)
)

@defer.inlineCallbacks
def do_iterations():
with Measure(self.clock, "get_server_verify_keys"):
# dict[str, set(str)]: keys to fetch for each server
missing_keys = {}
for verify_request in verify_requests:
missing_keys.setdefault(verify_request.server_name, set()).update(
verify_request.key_ids
)

for f in self._key_fetchers:
results = yield f.get_keys(missing_keys.items())

# We now need to figure out which verify requests we have keys
# for and which we don't
missing_keys = {}
requests_missing_keys = []
for verify_request in verify_requests:
if verify_request.deferred.called:
# We've already called this deferred, which probably
# means that we've already found a key for it.
continue

server_name = verify_request.server_name

# see if any of the keys we got this time are sufficient to
# complete this VerifyKeyRequest.
result_keys = results.get(server_name, {})
for key_id in verify_request.key_ids:
fetch_key_result = result_keys.get(key_id)
if fetch_key_result:
with PreserveLoggingContext():
verify_request.deferred.callback(
(
server_name,
key_id,
fetch_key_result.verify_key,
)
)
break
else:
# The else block is only reached if the loop above
# doesn't break.
missing_keys.setdefault(server_name, set()).update(
verify_request.key_ids
)
requests_missing_keys.append(verify_request)

if not missing_keys:
break
if not remaining_requests:
return
yield self._attempt_key_fetches_with_fetcher(f, remaining_requests)

# look for any requests which weren't satisfied
with PreserveLoggingContext():
for verify_request in requests_missing_keys:
for verify_request in remaining_requests:
verify_request.deferred.errback(
SynapseError(
401,
Expand All @@ -343,13 +305,56 @@ def do_iterations():
)

def on_err(err):
# we don't really expect to get here, because any errors should already
# have been caught and logged. But if we do, let's log the error and make
# sure that all of the deferreds are resolved.
logger.error("Unexpected error in _get_server_verify_keys: %s", err)
with PreserveLoggingContext():
for verify_request in verify_requests:
for verify_request in remaining_requests:
if not verify_request.deferred.called:
verify_request.deferred.errback(err)

run_in_background(do_iterations).addErrback(on_err)

@defer.inlineCallbacks
def _attempt_key_fetches_with_fetcher(self, fetcher, remaining_requests):
"""Use a key fetcher to attempt to satisfy some key requests

Args:
fetcher (KeyFetcher): fetcher to use to fetch the keys
remaining_requests (set[VerifyKeyRequest]): outstanding key requests.
Any successfully-completed requests will be reomved from the list.
"""
# dict[str, set(str)]: keys to fetch for each server
missing_keys = {}
for verify_request in remaining_requests:
# any completed requests should already have been removed
assert not verify_request.deferred.called
missing_keys.setdefault(verify_request.server_name, set()).update(
verify_request.key_ids
)

results = yield fetcher.get_keys(missing_keys.items())

completed = list()
for verify_request in remaining_requests:
server_name = verify_request.server_name

# see if any of the keys we got this time are sufficient to
# complete this VerifyKeyRequest.
result_keys = results.get(server_name, {})
for key_id in verify_request.key_ids:
key = result_keys.get(key_id)
if key:
with PreserveLoggingContext():
verify_request.deferred.callback(
(server_name, key_id, key.verify_key)
)
completed.append(verify_request)
break

remaining_requests.difference_update(completed)


class KeyFetcher(object):
def get_keys(self, server_name_and_key_ids):
Expand Down