diff --git a/Sources/PIRProcessDatabase/ProcessDatabase.swift b/Sources/PIRProcessDatabase/ProcessDatabase.swift index 605b9b35..3cedfe48 100644 --- a/Sources/PIRProcessDatabase/ProcessDatabase.swift +++ b/Sources/PIRProcessDatabase/ProcessDatabase.swift @@ -297,7 +297,7 @@ struct ResolvedArguments: CustomStringConvertible, Encodable { } @main -struct ProcessDatabase: ParsableCommand { +struct ProcessDatabase: AsyncParsableCommand { static let configuration: CommandConfiguration = .init( commandName: "PIRProcessDatabase") @@ -311,13 +311,18 @@ struct ProcessDatabase: ParsableCommand { """) var configFile: String + @Flag(name: .customLong("parallel"), + inversion: .prefixedNo, + help: "Enables parallel processing.") + var parallel = false + /// Performs the processing on the given database. /// - Parameters: /// - config: The configuration for the PIR processing. /// - scheme: The HE scheme. /// - Throws: Error upon processing the database. @inlinable - mutating func process(config: Arguments, scheme: Scheme.Type) throws { + mutating func process(config: Arguments, scheme: Scheme.Type) async throws { let database: [KeywordValuePair] = try Apple_SwiftHomomorphicEncryption_Pir_V1_KeywordDatabase(from: config.inputDatabase).native() @@ -339,69 +344,39 @@ struct ProcessDatabase: ParsableCommand { keyCompression: config.keyCompression, trialsPerShard: config.trialsPerShard) - var evaluationKeyConfig = EvaluationKeyConfig() let context = try Context(encryptionParameters: processArgs.encryptionParameters) let keywordDatabase = try KeywordDatabase(rows: database, sharding: processArgs.databaseConfig.sharding) - ProcessDatabase.logger - .info("Sharded database into \(keywordDatabase.shards.count) shards") - for (shardID, shard) in keywordDatabase.shards - .sorted(by: { $0.0.localizedStandardCompare($1.0) == .orderedAscending }) - { - func logEvent(event: ProcessKeywordDatabase.ProcessShardEvent) throws { - switch event { - case let .cuckooTableEvent(.createdTable(table)): - let summary = try table.summarize() - ProcessDatabase.logger.info("Created cuckoo table \(summary)") - case let .cuckooTableEvent(.expandingTable(table)): - let summary = try table.summarize() - ProcessDatabase.logger.info("Expanding cuckoo table \(summary)") - case let .cuckooTableEvent(.finishedExpandingTable(table)): - let summary = try table.summarize() - ProcessDatabase.logger.info("Finished expanding cuckoo table \(summary)") - case let .cuckooTableEvent(.insertedKeywordValuePair(index, _)): - let reportingPercentage = 10 - let shardFraction = shard.rows.count / reportingPercentage - if (index + 1).isMultiple(of: shardFraction) { - let percentage = Float(reportingPercentage * (index + 1)) / Float(shardFraction) - ProcessDatabase.logger - .info("Inserted \(index + 1) / \(shard.rows.count) keywords \(percentage)%") + ProcessDatabase.logger.info("Sharded database into \(keywordDatabase.shards.count) shards") + + let shards = keywordDatabase.shards.sorted { $0.0.localizedStandardCompare($1.0) == .orderedAscending } + + var evaluationKeyConfig = EvaluationKeyConfig() + if parallel { + try await withThrowingTaskGroup(of: EvaluationKeyConfig.self) { group in + for (shardID, shard) in shards { + group.addTask { @Sendable [self] in + try await processShard( + shardID: shardID, + shard: shard, + config: config, + context: context, + processArgs: processArgs) } } - } - ProcessDatabase.logger.info("Processing shard \(shardID) with \(shard.rows.count) rows") - let processed = try ProcessKeywordDatabase.processShard( - shard: shard, - with: processArgs, - onEvent: logEvent) - if config.trialsPerShard > 0 { - guard let row = shard.rows.first else { - throw PirError.emptyDatabase + for try await processedEvaluationKeyConfig in group { + evaluationKeyConfig = [evaluationKeyConfig, processedEvaluationKeyConfig].union() } - ProcessDatabase.logger.info("Validating shard \(shardID)") - let validationResults = try ProcessKeywordDatabase - .validateShard(shard: processed, - row: KeywordValuePair(keyword: row.key, value: row.value), - trials: config.trialsPerShard, context: context) - let description = try validationResults.description() - ProcessDatabase.logger.info("ValidationResults \(description)") } - - let outputDatabaseFilename = config.outputDatabase.replacingOccurrences( - of: "SHARD_ID", - with: String(shardID)) - try processed.database.save(to: outputDatabaseFilename) - ProcessDatabase.logger.info("Saved shard \(shardID) to \(outputDatabaseFilename)") - - let shardEvaluationKeyConfig = processed.evaluationKeyConfig - evaluationKeyConfig = [evaluationKeyConfig, shardEvaluationKeyConfig].union() - - let shardPirParameters = try processed.proto(context: context) - let outputParametersFilename = config.outputPirParameters.replacingOccurrences( - of: "SHARD_ID", - with: String(shardID)) - try shardPirParameters.save(to: outputParametersFilename) - ProcessDatabase.logger.info("Saved shard \(shardID) PIR parameters to \(outputParametersFilename)") + } else { + for (shardID, shard) in shards { + let processedEvaluationKeyConfig = try await processShard( + shardID: shardID, + shard: shard, config: + config, context: context, + processArgs: processArgs) + evaluationKeyConfig = [evaluationKeyConfig, processedEvaluationKeyConfig].union() + } } if let evaluationKeyConfigFile = config.outputEvaluationKeyConfig { @@ -411,14 +386,80 @@ struct ProcessDatabase: ParsableCommand { } } - mutating func run() throws { + private func processShard( + shardID: String, + shard: KeywordDatabaseShard, + config: ResolvedArguments, + context: Context, + processArgs: ProcessKeywordDatabase.Arguments) async throws -> EvaluationKeyConfig + { + var logger = ProcessDatabase.logger + logger[metadataKey: "shardID"] = .string(shardID) + + func logEvent(event: ProcessKeywordDatabase.ProcessShardEvent) throws { + switch event { + case let .cuckooTableEvent(.createdTable(table)): + let summary = try table.summarize() + logger.info("Created cuckoo table \(summary)") + case let .cuckooTableEvent(.expandingTable(table)): + let summary = try table.summarize() + logger.info("Expanding cuckoo table \(summary)") + case let .cuckooTableEvent(.finishedExpandingTable(table)): + let summary = try table.summarize() + logger.info("Finished expanding cuckoo table \(summary)") + case let .cuckooTableEvent(.insertedKeywordValuePair(index, _)): + let reportingPercentage = 10 + let shardFraction = shard.rows.count / reportingPercentage + if (index + 1).isMultiple(of: shardFraction) { + let percentage = Float(reportingPercentage * (index + 1)) / Float(shardFraction) + logger.info("Inserted \(index + 1) / \(shard.rows.count) keywords \(percentage)%") + } + } + } + + logger.info("Processing shard \(shardID) with \(shard.rows.count) rows") + let processed = try ProcessKeywordDatabase.processShard( + shard: shard, + with: processArgs, + onEvent: logEvent) + + if config.trialsPerShard > 0 { + guard let row = shard.rows.first else { + throw PirError.emptyDatabase + } + logger.info("Validating shard \(shardID)") + let validationResults = try ProcessKeywordDatabase + .validateShard(shard: processed, + row: KeywordValuePair(keyword: row.key, value: row.value), + trials: config.trialsPerShard, context: context) + let description = try validationResults.description() + logger.info("ValidationResults \(description)") + } + + let outputDatabaseFilename = config.outputDatabase.replacingOccurrences( + of: "SHARD_ID", + with: String(shardID)) + try processed.database.save(to: outputDatabaseFilename) + logger.info("Saved shard \(shardID) to \(outputDatabaseFilename)") + + let shardPirParameters = try processed.proto(context: context) + let outputParametersFilename = config.outputPirParameters.replacingOccurrences( + of: "SHARD_ID", + with: String(shardID)) + try shardPirParameters.save(to: outputParametersFilename) + logger.info("Saved shard \(shardID) PIR parameters to \(outputParametersFilename)") + + return processed.evaluationKeyConfig + } + + mutating func run() async throws { let configURL = URL(fileURLWithPath: configFile) let configData = try Data(contentsOf: configURL) let config = try JSONDecoder().decode(Arguments.self, from: configData) if config.rlweParameters.supportsScalar(UInt32.self) { - try process(config: config, scheme: Bfv.self) + try await process(config: config, scheme: Bfv.self) } else { - try process(config: config, scheme: Bfv.self) + try await process(config: config, scheme: Bfv.self) } } } diff --git a/Sources/PrivateInformationRetrieval/KeywordDatabase.swift b/Sources/PrivateInformationRetrieval/KeywordDatabase.swift index cf5a80a3..9f7a4b4f 100644 --- a/Sources/PrivateInformationRetrieval/KeywordDatabase.swift +++ b/Sources/PrivateInformationRetrieval/KeywordDatabase.swift @@ -148,7 +148,7 @@ extension Sharding { } /// A shard of a ``KeywordDatabase``. -public struct KeywordDatabaseShard: Hashable, Codable { +public struct KeywordDatabaseShard: Hashable, Codable, Sendable { /// Identifier for the shard. public let shardID: String /// Rows in the database. @@ -204,7 +204,7 @@ extension KeywordDatabaseShard: Collection { } /// Configuration for a ``KeywordDatabase``. -public struct KeywordDatabaseConfig: Hashable, Codable { +public struct KeywordDatabaseConfig: Hashable, Codable, Sendable { public let sharding: Sharding public let keywordPirConfig: KeywordPirConfig @@ -264,7 +264,7 @@ public struct KeywordDatabase { /// Utilities for processing a ``KeywordDatabase``. public enum ProcessKeywordDatabase { /// Arguments for processing a keyword database. - public struct Arguments: Codable { + public struct Arguments: Codable, Sendable { /// Database configuration. public let databaseConfig: KeywordDatabaseConfig /// Encryption parameters. diff --git a/Sources/PrivateInformationRetrieval/KeywordPirProtocol.swift b/Sources/PrivateInformationRetrieval/KeywordPirProtocol.swift index bb61c2e4..6d3db0d1 100644 --- a/Sources/PrivateInformationRetrieval/KeywordPirProtocol.swift +++ b/Sources/PrivateInformationRetrieval/KeywordPirProtocol.swift @@ -16,7 +16,7 @@ import Foundation import HomomorphicEncryption /// Configuration for a ``KeywordDatabase``. -public struct KeywordPirConfig: Hashable, Codable { +public struct KeywordPirConfig: Hashable, Codable, Sendable { /// Number of dimensions in the database. @usableFromInline let dimensionCount: Int