diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 6d7f2792dd0d..34e9796b8c6e 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -471,9 +471,10 @@ async def claim_client_keys( # Create the expected payload shape. body: Dict[str, Dict[str, List[str]]] = {} - for user_id, device, algorithm, _count in query: - # Note that only a single OTK can be claimed this way. - body.setdefault(user_id, {}).setdefault(device, []).append(algorithm) + for user_id, device, algorithm, count in query: + body.setdefault(user_id, {}).setdefault(device, []).extend( + [algorithm] * count + ) uri = f"{service.url}/_matrix/app/unstable/org.matrix.msc3983/keys/claim" try: diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index c3e03aec8d96..dee8e957a4a4 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -237,7 +237,7 @@ async def query_user_devices( async def claim_client_keys( self, destination: str, - content: Dict[str, Dict[str, Dict[str, int]]], + query: Dict[str, Dict[str, Dict[str, int]]], timeout: Optional[int], ) -> JsonDict: """Claims one-time keys for a device hosted on a remote server. @@ -251,24 +251,33 @@ async def claim_client_keys( """ sent_queries_counter.labels("client_one_time_keys").inc() - # Convert the query with counts into a legacy query and check if attempting - # to claim more than 1 OTK. - legacy_content: Dict[str, Dict[str, str]] = {} + # Convert the query with counts into a stable and unstable query and check + # if attempting to claim more than 1 OTK. + content: Dict[str, Dict[str, str]] = {} + unstable_content: Dict[str, Dict[str, List[str]]] = {} use_unstable = False - for user_id, one_time_keys in content.items(): + for user_id, one_time_keys in query.items(): for device_id, algorithms in one_time_keys.items(): if any(count > 1 for count in algorithms.values()): use_unstable = True if algorithms: - # Choose the first algorithm only. - legacy_content.setdefault(user_id, {})[device_id] = next( - iter(algorithms) + # Choose the first algorithm only for the stable query. + content.setdefault(user_id, {})[device_id] = next(iter(algorithms)) + # Flatten the map of algorithm -> count to a list repeating + # each algorithm count times for the unstable query. + unstable_content.setdefault(user_id, {})[device_id] = list( + itertools.chain( + *( + itertools.repeat(algorithm, count) + for algorithm, count in algorithms.items() + ) + ) ) if use_unstable: try: return await self.transport_layer.claim_client_keys_unstable( - destination, content, timeout + destination, unstable_content, timeout ) except HttpResponseException as e: # If an error is received that is due to an unrecognised endpoint, @@ -284,7 +293,7 @@ async def claim_client_keys( logger.debug("Skipping unstable claim client keys API") return await self.transport_layer.claim_client_keys( - destination, legacy_content, timeout + destination, content, timeout ) @trace diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py index 5bf0629b7fa0..36b0362504f5 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from collections import Counter from typing import ( TYPE_CHECKING, Dict, @@ -577,7 +578,7 @@ class FederationClientKeysClaimServlet(BaseFederationServerServlet): async def on_POST( self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] ) -> Tuple[int, JsonDict]: - # Flatten the request query. + # Generate a count for each algorithm, which is hard-coded to 1. key_query: List[Tuple[str, str, str, int]] = [] for user_id, device_keys in content.get("one_time_keys", {}).items(): for device_id, algorithm in device_keys.items(): @@ -603,11 +604,12 @@ class FederationUnstableClientKeysClaimServlet(BaseFederationServerServlet): async def on_POST( self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] ) -> Tuple[int, JsonDict]: - # Flatten the request query. + # Generate a count for each algorithm. key_query: List[Tuple[str, str, str, int]] = [] for user_id, device_keys in content.get("one_time_keys", {}).items(): for device_id, algorithms in device_keys.items(): - for algorithm, count in algorithms.items(): + counts = Counter(algorithms) + for algorithm, count in counts.items(): key_query.append((user_id, device_id, algorithm, count)) response = await self.handler.on_claim_client_keys( diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py index cceffde7dbdc..9bbab5e6241e 100644 --- a/synapse/rest/client/keys.py +++ b/synapse/rest/client/keys.py @@ -16,6 +16,7 @@ import logging import re +from collections import Counter from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple from synapse.api.errors import InvalidAPICallError, SynapseError @@ -290,7 +291,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: timeout = parse_integer(request, "timeout", 10 * 1000) body = parse_json_object_from_request(request) - # Map the legacy request to the new request format. + # Generate a count for each algorithm, which is hard-coded to 1. query: Dict[str, Dict[str, Dict[str, int]]] = {} for user_id, one_time_keys in body.get("one_time_keys", {}).items(): for device_id, algorithm in one_time_keys.items(): @@ -312,9 +313,8 @@ class UnstableOneTimeKeyServlet(RestServlet): { "one_time_keys": { "": { - "": { - "": - } } } } + "": ["", ...] + } } } HTTP/1.1 200 OK { @@ -338,7 +338,13 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await self.auth.get_user_by_req(request, allow_guest=True) timeout = parse_integer(request, "timeout", 10 * 1000) body = parse_json_object_from_request(request) - query = body.get("one_time_keys", {}) + + # Generate a count for each algorithm. + query: Dict[str, Dict[str, Dict[str, int]]] = {} + for user_id, one_time_keys in body.get("one_time_keys", {}).items(): + for device_id, algorithms in one_time_keys.items(): + query.setdefault(user_id, {})[device_id] = Counter(algorithms) + result = await self.e2e_keys_handler.claim_one_time_keys( query, timeout, always_include_fallback_keys=True )