Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add keys query scheduler #1676

Merged
merged 2 commits into from
Jan 11, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions MatrixSDK.xcodeproj/project.pbxproj
Original file line number Diff line number Diff line change
Expand Up @@ -1995,6 +1995,10 @@
EDBCF33A281A8D3D00ED5044 /* MXSharedHistoryKeyService.m in Sources */ = {isa = PBXBuildFile; fileRef = EDBCF338281A8D3D00ED5044 /* MXSharedHistoryKeyService.m */; };
EDC2A0E628369E740039F3D6 /* CryptoTests.xctestplan in Resources */ = {isa = PBXBuildFile; fileRef = EDC2A0E528369E740039F3D6 /* CryptoTests.xctestplan */; };
EDC2A0E728369E740039F3D6 /* CryptoTests.xctestplan in Resources */ = {isa = PBXBuildFile; fileRef = EDC2A0E528369E740039F3D6 /* CryptoTests.xctestplan */; };
EDC8C4082968A993003792C5 /* MXKeysQueryScheduler.swift in Sources */ = {isa = PBXBuildFile; fileRef = EDC8C4072968A993003792C5 /* MXKeysQueryScheduler.swift */; };
EDC8C4092968A993003792C5 /* MXKeysQueryScheduler.swift in Sources */ = {isa = PBXBuildFile; fileRef = EDC8C4072968A993003792C5 /* MXKeysQueryScheduler.swift */; };
EDC8C40D2968C37E003792C5 /* MXKeysQuerySchedulerUnitTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = EDC8C40A2968A9F7003792C5 /* MXKeysQuerySchedulerUnitTests.swift */; };
EDC8C40E2968C37F003792C5 /* MXKeysQuerySchedulerUnitTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = EDC8C40A2968A9F7003792C5 /* MXKeysQuerySchedulerUnitTests.swift */; };
EDCB65E22912AB0C00F55D4D /* MXRoomEventDecryption.swift in Sources */ = {isa = PBXBuildFile; fileRef = EDCB65E12912AB0C00F55D4D /* MXRoomEventDecryption.swift */; };
EDCB65E32912AB0C00F55D4D /* MXRoomEventDecryption.swift in Sources */ = {isa = PBXBuildFile; fileRef = EDCB65E12912AB0C00F55D4D /* MXRoomEventDecryption.swift */; };
EDD4197E28DCAA5F007F3757 /* MXNativeKeyBackupEngine.h in Headers */ = {isa = PBXBuildFile; fileRef = EDD4197D28DCAA5F007F3757 /* MXNativeKeyBackupEngine.h */; };
Expand Down Expand Up @@ -3120,6 +3124,8 @@
EDBCF335281A8AB900ED5044 /* MXSharedHistoryKeyService.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = MXSharedHistoryKeyService.h; sourceTree = "<group>"; };
EDBCF338281A8D3D00ED5044 /* MXSharedHistoryKeyService.m */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.objc; path = MXSharedHistoryKeyService.m; sourceTree = "<group>"; };
EDC2A0E528369E740039F3D6 /* CryptoTests.xctestplan */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; path = CryptoTests.xctestplan; sourceTree = "<group>"; };
EDC8C4072968A993003792C5 /* MXKeysQueryScheduler.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MXKeysQueryScheduler.swift; sourceTree = "<group>"; };
EDC8C40A2968A9F7003792C5 /* MXKeysQuerySchedulerUnitTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MXKeysQuerySchedulerUnitTests.swift; sourceTree = "<group>"; };
EDCB65E12912AB0C00F55D4D /* MXRoomEventDecryption.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MXRoomEventDecryption.swift; sourceTree = "<group>"; };
EDD4197D28DCAA5F007F3757 /* MXNativeKeyBackupEngine.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = MXNativeKeyBackupEngine.h; sourceTree = "<group>"; };
EDD4198028DCAA7B007F3757 /* MXNativeKeyBackupEngine.m */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.objc; path = MXNativeKeyBackupEngine.m; sourceTree = "<group>"; };
Expand Down Expand Up @@ -5395,6 +5401,7 @@
ED8F1D3A2885BB2D00F897E7 /* MXCryptoProtocols.swift */,
ED2DD112286C450600F06731 /* MXEventDecryptionResult+DecryptedEvent.swift */,
ED2DD113286C450600F06731 /* MXCryptoRequests.swift */,
EDC8C4072968A993003792C5 /* MXKeysQueryScheduler.swift */,
);
path = CryptoMachine;
sourceTree = "<group>";
Expand All @@ -5407,6 +5414,7 @@
ED2DD11B286C4F3E00F06731 /* MXCryptoRequestsUnitTests.swift */,
ED8F1D312885AC5700F897E7 /* Device+Stub.swift */,
ED1FE90A2912E13A0046F722 /* DecryptedEvent+Stub.swift */,
EDC8C40A2968A9F7003792C5 /* MXKeysQuerySchedulerUnitTests.swift */,
);
path = CryptoMachine;
sourceTree = "<group>";
Expand Down Expand Up @@ -7001,6 +7009,7 @@
ECD2897D26E8F06F00F268CF /* MXStoreRoomListDataFetcher.swift in Sources */,
32D2CC0323422462002BD8CA /* MX3PidAddSession.m in Sources */,
3283F7791EAF30F700C1688C /* MXBugReportRestClient.m in Sources */,
EDC8C4082968A993003792C5 /* MXKeysQueryScheduler.swift in Sources */,
32B0E33B23A2989A0054FF1A /* MXEventReferenceChunk.m in Sources */,
A780625027B2CE74005780C0 /* FileManager+AppGroupContainer.swift in Sources */,
9274AFE91EE580240009BEB6 /* MXCallKitAdapter.m in Sources */,
Expand Down Expand Up @@ -7253,6 +7262,7 @@
322A51D81D9E846800C8536D /* MXCryptoTests.m in Sources */,
B146D4FF21A5C0BD00D8C2C6 /* MXMediaScanStoreUnitTests.m in Sources */,
32BD34BE1E84134A006EDC0D /* MatrixSDKTestsE2EData.m in Sources */,
EDC8C40E2968C37F003792C5 /* MXKeysQuerySchedulerUnitTests.swift in Sources */,
ED751DAE28EDEC7E003748C3 /* MXKeyVerificationStateResolverUnitTests.swift in Sources */,
ED6E87A9294B3BAB00100D9C /* MXAnalyticsDestinationUnitTests.swift in Sources */,
B146D4FE21A5C0BD00D8C2C6 /* MXEventScanStoreUnitTests.m in Sources */,
Expand Down Expand Up @@ -7650,6 +7660,7 @@
EC1165B727107E330089FA56 /* MXStoreRoomListDataCounts.swift in Sources */,
B14EF2392397E90400758AF0 /* MXMediaManager.m in Sources */,
B14EF23A2397E90400758AF0 /* MXHTTPOperation.m in Sources */,
EDC8C4092968A993003792C5 /* MXKeysQueryScheduler.swift in Sources */,
A780625127B2CE74005780C0 /* FileManager+AppGroupContainer.swift in Sources */,
B14EF23B2397E90400758AF0 /* MXKeyBackupData.m in Sources */,
B14EF23C2397E90400758AF0 /* MXJSONModels.m in Sources */,
Expand Down Expand Up @@ -7902,6 +7913,7 @@
32B090FE26201C8D002924AA /* MXAsyncTaskQueueUnitTests.swift in Sources */,
B1E09A242397FCE90057C069 /* MXPeekingRoomTests.m in Sources */,
B1E09A452397FD990057C069 /* MXLazyLoadingTests.m in Sources */,
EDC8C40D2968C37E003792C5 /* MXKeysQuerySchedulerUnitTests.swift in Sources */,
ED751DAF28EDEC7E003748C3 /* MXKeyVerificationStateResolverUnitTests.swift in Sources */,
ED6E87AA294B3BAB00100D9C /* MXAnalyticsDestinationUnitTests.swift in Sources */,
B1E09A1C2397FCE90057C069 /* MXEventAnnotationUnitTests.swift in Sources */,
Expand Down
2 changes: 1 addition & 1 deletion MatrixSDK/Crypto/CrossSigning/MXCrossSigningV2.swift
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class MXCrossSigningV2: NSObject, MXCrossSigning {

Task {
do {
try await crossSigning.downloadKeys(users: [crossSigning.userId])
try await crossSigning.refreshCrossSigningStatus()
myUserCrossSigningKeys = infoSource.crossSigningInfo(userId: crossSigning.userId)

log.debug("Cross signing state refreshed")
Expand Down
83 changes: 45 additions & 38 deletions MatrixSDK/Crypto/CryptoMachine/MXCryptoMachine.swift
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class MXCryptoMachine {

private let machine: OlmMachine
private let requests: MXCryptoRequests
private let queryScheduler: MXKeysQueryScheduler<MXKeysQueryResponse>
private let getRoomAction: GetRoomAction

private let sessionsQueue = MXTaskQueue()
Expand All @@ -72,15 +73,25 @@ class MXCryptoMachine {

private let log = MXNamedLog(name: "MXCryptoMachine")

init(userId: String, deviceId: String, restClient: MXRestClient, getRoomAction: @escaping GetRoomAction) throws {
init(
userId: String,
deviceId: String,
restClient: MXRestClient,
getRoomAction: @escaping GetRoomAction
) throws {
let url = try Self.storeURL(for: userId)
machine = try OlmMachine(
userId: userId,
deviceId: deviceId,
path: url.path,
passphrase: nil
)
requests = MXCryptoRequests(restClient: restClient)
let requests = MXCryptoRequests(restClient: restClient)
self.requests = requests

queryScheduler = MXKeysQueryScheduler { users in
try await requests.queryKeys(users: users)
}
self.getRoomAction = getRoomAction

setLogger(logger: self)
Expand Down Expand Up @@ -211,6 +222,31 @@ extension MXCryptoMachine: MXCryptoSyncing {
return toDevice
}

func downloadKeysIfNecessary(users: [String]) async throws {
machine.updateTrackedUsers(users: users)
try await withThrowingTaskGroup(of: Void.self) { [weak self] group in
guard let self = self else { return }

for request in try machine.outgoingRequests() {
if case .keysQuery(_, let requestUsers) = request {
let usersInCommon = Set(requestUsers).intersection(users)
if !usersInCommon.isEmpty {
try await self.handleRequest(request)
return
}
}
}
}
}

@available(*, deprecated, message: "The application should not manually force reload keys, use `downloadKeysIfNecessary` instead")
func reloadKeys(users: [String]) async throws {
machine.updateTrackedUsers(users: users)
try await handleRequest(
.keysQuery(requestId: UUID().uuidString, users: users)
)
}

func processOutgoingRequests() async throws {
try await syncQueue.sync { [weak self] in
try await self?.handleOutgoingRequests()
Expand All @@ -234,7 +270,8 @@ extension MXCryptoMachine: MXCryptoSyncing {
try markRequestAsSent(requestId: requestId, requestType: .keysUpload, response: response.jsonString())

case .keysQuery(let requestId, let users):
let response = try await requests.queryKeys(users: users)
// Key queries go through a scheduler layer instead of directly through the rest client
let response = try await queryScheduler.query(users: Set(users))
try markRequestAsSent(requestId: requestId, requestType: .keysQuery, response: response.jsonString())

case .keysClaim(let requestId, let oneTimeKeys):
Expand Down Expand Up @@ -341,21 +378,6 @@ extension MXCryptoMachine: MXCryptoUserIdentitySource {
}
}

func isUserTracked(userId: String) -> Bool {
do {
return try machine.isUserTracked(userId: userId)
} catch {
log.error("Failed checking user tracking")
return false
}
}

func downloadKeys(users: [String]) async throws {
try await handleRequest(
.keysQuery(requestId: UUID().uuidString, users: users)
)
}

func verifyUser(userId: String) async throws {
let request = try machine.verifyIdentity(userId: userId)
try await requests.uploadSignatures(request: request)
Expand All @@ -378,7 +400,7 @@ extension MXCryptoMachine: MXCryptoRoomEventEncrypting {
settings: EncryptionSettings
) async throws {
try await sessionsQueue.sync { [weak self] in
try await self?.updateTrackedUsers(users: users)
try await self?.downloadKeysIfNecessary(users: users)
try await self?.getMissingSessions(users: users)
}

Expand Down Expand Up @@ -411,25 +433,6 @@ extension MXCryptoMachine: MXCryptoRoomEventEncrypting {

// MARK: - Private

private func updateTrackedUsers(users: [String]) async throws {
machine.updateTrackedUsers(users: users)
try await withThrowingTaskGroup(of: Void.self) { [weak self] group in
guard let self = self else { return }

for request in try machine.outgoingRequests() {
guard case .keysQuery = request else {
continue
}

group.addTask {
try await self.handleRequest(request)
}
}

try await group.waitForAll()
}
}

private func getMissingSessions(users: [String]) async throws {
guard
let request = try machine.getMissingSessions(users: users),
Expand Down Expand Up @@ -484,6 +487,10 @@ extension MXCryptoMachine: MXCryptoRoomEventDecrypting {
}

extension MXCryptoMachine: MXCryptoCrossSigning {
func refreshCrossSigningStatus() async throws {
try await reloadKeys(users: [userId])
}

func crossSigningStatus() -> CrossSigningStatus {
return machine.crossSigningStatus()
}
Expand Down
8 changes: 6 additions & 2 deletions MatrixSDK/Crypto/CryptoMachine/MXCryptoProtocols.swift
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ protocol MXCryptoSyncing: MXCryptoIdentity {
) throws -> MXToDeviceSyncResponse

func processOutgoingRequests() async throws

func downloadKeysIfNecessary(users: [String]) async throws

@available(*, deprecated, message: "The application should not manually force reload keys, use `downloadKeysIfNecessary` instead")
func reloadKeys(users: [String]) async throws
}

/// Source of user devices and their cryptographic trust status
Expand All @@ -52,8 +57,6 @@ protocol MXCryptoDevicesSource: MXCryptoIdentity {
protocol MXCryptoUserIdentitySource: MXCryptoIdentity {
func userIdentity(userId: String) -> UserIdentity?
func isUserVerified(userId: String) -> Bool
func isUserTracked(userId: String) -> Bool
func downloadKeys(users: [String]) async throws
func verifyUser(userId: String) async throws
func verifyDevice(userId: String, deviceId: String) async throws
func setLocalTrust(userId: String, deviceId: String, trust: LocalTrust) throws
Expand All @@ -74,6 +77,7 @@ protocol MXCryptoRoomEventDecrypting: MXCryptoIdentity {

/// Cross-signing functionality
protocol MXCryptoCrossSigning: MXCryptoUserIdentitySource {
func refreshCrossSigningStatus() async throws
func crossSigningStatus() -> CrossSigningStatus
func bootstrapCrossSigning(authParams: [AnyHashable: Any]) async throws
func exportCrossSigningKeys() -> CrossSigningKeyExport?
Expand Down
11 changes: 9 additions & 2 deletions MatrixSDK/Crypto/CryptoMachine/MXCryptoRequests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,15 @@ struct MXCryptoRequests {
}

func queryKeys(users: [String]) async throws -> MXKeysQueryResponse {
return try await performCallbackRequest {
restClient.downloadKeys(forUsers: users, completion: $0)
return try await performCallbackRequest { completion in
_ = restClient.downloadKeysByChunk(
forUsers: users,
token: nil,
success: {
completion(.success($0))
}, failure: {
completion(.failure($0 ?? Error.unknownError))
})
}
}

Expand Down
112 changes: 112 additions & 0 deletions MatrixSDK/Crypto/CryptoMachine/MXKeysQueryScheduler.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
//
// Copyright 2023 The Matrix.org Foundation C.I.C
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//

import Foundation

/// A schedule of `keys/query` requests that will ensure only one request
/// is in-flight at any given point in time, and all future queries aggregate
/// requested user ids into a single query.
public actor MXKeysQueryScheduler<Response> {
typealias QueryAction = ([String]) async throws -> Response

struct Query {
let users: Set<String>
let task: Task<Response, Error>

func contains(users: Set<String>) -> Bool {
users.subtracting(self.users).isEmpty
}
}

private let queryAction: QueryAction
private var nextUsers: Set<String>

private var currentQuery: Query?
private var nextTask: Task<Response, Error>?

init(queryAction: @escaping QueryAction) {
self.queryAction = queryAction
self.nextUsers = []
}

/// Query a list of user ids
///
/// If there is no ongoing query, it will be executed right away,
/// otherwise it will be scheduled for the next available run.
public func query(users: Set<String>) async throws -> Response {
log("Querying \(users.count) user(s) ...")

let task = currentOrNextQuery(users: users)
return try await task.value
}

private func currentOrNextQuery(users: Set<String>) -> Task<Response, Error> {
if let currentQuery = currentQuery {
if currentQuery.contains(users: users) {
log("... query already running")

return currentQuery.task

} else {
log("... queueing users for the next query")

nextUsers = nextUsers.union(users)

let task = nextTask ?? Task {
// Next task needs to await to completion of the currently running task
let _ = await currentQuery.task.result

// Extract and reset next users
let users = nextUsers
nextUsers = []

// Only then we can execute the actual work
return try await executeQuery(users: users)
}
nextTask = task
return task
}

} else {
log("... query starting")

let task = Task {
// Since we do not have any task running we can execute work right away
try await executeQuery(users: users)
}
currentQuery = .init(users: users, task: task)
return task
}
}

private func executeQuery(users: Set<String>) async throws -> Response {
defer {
if let next = nextTask {
log("... query completed, starting next pending query.")
currentQuery = .init(users: users, task: next)
} else {
log("... query completed, no other queries scheduled.")
currentQuery = nil
}
nextTask = nil
}
return try await queryAction(Array(users))
}

private func log(_ message: String) {
MXLog.debug("[MXKeysQueryScheduler]: \(message)")
}
}
Loading