diff --git a/Sources/PrivateInformationRetrieval/MulPir.swift b/Sources/PrivateInformationRetrieval/MulPir.swift index 8c899769..c9819bed 100644 --- a/Sources/PrivateInformationRetrieval/MulPir.swift +++ b/Sources/PrivateInformationRetrieval/MulPir.swift @@ -77,7 +77,7 @@ public enum MulPir: IndexPirProtocol { static func evaluationKeyConfiguration( expandedQueryCount: Int, degree: Int, - keyCompression: PirKeyCompressionStrategy) -> HomomorphicEncryption.EvaluationKeyConfiguration + keyCompression: PirKeyCompressionStrategy) -> EvaluationKeyConfiguration { let maxExpansionDepth = min(expandedQueryCount, degree).ceilLog2 let smallestPower = degree.log2 - maxExpansionDepth + 1 diff --git a/Sources/PrivateNearestNeighborsSearch/CiphertextMatrix.swift b/Sources/PrivateNearestNeighborsSearch/CiphertextMatrix.swift index 7c28fc00..7d50cd63 100644 --- a/Sources/PrivateNearestNeighborsSearch/CiphertextMatrix.swift +++ b/Sources/PrivateNearestNeighborsSearch/CiphertextMatrix.swift @@ -17,7 +17,7 @@ import HomomorphicEncryption /// Stores a matrix of scalars as ciphertexts. public struct CiphertextMatrix: Equatable, Sendable { /// Dimensions of the matrix. - @usableFromInline let dimensions: MatrixDimensions + @usableFromInline var dimensions: MatrixDimensions /// Dimensions of the scalar matrix in a SIMD-encoded plaintext. @usableFromInline let simdDimensions: SimdEncodingDimensions @@ -26,7 +26,7 @@ public struct CiphertextMatrix: Equatable, @usableFromInline let packing: MatrixPacking /// Encrypted data. - @usableFromInline let ciphertexts: [Ciphertext] + @usableFromInline package var ciphertexts: [Ciphertext] /// The parameter context. @usableFromInline var context: Context { @@ -111,8 +111,8 @@ extension CiphertextMatrix { ciphertexts: evalCiphertexts) } - /// Converts the plaintext matrix to ``Coeff`` format. - /// - Returns: The converted plaintext ciphertext. + /// Converts the ciphertext matrix to ``Coeff`` format. + /// - Returns: The converted ciphertext matrix. /// - Throws: Error upon failure to convert the ciphertext matrix. @inlinable public func convertToCoeffFormat() throws -> CiphertextMatrix { @@ -126,6 +126,33 @@ extension CiphertextMatrix { packing: packing, ciphertexts: coeffCiphertexts) } + + /// Converts the ciphertext matrix to canonical format. + /// - Returns: The converted ciphertext matrix. + /// - Throws: Error upon failure to convert the ciphertext matrix. + @inlinable + public func convertToCanonicalFormat() throws -> CiphertextMatrix { + if Scheme.CanonicalCiphertextFormat.self == Coeff.self { + // swiftlint:disable:next force_cast + return try convertToCoeffFormat() as! CiphertextMatrix + } + if Scheme.CanonicalCiphertextFormat.self == Eval.self { + // swiftlint:disable:next force_cast + return try convertToEvalFormat() as! CiphertextMatrix + } + fatalError("Unsupported Format \(Format.description)") + } + + /// Performs modulus switching to a single modulus. + /// + /// If the ciphertexts already have a single modulus, this is a no-op. + /// - Throws: Error upon failure to modulus switch. + @inlinable + public mutating func modSwitchDownToSingle() throws where Format == Scheme.CanonicalCiphertextFormat { + for index in 0.. Double { + try ciphertexts.map { ciphertext in + try ciphertext.noiseBudget(using: secretKey, variableTime: variableTime) + }.min() ?? -Double.infinity + } } diff --git a/Sources/PrivateNearestNeighborsSearch/Client.swift b/Sources/PrivateNearestNeighborsSearch/Client.swift index 19fc0efe..a5494977 100644 --- a/Sources/PrivateNearestNeighborsSearch/Client.swift +++ b/Sources/PrivateNearestNeighborsSearch/Client.swift @@ -12,51 +12,57 @@ // See the License for the specific language governing permissions and // limitations under the License. -import Algorithms import Foundation import HomomorphicEncryption /// Private nearest neighbors client. -struct Client { +public struct Client { /// Configuration. - let config: ClientConfig + public let config: ClientConfig /// One context per plaintext modulus. - let contexts: [Context] + @usableFromInline let contexts: [Context] /// Performs composition of the plaintext CRT responses. - let crtComposer: CrtComposer + @usableFromInline let crtComposer: CrtComposer /// Context for the plaintext CRT moduli. - let plaintextContext: PolyContext + @usableFromInline let plaintextContext: PolyContext - var evaluationKeyConfiguration: HomomorphicEncryption.EvaluationKeyConfiguration { + /// The evaluation key configuration used by the ``Server``. + public var evaluationKeyConfiguration: EvaluationKeyConfiguration { config.evaluationKeyConfig } /// Creates a new ``Client``. - /// - Parameter config: Client configuration. - /// - Throws: Error upon failure to create a new client. + /// - Parameters: + /// - config: Client configuration. + /// - contexts: Contexts for HE computation, one per plaintext modulus. + /// - Throws: Error upon failure to create the client. @inlinable - init(config: ClientConfig) throws { + public init(config: ClientConfig, contexts: [Context] = []) throws { guard config.distanceMetric == .cosineSimilarity else { throw PnnsError.wrongDistanceMetric(got: config.distanceMetric, expected: .cosineSimilarity) } self.config = config - let extraEncryptionParams = try config.extraPlaintextModuli.map { plaintextModulus in - try EncryptionParameters( - polyDegree: config.encryptionParams.polyDegree, - plaintextModulus: plaintextModulus, - coefficientModuli: config.encryptionParams.coefficientModuli, - errorStdDev: config.encryptionParams.errorStdDev, - securityLevel: config.encryptionParams.securityLevel) - } - let encryptionParams = [config.encryptionParams] + extraEncryptionParams - self.contexts = try encryptionParams.map { encryptionParams in - try Context(encryptionParameters: encryptionParams) + + if !contexts.isEmpty { + precondition(contexts.count == config.encryptionParameters.count) + for (context, encryptionParameters) in zip(contexts, config.encryptionParameters) { + guard context.encryptionParameters == encryptionParameters else { + throw PnnsError.wrongEncryptionParameters( + got: context.encryptionParameters, + expected: encryptionParameters) + } + } + self.contexts = contexts + } else { + self.contexts = try config.encryptionParameters.map { encryptionParams in + try Context(encryptionParameters: encryptionParams) + } } self.plaintextContext = try PolyContext( - degree: config.encryptionParams.polyDegree, + degree: config.encryptionParameters[0].polyDegree, moduli: config.plaintextModuli) self.crtComposer = try CrtComposer(polyContext: plaintextContext) } @@ -68,17 +74,17 @@ struct Client { /// - Returns: The query. /// - Throws: Error upon failure to generate the query. @inlinable - func generateQuery(vectors: Array2d, using secretKey: SecretKey) throws -> Query { - let scaledVectors: Array2d = vectors.normalizedRows(norm: Array2d.Norm.Lp(p: 2.0)) - .scaled(by: Float(config.scalingFactor)).rounded() - let dimensions = try MatrixDimensions(rowCount: vectors.rowCount, columnCount: vectors.columnCount) - + public func generateQuery(for vectors: Array2d, + using secretKey: SecretKey) throws -> Query + { + let scaledVectors: Array2d = vectors + .normalizedScaledAndRounded(scalingFactor: Float(config.scalingFactor)) let matrices = try contexts.map { context in // For a single plaintext modulus, reduction isn't necessary let shouldReduce = contexts.count > 1 let plaintextMatrix = try PlaintextMatrix( context: context, - dimensions: dimensions, + dimensions: MatrixDimensions(vectors.shape), packing: config.queryPacking, signedValues: scaledVectors.data, reduce: shouldReduce) @@ -94,7 +100,7 @@ struct Client { /// - Returns: The distances from the query vectors to the database rows. /// - Throws: Error upon failure to decrypt the response. @inlinable - func decrypt(response: Response, using secretKey: SecretKey) throws -> DatabaseDistances { + public func decrypt(response: Response, using secretKey: SecretKey) throws -> DatabaseDistances { guard let dimensions = response.ciphertextMatrices.first?.dimensions else { throw PnnsError.emptyCiphertextArray } @@ -123,6 +129,14 @@ struct Client { entryMetadatas: response.entryMetadatas) } + /// Generates a secret key for query encryption and response decryption. + /// - Returns: A freshly generated secret key. + /// - Throws: Error upon failure to generate a secret key. + @inlinable + public func generateSecretKey() throws -> SecretKey { + try contexts[0].generateSecretKey() + } + /// Generates an ``EvaluationKey`` for use in nearest neighbors search. /// - Parameter secretKey: Secret key used to generate the evaluation key. /// - Returns: The evaluation key. @@ -130,58 +144,7 @@ struct Client { /// - Warning: Uses the first context to generate the evaluation key. So either the HE scheme should generate /// evaluation keys independent of the plaintext modulus (as in BFV), or there should be just one plaintext modulus. @inlinable - func generateEvaluationKey(using secretKey: SecretKey) throws -> EvaluationKey { + public func generateEvaluationKey(using secretKey: SecretKey) throws -> EvaluationKey { try contexts[0].generateEvaluationKey(configuration: evaluationKeyConfiguration, using: secretKey) } } - -extension Array2d where T == Float { - /// A mapping from vectors to non-negative real numbers. - @usableFromInline - enum Norm { - case Lp(p: Float) // sum_i (|x_i|^p)^{1/p} - } - - /// Normalizes each row in the matrix. - @inlinable - func normalizedRows(norm: Norm) -> Array2d { - switch norm { - case let Norm.Lp(p): - let normalizedValues = data.chunks(ofCount: columnCount).flatMap { row in - let sumOfPowers = row.map { pow($0, p) }.reduce(0, +) - let norm = pow(sumOfPowers, 1 / p) - return row.map { value in - if sumOfPowers.isZero { - Float.zero - } else { - value / norm - } - } - } - return Array2d( - data: normalizedValues, - rowCount: rowCount, - columnCount: columnCount) - } - } - - /// Returns the matrix where each entry is rounded to the closest integer. - @inlinable - func rounded() -> Array2d { - Array2d( - data: data.map { value in V(value.rounded()) }, - rowCount: rowCount, - columnCount: columnCount) - } - - /// Returns the matrix where each entry has been multiplied by a scaling factor. - /// - Parameter scalingFactor: The factor to multiply each entry by. - /// - Returns: The scaled matrix. - @inlinable - func scaled(by scalingFactor: Float) -> Array2d { - Array2d( - data: data.map { value in value * scalingFactor }, - rowCount: rowCount, - columnCount: columnCount) - } -} diff --git a/Sources/PrivateNearestNeighborsSearch/Config.swift b/Sources/PrivateNearestNeighborsSearch/Config.swift index 30e50707..c1e42bfc 100644 --- a/Sources/PrivateNearestNeighborsSearch/Config.swift +++ b/Sources/PrivateNearestNeighborsSearch/Config.swift @@ -23,7 +23,7 @@ public enum DistanceMetric: CaseIterable, Codable, Equatable, Hashable, Sendable /// Client configuration. public struct ClientConfig: Codable, Equatable, Hashable, Sendable { /// Encryption parameters. - public let encryptionParams: EncryptionParameters + public let encryptionParameters: [EncryptionParameters] /// Factor by which to scale floating-point entries before rounding to integers. public let scalingFactor: Int /// Packing for the query. @@ -40,9 +40,7 @@ public struct ClientConfig: Codable, Equatable, Hashable, Send public let extraPlaintextModuli: [Scheme.Scalar] /// The plaintext CRT moduli. - var plaintextModuli: [Scheme.Scalar] { - [encryptionParams.plaintextModulus] + extraPlaintextModuli - } + public var plaintextModuli: [Scheme.Scalar] { encryptionParameters.map(\.plaintextModulus) } /// Creates a new ``ClientConfig``. /// - Parameters: @@ -54,6 +52,7 @@ public struct ClientConfig: Codable, Equatable, Hashable, Send /// - distanceMetric: Metric for nearest neighbors computation /// - extraPlaintextModuli: For plaintext CRT, the list of extra plaintext moduli. The first plaintext modulus /// will be the one in ``ClientConfig/encryptionParams``. + /// - Throws: Error upon failure to create a new client config. public init( encryptionParams: EncryptionParameters, scalingFactor: Int, @@ -61,9 +60,17 @@ public struct ClientConfig: Codable, Equatable, Hashable, Send vectorDimension: Int, evaluationKeyConfig: EvaluationKeyConfiguration, distanceMetric: DistanceMetric, - extraPlaintextModuli: [Scheme.Scalar] = []) + extraPlaintextModuli: [Scheme.Scalar] = []) throws { - self.encryptionParams = encryptionParams + let extraEncryptionParams = try extraPlaintextModuli.map { plaintextModulus in + try EncryptionParameters( + polyDegree: encryptionParams.polyDegree, + plaintextModulus: plaintextModulus, + coefficientModuli: encryptionParams.coefficientModuli, + errorStdDev: encryptionParams.errorStdDev, + securityLevel: encryptionParams.securityLevel) + } + self.encryptionParameters = [encryptionParams] + extraEncryptionParams self.scalingFactor = scalingFactor self.queryPacking = queryPacking self.vectorDimension = vectorDimension @@ -72,30 +79,15 @@ public struct ClientConfig: Codable, Equatable, Hashable, Send self.extraPlaintextModuli = extraPlaintextModuli } - static func maxScalingFactor(vectorDimension: Int, distanceMetric: DistanceMetric, - plaintextModuli: [Scheme.Scalar]) -> Int + @inlinable + package static func maxScalingFactor(distanceMetric: DistanceMetric, vectorDimension: Int, + plaintextModuli: [Scheme.Scalar]) -> Int { precondition(distanceMetric == .cosineSimilarity) let t = plaintextModuli.map { Float($0) }.reduce(1, *) let scalingFactor = (((t - 1) / 2).squareRoot() - Float(vectorDimension).squareRoot() / 2).rounded(.down) return Int(scalingFactor) } - - /// Computes the encryption parameters, one per plaintext modulus. - /// - /// - Returns: The encryption parameters - /// - Throws: Error upon failure to generate the encryption parameters. - func encryptionParameters() throws -> [EncryptionParameters] { - let extraEncryptionParams = try extraPlaintextModuli.map { plaintextModulus in - try EncryptionParameters( - polyDegree: encryptionParams.polyDegree, - plaintextModulus: plaintextModulus, - coefficientModuli: encryptionParams.coefficientModuli, - errorStdDev: encryptionParams.errorStdDev, - securityLevel: encryptionParams.securityLevel) - } - return [encryptionParams] + extraEncryptionParams - } } /// Server configuration. @@ -110,6 +102,10 @@ public struct ServerConfig: Codable, Equatable, Hashable, Send public var plaintextModuli: [Scheme.Scalar] { clientConfig.plaintextModuli } /// Distance metric. public var distanceMetric: DistanceMetric { clientConfig.distanceMetric } + /// The encryption parameters, one per plaintext modulus. + public var encryptionParameters: [EncryptionParameters] { + clientConfig.encryptionParameters + } /// Creates a new ``ServerConfig``. /// - Parameters: @@ -122,12 +118,4 @@ public struct ServerConfig: Codable, Equatable, Hashable, Send self.clientConfig = clientConfig self.databasePacking = databasePacking } - - /// Computes the encryption parameters, one per plaintext modulus. - /// - /// - Returns: The encryption parameters - /// - Throws: Error upon failure to generate the encryption parameters. - public func encryptionParameters() throws -> [EncryptionParameters] { - try clientConfig.encryptionParameters() - } } diff --git a/Sources/PrivateNearestNeighborsSearch/Error.swift b/Sources/PrivateNearestNeighborsSearch/Error.swift index 8e243509..702e650a 100644 --- a/Sources/PrivateNearestNeighborsSearch/Error.swift +++ b/Sources/PrivateNearestNeighborsSearch/Error.swift @@ -15,16 +15,31 @@ import Foundation import HomomorphicEncryption +public enum InvalidQueryReason: Error, Equatable { + case wrongCiphertextMatrixCount(got: Int, expected: Int) +} + +extension InvalidQueryReason: LocalizedError { + public var errorDescription: String? { + switch self { + case let .wrongCiphertextMatrixCount(got, expected): + "Wrong ciphertext matrix count \(got), expected \(expected)" + } + } +} + /// Error type for ``PrivateNearestNeighborsSearch``. public enum PnnsError: Error, Equatable { case emptyCiphertextArray case emptyPlaintextArray case invalidMatrixDimensions(_ dimensions: MatrixDimensions) + case invalidQuery(reason: InvalidQueryReason) case simdEncodingNotSupported(_ description: String) case wrongCiphertextCount(got: Int, expected: Int) case wrongContext(gotDescription: String, expectedDescription: String) case wrongDistanceMetric(got: DistanceMetric, expected: DistanceMetric) case wrongEncodingValuesCount(got: Int, expected: Int) + case wrongEncryptionParameters(gotDescription: String, expectedDescription: String) case wrongMatrixPacking(got: MatrixPacking, expected: MatrixPacking) case wrongPlaintextCount(got: Int, expected: Int) } @@ -39,6 +54,14 @@ extension PnnsError { static func wrongContext(got: Context, expected: Context) -> Self { PnnsError.wrongContext(gotDescription: got.description, expectedDescription: expected.description) } + + @inlinable + static func wrongEncryptionParameters( + got: EncryptionParameters, + expected: EncryptionParameters) -> Self + { + PnnsError.wrongEncryptionParameters(gotDescription: got.description, expectedDescription: expected.description) + } } extension PnnsError: LocalizedError { @@ -52,14 +75,18 @@ extension PnnsError: LocalizedError { "Invalid matrix dimensions: rowCount \(dimensions.rowCount), columnCount \(dimensions.columnCount)" case let .simdEncodingNotSupported(encryptionParameters): "SIMD encoding is not supported for encryption parameters \(encryptionParameters)" + case let .invalidQuery(reason): + "Invalid query due to \(reason)" case let .wrongCiphertextCount(got, expected): "Wrong ciphertext count \(got), expected \(expected)" case let .wrongContext(gotDescription, expectedDescription): - "Wrong context: got \(gotDescription), expected \(expectedDescription)" + "Wrong context \(gotDescription), expected \(expectedDescription)" case let .wrongDistanceMetric(got, expected): - "Wrong distance metric: got \(got), expected \(expected)" + "Wrong distance metric \(got), expected \(expected)" case let .wrongEncodingValuesCount(got, expected): "Wrong encoding values count \(got), expected \(expected)" + case let .wrongEncryptionParameters(got, expected): + "Wrong encryption parameters \(got), expected \(expected)" case let .wrongMatrixPacking(got: got, expected: expected): "Wrong matrix packing \(got), expected \(expected)" case let .wrongPlaintextCount(got, expected): diff --git a/Sources/PrivateNearestNeighborsSearch/MatrixMultiplication.swift b/Sources/PrivateNearestNeighborsSearch/MatrixMultiplication.swift index 6c465db2..49a989df 100644 --- a/Sources/PrivateNearestNeighborsSearch/MatrixMultiplication.swift +++ b/Sources/PrivateNearestNeighborsSearch/MatrixMultiplication.swift @@ -32,18 +32,28 @@ public struct BabyStepGiantStep: Codable, Equatable, Hashable, Sendable { self.giantStep = giantStep } - public init(vectorDimension: Int) { - let dimension = Int32(vectorDimension).nextPowerOfTwo - let babyStep = Int32(Double(dimension).squareRoot().rounded(.up)) + @inlinable + public init(vectorDimension: Int, babyStep: Int) { + let dimension = vectorDimension.nextPowerOfTwo let giantStep = dimension.dividingCeil(babyStep, variableTime: true) + self.init( + vectorDimension: vectorDimension, + babyStep: babyStep, + giantStep: giantStep) + } - self.init(vectorDimension: Int(dimension), babyStep: Int(babyStep), giantStep: Int(giantStep)) + @inlinable + public init(vectorDimension: Int) { + let dimension = vectorDimension.nextPowerOfTwo + let babyStep = Int(Double(dimension).squareRoot().rounded(.up)) + self.init(vectorDimension: dimension, babyStep: babyStep) } } /// Helper function to compute evaluation key used in computing multiplication with a vector. -enum MatrixMultiplication { - static func evaluationKeyConfig( +package enum MatrixMultiplication { + @inlinable + package static func evaluationKeyConfig( plaintextMatrixDimensions: MatrixDimensions, encryptionParameters: EncryptionParameters) throws -> EvaluationKeyConfiguration { @@ -69,6 +79,7 @@ extension PlaintextMatrix { /// - evaluationKey: Evaluation key to perform BabyStepGiantStep rotations. /// - Returns: Encrypted dense-column packed vector containing dot products. /// - Throws: Error upon failure to compute the inner product. + @inlinable func mul( ciphertextVector: CiphertextMatrix, using evaluationKey: EvaluationKey) throws -> CiphertextMatrix @@ -157,7 +168,7 @@ extension PlaintextMatrix { } } let ciphertexMatrixDimensions = try MatrixDimensions( - rowCount: resultCiphertextCount * context.encryptionParameters.polyDegree, + rowCount: dimensions.rowCount, columnCount: 1) return try CiphertextMatrix( dimensions: ciphertexMatrixDimensions, diff --git a/Sources/PrivateNearestNeighborsSearch/PlaintextMatrix.swift b/Sources/PrivateNearestNeighborsSearch/PlaintextMatrix.swift index 6185ae6f..9212993b 100644 --- a/Sources/PrivateNearestNeighborsSearch/PlaintextMatrix.swift +++ b/Sources/PrivateNearestNeighborsSearch/PlaintextMatrix.swift @@ -60,6 +60,14 @@ public struct MatrixDimensions: Equatable, Sendable { throw PnnsError.invalidMatrixDimensions(self) } } + + /// Initializes a ``MatrixDimensions``. + /// - Parameter shape: The (rowCount, columnCount). + /// - Throws: Error upon failure to initialize the dimensions. + @inlinable + public init(_ shape: (Int, Int)) throws { + try self.init(rowCount: shape.0, columnCount: shape.1) + } } /// Stores a matrix of scalars as plaintexts. @@ -446,10 +454,11 @@ public struct PlaintextMatrix: Equatable, chunk += repeatElement(0, count: n - chunk.count) let i = (plaintexts.count - chunkIndex) / plaintextsPerColumn let rotationStep = i.previousMultiple(of: bsgs.babyStep, variableTime: true) - let middle = min(chunk.endIndex, chunk.startIndex + n / 2) - chunk[chunk.startIndex..: Sendable { - // Encrypted query; one matrix per plaintext CRT modulus + /// Encrypted query; one matrix per plaintext CRT modulus. public let ciphertextMatrices: [CiphertextMatrix] + + /// Creates a ``Query``. + /// - Parameter ciphertextMatrices: Encrypted query. + public init(ciphertextMatrices: [CiphertextMatrix]) { + self.ciphertextMatrices = ciphertextMatrices + } } /// A nearest neighbor search response. public struct Response: Sendable { - // Encrypted response; one matrix per plaintext CRT modulus + /// Encrypted distances; one matrix per plaintext CRT modulus. public let ciphertextMatrices: [CiphertextMatrix] - // The entry identifiers the server computed distances for. + /// The entry identifiers the server computed distances for. public let entryIds: [UInt64] - // Metadata for each entry the server computed distances for. + /// Metadata for each entry the server computed distances for. public let entryMetadatas: [[UInt8]] /// Creates a new ``Response``. /// - Parameters: - /// - ciphertextMatrices: Ciphertext matrices. + /// - ciphertextMatrices: Encrypted distances; one matrix per plaintext CRT modulus. /// - entryIds: An identifiers the server computed distances for. /// - entryMetadatas: Metadata for each entry the server computed distances for. public init( @@ -46,11 +52,46 @@ public struct Response: Sendable { } /// Distances from one or more query vector to the database rows. -struct DatabaseDistances: Sendable { - /// The distance from each query vector (outer dimension) to each database row (inner dimension). - let distances: Array2d - // Identifier for each entry in the database. - let entryIds: [UInt64] - // Metadata for each entry in the database. - let entryMetadatas: [[UInt8]] +public struct DatabaseDistances: Sendable { + /// Each row contains the distances from a database entry to each query vector. + public let distances: Array2d + /// Identifier for each entry in the database. + public let entryIds: [UInt64] + /// Metadata for each entry in the database. + public let entryMetadatas: [[UInt8]] + + /// Creates a new ``DatabaseDistances``. + /// - Parameters: + /// - distances: Each row contains the distances from a database entry to each query vector. + /// - entryIds: Identifier for each entry in the database + /// - entryMetadatas: Metadata for each entry in the database + public init( + distances: Array2d, + entryIds: [UInt64], + entryMetadatas: [[UInt8]]) + { + self.distances = distances + self.entryIds = entryIds + self.entryMetadatas = entryMetadatas + } +} + +extension Response { + /// Computes the noise budget of the ciphertext. + /// + /// The *noise budget* of the ciphertext decreases throughout HE operations. Once a ciphertext's noise budget is + /// below + /// ``HeScheme/minNoiseBudget``, decryption may yield inaccurate plaintexts. + /// - Parameters: + /// - secretKey: Secret key. + /// - variableTime: If `true`, indicates the secret key coefficients may be leaked through timing. + /// - Returns: The noise budget. + /// - Throws: Error upon failure to compute the noise budget. + /// - Warning: Leaks `secretKey` through timing. Should be used for testing only. + @inlinable + public func noiseBudget(using secretKey: Scheme.SecretKey, variableTime: Bool) throws -> Double { + try ciphertextMatrices.map { ciphertextMatrix in + try ciphertextMatrix.noiseBudget(using: secretKey, variableTime: variableTime) + }.min() ?? -Double.infinity + } } diff --git a/Sources/PrivateNearestNeighborsSearch/ProcessedDatabase.swift b/Sources/PrivateNearestNeighborsSearch/ProcessedDatabase.swift index 070d3211..e9e24ac4 100644 --- a/Sources/PrivateNearestNeighborsSearch/ProcessedDatabase.swift +++ b/Sources/PrivateNearestNeighborsSearch/ProcessedDatabase.swift @@ -15,7 +15,8 @@ import HomomorphicEncryption public struct ProcessedDatabase: Equatable, Sendable { - let contexts: [Context] + /// One context per plaintext modulus. + public let contexts: [Context] /// The processed vectors in the database. public let plaintextMatrices: [PlaintextMatrix] @@ -29,6 +30,22 @@ public struct ProcessedDatabase: Equatable, Sendable { /// Server configuration. public let serverConfig: ServerConfig + @inlinable + public init( + contexts: [Context], + plaintextMatrices: [PlaintextMatrix], + entryIds: [UInt64], + entryMetadatas: [[UInt8]], + serverConfig: ServerConfig) + { + precondition(contexts.count == plaintextMatrices.count) + self.contexts = contexts + self.plaintextMatrices = plaintextMatrices + self.entryIds = entryIds + self.entryMetadatas = entryMetadatas + self.serverConfig = serverConfig + } + /// Serializes the processed database. /// - Returns: The serialized processed database. /// - Throws: Error upon failure to serialize. @@ -44,39 +61,54 @@ public struct ProcessedDatabase: Equatable, Sendable { extension Database { /// Processes the database for neareset neighbors computation. - /// - Parameter config: Configuration to process with. + /// - Parameters: + /// - config: Configuration to process with. + /// - contexts: Contexts for HE computation, one per plaintext modulus. /// - Returns: The processed database. /// - Throws: Error upon failure to process the database. - public func process(with config: ServerConfig) throws -> ProcessedDatabase { + @inlinable + public func process(config: ServerConfig, + contexts: [Context] = []) throws -> ProcessedDatabase + { guard config.distanceMetric == .cosineSimilarity else { throw PnnsError.wrongDistanceMetric(got: config.distanceMetric, expected: .cosineSimilarity) } + var contexts = contexts + if contexts.isEmpty { + contexts = try config.encryptionParameters.map { encryptionParams in + try Context(encryptionParameters: encryptionParams) + } + } else { + precondition(contexts.count == config.encryptionParameters.count) + for (context, encryptionParameters) in zip(contexts, config.encryptionParameters) { + guard context.encryptionParameters == encryptionParameters else { + throw PnnsError.wrongEncryptionParameters( + got: context.encryptionParameters, + expected: encryptionParameters) + } + } + } + let vectors = Array2d(data: rows.map { row in row.vector }) - let roundedVectors: Array2d = vectors - .normalizedRows(norm: .Lp(p: 2.0)) - .scaled(by: Float(config.clientConfig.scalingFactor)).rounded() + let roundedVectors: Array2d = vectors.normalizedScaledAndRounded( + scalingFactor: Float(config.scalingFactor)) - let contexts = try config.encryptionParameters().map { encryptionParams in - try Context(encryptionParameters: encryptionParams) - } let plaintextMatrices: [PlaintextMatrix] = try contexts.map { context in // For a single plaintext modulus, reduction isn't necessary let shouldReduce = contexts.count > 1 return try PlaintextMatrix( context: context, - dimensions: MatrixDimensions( - rowCount: roundedVectors.rowCount, - columnCount: roundedVectors.columnCount), + dimensions: MatrixDimensions(roundedVectors.shape), packing: config.databasePacking, signedValues: roundedVectors.data, reduce: shouldReduce).convertToEvalFormat() } - + let hasMetadata = rows.contains { row in !row.entryMetadata.isEmpty } return ProcessedDatabase( contexts: contexts, plaintextMatrices: plaintextMatrices, entryIds: rows.map { row in row.entryId }, - entryMetadatas: rows.map { row in row.entryMetadata }, + entryMetadatas: hasMetadata ? rows.map { row in row.entryMetadata } : [], serverConfig: config) } } diff --git a/Sources/PrivateNearestNeighborsSearch/Server.swift b/Sources/PrivateNearestNeighborsSearch/Server.swift new file mode 100644 index 00000000..6dfd5b99 --- /dev/null +++ b/Sources/PrivateNearestNeighborsSearch/Server.swift @@ -0,0 +1,86 @@ +// Copyright 2024 Apple Inc. and the Swift Homomorphic Encryption project authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import HomomorphicEncryption + +/// Private nearest neighbors server. +public struct Server { + /// Configuration. + public let config: ServerConfig + + /// The database. + public let database: ProcessedDatabase + + /// Configuration needed for private nearest neighbors search.. + public var evaluationKeyConfiguration: EvaluationKeyConfiguration { + config.clientConfig.evaluationKeyConfig + } + + /// One context per plaintext modulus. + public var contexts: [Context] { + database.contexts + } + + /// Creates a new ``Server``. + /// - Parameters: + /// - database: Processed database. + /// - config: Server configuration. + /// - Throws: Error upon failure to create the server. + @inlinable + public init(database: ProcessedDatabase, config: ServerConfig) throws { + guard config.distanceMetric == .cosineSimilarity else { + throw PnnsError.wrongDistanceMetric(got: config.distanceMetric, expected: .cosineSimilarity) + } + self.config = config + self.database = database + } + + /// Compute the encrypted response to a query. + /// - Parameters: + /// - query: Query. + /// - evaluationKey: Evaluation key to aid in the server computation. + /// - Returns: The response. + /// - Throws: Error upon failure to compute a response. + @inlinable + public func computeResponse(to query: Query, + using evaluationKey: EvaluationKey) throws -> Response + { + guard query.ciphertextMatrices.count == database.plaintextMatrices.count else { + throw PnnsError.invalidQuery(reason: InvalidQueryReason.wrongCiphertextMatrixCount( + got: query.ciphertextMatrices.count, + expected: database.plaintextMatrices.count)) + } + + let responseMatrices = try zip(query.ciphertextMatrices, database.plaintextMatrices) + .map { ciphertextMatrix, plaintextMatrix in + // Client query has transposed dimensions + // TODO: remove (and make CiphertextMatrix.dimensions `let` instead of `var) + var ciphertextMatrix = ciphertextMatrix + ciphertextMatrix.dimensions = try MatrixDimensions( + rowCount: ciphertextMatrix.columnCount, + columnCount: ciphertextMatrix.rowCount) + var responseMatrix = try plaintextMatrix.mul( + ciphertextVector: ciphertextMatrix.convertToCanonicalFormat(), + using: evaluationKey) + // Reduce response size by mod-switching to a single modulus. + try responseMatrix.modSwitchDownToSingle() + return try responseMatrix.convertToCoeffFormat() + } + + return Response( + ciphertextMatrices: responseMatrices, + entryIds: database.entryIds, + entryMetadatas: database.entryMetadatas) + } +} diff --git a/Sources/PrivateNearestNeighborsSearch/Util.swift b/Sources/PrivateNearestNeighborsSearch/Util.swift new file mode 100644 index 00000000..d4cac711 --- /dev/null +++ b/Sources/PrivateNearestNeighborsSearch/Util.swift @@ -0,0 +1,158 @@ +// Copyright 2024 Apple Inc. and the Swift Homomorphic Encryption project authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Algorithms +import Foundation +import HomomorphicEncryption + +extension Array2d where T == Float { + /// A mapping from vectors to non-negative real numbers. + @usableFromInline + package enum Norm: Equatable { + case Lp(p: Float) // sum_i (|x_i|^p)^{1/p} + } + + /// Normalizes each row in the matrix. + @inlinable + package func normalizedRows(norm: Norm) -> Array2d { + switch norm { + case let Norm.Lp(p): + let normalizedValues = data.chunks(ofCount: columnCount).flatMap { row in + let sumOfPowers = row.map { pow($0, p) }.reduce(0, +) + let norm = pow(sumOfPowers, 1 / p) + return row.map { value in + if sumOfPowers.isZero { + Float.zero + } else { + value / norm + } + } + } + return Array2d( + data: normalizedValues, + rowCount: rowCount, + columnCount: columnCount) + } + } + + /// Returns the matrix where each entry is rounded to the closest integer. + @inlinable + package func rounded() -> Array2d { + Array2d( + data: data.map { value in V(value.rounded()) }, + rowCount: rowCount, + columnCount: columnCount) + } + + /// Returns the matrix where each entry has been multiplied by a scaling factor. + /// - Parameter scalingFactor: The factor to multiply each entry by. + /// - Returns: The scaled matrix. + @inlinable + package func scaled(by scalingFactor: Float) -> Array2d { + Array2d( + data: data.map { value in value * scalingFactor }, + rowCount: rowCount, + columnCount: columnCount) + } + + /// Normalizes the each rows' vector with L2 norm, then scales and rounds each entry. + /// - Parameter scalingFactor: The factor to multiply each entry by. + /// - Returns: The matrix after the normalization, scaling, and rounding. + @inlinable + package func normalizedScaledAndRounded(scalingFactor: Float) -> Array2d { + let normalizedValues = data.chunks(ofCount: columnCount).flatMap { row in + let norm = row.map { $0 * $0 }.reduce(0, +).squareRoot() + return row.map { value in + if norm.isZero { + V.zero + } else { + V((value * scalingFactor / norm).rounded()) + } + } + } + return Array2d( + data: normalizedValues, + rowCount: rowCount, + columnCount: columnCount) + } +} + +extension Array2d where T: SignedScalarType { + /// Performs modular matrix multiplication. + /// - Parameters: + /// - rhs: Matrix to multiply with. + /// - modulus: Modulus. + /// - Returns: The matrix product; each value is in `[-floor(modulus/2), floor(modulus-1)/2]` + @inlinable + package func mul(_ rhs: Self, modulus: T.UnsignedScalar) -> Self { + precondition(columnCount == rhs.rowCount) + let signedModulus = T(modulus) + var result = Array2d.zero(rowCount: rowCount, columnCount: rhs.columnCount) + for row in 0..= signedModulus { + sum -= signedModulus + } + result[row, column] = sum + } + result[row, column] = T.UnsignedScalar(result[row, column]).remainderToCentered(modulus: modulus) + } + } + return result + } +} + +extension Array2d where T == Float { + @inlinable + package func mul(_ rhs: Self) -> Self { + precondition(columnCount == rhs.rowCount) + var result = Array2d.zero(rowCount: rowCount, columnCount: rhs.columnCount) + for row in 0..(_ rhs: Self, modulus: V, + scalingFactor: Float) throws -> Self + { + let lhsScaled: Array2d = normalizedScaledAndRounded(scalingFactor: scalingFactor) + let rhsScaled: Array2d = rhs.transposed() + .normalizedScaledAndRounded(scalingFactor: scalingFactor) + .transposed() + let product = lhsScaled.mul(rhsScaled, modulus: modulus) + return product.map { Float($0) / (scalingFactor * scalingFactor) } + } +} + +@inlinable +package func fixedPointCosineSimilarityError(innerDimension _: Int, scalingFactor: Int) -> Float { + // With scaling factor 10, 0.45 would round to 0.50, for error 0.05 + let scaleAndRoundError = 1.0 / Float(2 * scalingFactor) + return pow(1 + scaleAndRoundError, 2) - 1.0 +} diff --git a/Sources/PrivateNearestNeighborsSearchProtobuf/ConversionPnns.swift b/Sources/PrivateNearestNeighborsSearchProtobuf/ConversionPnns.swift index b8f1f57d..da60b2b4 100644 --- a/Sources/PrivateNearestNeighborsSearchProtobuf/ConversionPnns.swift +++ b/Sources/PrivateNearestNeighborsSearchProtobuf/ConversionPnns.swift @@ -138,7 +138,7 @@ extension ClientConfig { /// - Throws: Error upon unsupported object. public func proto() throws -> Apple_SwiftHomomorphicEncryption_Pnns_V1_ClientConfig { try Apple_SwiftHomomorphicEncryption_Pnns_V1_ClientConfig.with { config in - config.encryptionParameters = try encryptionParams.proto() + config.encryptionParameters = try encryptionParameters[0].proto() config.scalingFactor = UInt64(scalingFactor) config.queryPacking = try queryPacking.proto() config.vectorDimension = UInt32(vectorDimension) @@ -311,3 +311,21 @@ extension Query { try ciphertextMatrices.map { matrix in try matrix.serialize().proto() } } } + +extension Query { + package func size() throws -> Int { + try proto().map { matrix in try matrix.serializedData().count }.sum() + } +} + +extension Response { + package func size() throws -> Int { + try proto().serializedData().count + } +} + +extension EvaluationKey { + package func size() throws -> Int { + try serialize().proto().serializedData().count + } +} diff --git a/Tests/PrivateNearestNeighborsSearchProtobufTests/ConversionTests.swift b/Tests/PrivateNearestNeighborsSearchProtobufTests/ConversionTests.swift index fdabd68e..2ba0d589 100644 --- a/Tests/PrivateNearestNeighborsSearchProtobufTests/ConversionTests.swift +++ b/Tests/PrivateNearestNeighborsSearchProtobufTests/ConversionTests.swift @@ -63,7 +63,10 @@ class ConversionTests: XCTestCase { vectorDimension: vectorDimension, evaluationKeyConfig: EvaluationKeyConfiguration(galoisElements: [3]), distanceMetric: .cosineSimilarity, - extraPlaintextModuli: [536_903_681]) + extraPlaintextModuli: Scheme.Scalar.generatePrimes( + significantBitCounts: [15], + preferringSmall: true, + nttDegree: 8)) XCTAssertEqual(try clientConfig.proto().native(), clientConfig) let serverConfig = ServerConfig( @@ -142,23 +145,16 @@ class ConversionTests: XCTestCase { dimensions: dimensions, packing: .denseColumn, values: scalars.flatMap { $0 }) - let ciphertextMatrix = try plaintextMatrix.encrypt(using: secretKey) - let serialized = try ciphertextMatrix.serialize() - let serializedProto = try serialized.proto() - XCTAssertEqual(try serializedProto.native(), serialized) - let serializedSize = try serializedProto.serializedData().count - - let serializedForDecryption = try ciphertextMatrix.serialize(forDecryption: true) - let serializedForDecryptionSize = try serializedForDecryption.proto().serializedData().count - XCTAssertLessThanOrEqual(serializedForDecryptionSize, serializedSize) - let deserialized = try CiphertextMatrix( - deserialize: serializedForDecryption, - context: context) - let decrypted = try deserialized.decrypt(using: secretKey) - XCTAssertEqual(decrypted, plaintextMatrix) - + // Check Canonical Format + do { + let ciphertextMatrix = try plaintextMatrix.encrypt(using: secretKey) + let serialized = try ciphertextMatrix.serialize() + let serializedProto = try serialized.proto() + XCTAssertEqual(try serializedProto.native(), serialized) + } // Check Evaluation format do { + let ciphertextMatrix = try plaintextMatrix.encrypt(using: secretKey) let evalCiphertextMatrix = try ciphertextMatrix.convertToEvalFormat() let serialized = try evalCiphertextMatrix.serialize() XCTAssertEqual(try serialized.proto().native(), serialized) @@ -167,6 +163,24 @@ class ConversionTests: XCTestCase { context: context) XCTAssertEqual(deserialized, evalCiphertextMatrix) } + // Check serializeForDecryption + do { + var ciphertextMatrix = try plaintextMatrix.encrypt(using: secretKey) + try ciphertextMatrix.modSwitchDownToSingle() + let serializedForDecryption = try ciphertextMatrix.serialize(forDecryption: true) + let serializedForDecryptionSize = try serializedForDecryption.proto().serializedData().count + + let serialized = try ciphertextMatrix.serialize() + let serializedProto = try serialized.proto() + let serializedSize = try serializedProto.serializedData().count + + XCTAssertLessThan(serializedForDecryptionSize, serializedSize) + let deserialized = try CiphertextMatrix( + deserialize: serializedForDecryption, + context: context, moduliCount: 1) + let decrypted = try deserialized.decrypt(using: secretKey) + XCTAssertEqual(decrypted, plaintextMatrix) + } } try runTest(Bfv.self) @@ -207,7 +221,7 @@ class ConversionTests: XCTestCase { .diagonal( babyStepGiantStep: BabyStepGiantStep(vectorDimension: vectorDimension))) - let processed = try database.process(with: serverConfig) + let processed = try database.process(config: serverConfig) let serialized = try processed.serialize() XCTAssertEqual(try serialized.proto().native(), serialized) } diff --git a/Tests/PrivateNearestNeighborsSearchTests/CiphertextMatrixTests.swift b/Tests/PrivateNearestNeighborsSearchTests/CiphertextMatrixTests.swift index b7b26110..8aa1079c 100644 --- a/Tests/PrivateNearestNeighborsSearchTests/CiphertextMatrixTests.swift +++ b/Tests/PrivateNearestNeighborsSearchTests/CiphertextMatrixTests.swift @@ -43,9 +43,16 @@ final class CiphertextMatrixTests: XCTestCase { packing: .denseRow, values: encodeValues.flatMap { $0 }) let secretKey = try context.generateSecretKey() - let ciphertextMatrix = try plaintextMatrix.encrypt(using: secretKey) - let plaintextMatrixroundTrip = try ciphertextMatrix.decrypt(using: secretKey) - XCTAssertEqual(plaintextMatrixroundTrip, plaintextMatrix) + var ciphertextMatrix = try plaintextMatrix.encrypt(using: secretKey) + let plaintextMatrixRoundTrip = try ciphertextMatrix.decrypt(using: secretKey) + XCTAssertEqual(plaintextMatrixRoundTrip, plaintextMatrix) + + // modSwitchDownToSingle + do { + try ciphertextMatrix.modSwitchDownToSingle() + let plaintextMatrixRoundTrip = try ciphertextMatrix.decrypt(using: secretKey) + XCTAssertEqual(plaintextMatrixRoundTrip, plaintextMatrix) + } } try runTest(for: NoOpScheme.self) try runTest(for: Bfv.self) @@ -103,9 +110,7 @@ final class CiphertextMatrixTests: XCTestCase { for rowCount in 1..<(2 * degree) { for columnCount in 1...maxScalingFactor( - vectorDimension: 128, distanceMetric: .cosineSimilarity, + vectorDimension: 128, plaintextModuli: Array(plaintextModuli.prefix(1))) let maxScalingFactor2 = ClientConfig.maxScalingFactor( - vectorDimension: 128, distanceMetric: .cosineSimilarity, + vectorDimension: 128, plaintextModuli: plaintextModuli) XCTAssertGreaterThan(maxScalingFactor2, maxScalingFactor1) @@ -75,6 +75,12 @@ final class ClientTests: XCTestCase { } let rounded: Array2d = scaled.rounded() XCTAssertEqual(rounded.data, testCase.rounded.flatMap { $0 }) + + if testCase.norm == Array2d.Norm.Lp(p: 2.0) { + let normalizedScaledAndRounded: Array2d = floatMatrix.normalizedScaledAndRounded( + scalingFactor: testCase.scalingFactor) + XCTAssertEqual(normalizedScaledAndRounded.data, testCase.rounded.flatMap { $0 }) + } } let testCases: [TestCase] = [ @@ -96,7 +102,7 @@ final class ClientTests: XCTestCase { } } - func testQuery() throws { + func testQueryAsResponse() throws { func runTest(for _: Scheme.Type) throws { let degree = 512 let encryptionParams = try EncryptionParameters( @@ -127,7 +133,7 @@ final class ClientTests: XCTestCase { significantBitCounts: [17], preferringSmall: true, nttDegree: degree)] { - let config = ClientConfig( + let config = try ClientConfig( encryptionParams: encryptionParams, scalingFactor: scalingFactor, queryPacking: .denseRow, @@ -136,11 +142,12 @@ final class ClientTests: XCTestCase { distanceMetric: .cosineSimilarity, extraPlaintextModuli: extraPlaintextModuli) let client = try Client(config: config) - let query = try client.generateQuery(vectors: queryValues, using: secretKey) + let query = try client.generateQuery(for: queryValues, using: secretKey) XCTAssertEqual(query.ciphertextMatrices.count, config.plaintextModuli.count) let entryIds = [UInt64(42)] let entryMetadatas = [42.littleEndianBytes] + // Treat the query as a response let response = Response( ciphertextMatrices: query.ciphertextMatrices, entryIds: entryIds, entryMetadatas: entryMetadatas) @@ -149,8 +156,7 @@ final class ClientTests: XCTestCase { XCTAssertEqual(databaseDistances.entryMetadatas, entryMetadatas) let scaledQuery: Array2d = queryValues - .normalizedRows(norm: Array2d.Norm.Lp(p: 2.0)).scaled(by: Float(config.scalingFactor)) - .rounded() + .normalizedScaledAndRounded(scalingFactor: Float(config.scalingFactor)) // Cosine similarity response returns result scaled by scalingFactor^2 let expectedDistances = scaledQuery.map { value in Float(value) / Float(config.scalingFactor * config.scalingFactor) @@ -161,4 +167,101 @@ final class ClientTests: XCTestCase { try runTest(for: Bfv.self) try runTest(for: Bfv.self) } + + func testClientServer() throws { + func runSingleTest( + encryptionParams: EncryptionParameters, + dimensions: MatrixDimensions, + plaintextModuli: [Scheme.Scalar], + queryCount: Int) throws + { + let vectorDimension = dimensions.columnCount + let scalingFactor = ClientConfig.maxScalingFactor( + distanceMetric: .cosineSimilarity, + vectorDimension: vectorDimension, + plaintextModuli: plaintextModuli) + let evaluatonKeyConfig = try MatrixMultiplication.evaluationKeyConfig( + plaintextMatrixDimensions: dimensions, + encryptionParameters: encryptionParams) + let clientConfig = try ClientConfig( + encryptionParams: encryptionParams, + scalingFactor: scalingFactor, + queryPacking: .denseRow, + vectorDimension: vectorDimension, + evaluationKeyConfig: evaluatonKeyConfig, + distanceMetric: .cosineSimilarity, + extraPlaintextModuli: Array(plaintextModuli[1...])) + let serverConfig = ServerConfig( + clientConfig: clientConfig, + databasePacking: .diagonal(babyStepGiantStep: BabyStepGiantStep(vectorDimension: vectorDimension))) + + let database = getDatabaseForTesting(config: DatabaseConfig( + rowCount: dimensions.rowCount, + vectorDimension: dimensions.columnCount)) + let processed = try database.process(config: serverConfig) + + let client = try Client(config: clientConfig, contexts: processed.contexts) + let server = try Server(database: processed, config: serverConfig) + + // We query exact matches from rows in the database + let queryVectors = Array2d(data: database.rows.prefix(queryCount).map { row in row.vector }) + let secretKey = try client.generateSecretKey() + let query = try client.generateQuery(for: queryVectors, using: secretKey) + let evaluationKey = try client.generateEvaluationKey(using: secretKey) + + let response = try server.computeResponse(to: query, using: evaluationKey) + let noiseBudget = try response.noiseBudget(using: secretKey, variableTime: true) + XCTAssertGreaterThan(noiseBudget, 0) + let decrypted = try client.decrypt(response: response, using: secretKey) + + XCTAssertEqual(decrypted.entryIds, processed.entryIds) + XCTAssertEqual(decrypted.entryMetadatas, processed.entryMetadatas) + + let vectors = Array2d(data: database.rows.map { row in row.vector }) + let modulus: UInt64 = client.config.plaintextModuli.map { UInt64($0) }.reduce(1, *) + let expected = try vectors.fixedPointCosineSimilarity( + queryVectors.transposed(), + modulus: modulus, + scalingFactor: Float(scalingFactor)) + XCTAssertEqual(decrypted.distances, expected) + } + + func runTest(for _: Scheme.Type) throws { + let degree = 64 + let maxPlaintextModuliCount = 2 + let plaintextModuli = try Scheme.Scalar.generatePrimes( + significantBitCounts: Array(repeating: 10, count: maxPlaintextModuliCount), + preferringSmall: true, + nttDegree: degree) + let coefficientModuli = try Scheme.Scalar.generatePrimes( + significantBitCounts: Array( + repeating: Scheme.Scalar.bitWidth - 4, + count: 3), + preferringSmall: false, + nttDegree: degree) + let encryptionParams = try EncryptionParameters( + polyDegree: degree, + plaintextModulus: plaintextModuli[0], + coefficientModuli: coefficientModuli, + errorStdDev: .stdDev32, + securityLevel: .unchecked) + XCTAssert(encryptionParams.supportsSimdEncoding) + + let queryCount = 1 + for rowCount in [degree / 2, degree, degree + 1, 3 * degree] { + for dimensions in try [MatrixDimensions(rowCount: rowCount, columnCount: 16)] { + for plaintextModuliCount in 1...maxPlaintextModuliCount { + try runSingleTest( + encryptionParams: encryptionParams, + dimensions: dimensions, + plaintextModuli: Array(plaintextModuli.prefix(plaintextModuliCount)), + queryCount: queryCount) + } + } + } + } + + try runTest(for: Bfv.self) + try runTest(for: Bfv.self) + } } diff --git a/Tests/PrivateNearestNeighborsSearchTests/CosineSimilarityTests.swift b/Tests/PrivateNearestNeighborsSearchTests/CosineSimilarityTests.swift new file mode 100644 index 00000000..4875ba5e --- /dev/null +++ b/Tests/PrivateNearestNeighborsSearchTests/CosineSimilarityTests.swift @@ -0,0 +1,64 @@ +// Copyright 2024 Apple Inc. and the Swift Homomorphic Encryption project authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import HomomorphicEncryption +@testable import PrivateNearestNeighborsSearch +import TestUtilities +import XCTest + +final class CosineSimilarityTests: XCTestCase { + func testNormalizeRowsAndScale() throws { + struct TestCase { + let scalingFactor: Float + let norm: Array2d.Norm + let input: [[Float]] + let normalized: [[Float]] + let scaled: [[Float]] + let rounded: [[T]] + } + + func runTestCase(testCase: TestCase) throws { + let floatMatrix = Array2d(data: testCase.input) + let normalized = floatMatrix.normalizedRows(norm: testCase.norm) + for (normalized, expected) in zip(normalized.data, testCase.normalized.flatMap { $0 }) { + XCTAssertIsClose(normalized, expected) + } + + let scaled = normalized.scaled(by: testCase.scalingFactor) + for (scaled, expected) in zip(scaled.data, testCase.scaled.flatMap { $0 }) { + XCTAssertIsClose(scaled, expected) + } + let rounded: Array2d = scaled.rounded() + XCTAssertEqual(rounded.data, testCase.rounded.flatMap { $0 }) + } + + let testCases: [TestCase] = [ + TestCase(scalingFactor: 10.0, + norm: Array2d.Norm.Lp(p: 1.0), + input: [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], + normalized: [[1.0 / 3.0, 2.0 / 3.0], [3.0 / 7.0, 4.0 / 7.0], [5.0 / 11.0, 6.0 / 11.0]], + scaled: [[10.0 / 3.0, 20.0 / 3.0], [30.0 / 7.0, 40.0 / 7.0], [50.0 / 11.0, 60.0 / 11.0]], + rounded: [[3, 7], [4, 6], [5, 5]]), + TestCase(scalingFactor: 100.0, + norm: Array2d.Norm.Lp(p: 2.0), + input: [[3.0, 4.0], [-5.0, 12.0]], + normalized: [[3.0 / 5.0, 4.0 / 5.0], [-5.0 / 13.0, 12.0 / 13.0]], + scaled: [[300.0 / 5.0, 400.0 / 5.0], [-500.0 / 13.0, 1200.0 / 13.0]], + rounded: [[60, 80], [-38, 92]]), + ] + for testCase in testCases { + try runTestCase(testCase: testCase) + } + } +} diff --git a/Tests/PrivateNearestNeighborsSearchTests/MatrixMultiplicationTests.swift b/Tests/PrivateNearestNeighborsSearchTests/MatrixMultiplicationTests.swift index b96f38ea..268951c5 100644 --- a/Tests/PrivateNearestNeighborsSearchTests/MatrixMultiplicationTests.swift +++ b/Tests/PrivateNearestNeighborsSearchTests/MatrixMultiplicationTests.swift @@ -17,49 +17,32 @@ import HomomorphicEncryption import TestUtilities import XCTest -extension Array where Element: Collection, Element.Element: ScalarType, Element.Index == Int { - typealias BaseElement = Element.Element - - func mul(_ vector: [BaseElement], modulus: BaseElement) throws -> [BaseElement] { - map { row in - precondition(row.count == vector.count) - return zip(row, vector).reduce(0) { sum, multiplicands in - let product = multiplicands.0.multiplyMod(multiplicands.1, modulus: modulus, variableTime: true) - return sum.addMod(product, modulus: modulus) - } - } - } -} - final class MatrixMultiplicationTests: XCTestCase { func testMulVector() throws { func checkProduct( _: Scheme.Type, _ plaintextRows: [[Scheme.Scalar]], - _ dimensions: MatrixDimensions, + _ plaintextMatrixDimensions: MatrixDimensions, _ queryValues: [Scheme.Scalar]) throws { let encryptionParameters = try EncryptionParameters(from: .n_4096_logq_27_28_28_logt_16) let context = try Context(encryptionParameters: encryptionParameters) let secretKey = try context.generateSecretKey() + let queryCount = queryValues.count / plaintextMatrixDimensions.columnCount - var expected: [Scheme.Scalar] = try plaintextRows.mul( + let expected: [Scheme.Scalar] = try plaintextRows.mul( queryValues, modulus: encryptionParameters.plaintextModulus) - let n = encryptionParameters.polyDegree - if expected.count % n > 0 { - expected += Array(repeating: 0, count: n - (expected.count % n)) - } let babyStepGiantStep = BabyStepGiantStep(vectorDimension: queryValues.count) let plaintextMatrix = try PlaintextMatrix( context: context, - dimensions: dimensions, + dimensions: plaintextMatrixDimensions, packing: .diagonal(babyStepGiantStep: babyStepGiantStep), values: plaintextRows.flatMap { $0 }) let evaluationKeyConfig = try MatrixMultiplication.evaluationKeyConfig( - plaintextMatrixDimensions: dimensions, + plaintextMatrixDimensions: plaintextMatrixDimensions, encryptionParameters: encryptionParameters) let evaluationKey = try context.generateEvaluationKey( configuration: evaluationKeyConfig, @@ -74,10 +57,15 @@ final class MatrixMultiplicationTests: XCTestCase { values: queryValues).encrypt(using: secretKey) let dotProduct = try plaintextMatrix.mul(ciphertextVector: ciphertextVector, using: evaluationKey) - let expectedCiphertextsCount = dimensions.rowCount.dividingCeil( + let expectedCiphertextsCount = plaintextMatrixDimensions.rowCount.dividingCeil( encryptionParameters.polyDegree, variableTime: true) XCTAssertEqual(dotProduct.ciphertexts.count, expectedCiphertextsCount) + XCTAssertEqual( + dotProduct.dimensions, + try MatrixDimensions( + rowCount: plaintextMatrixDimensions.rowCount, + columnCount: queryCount)) let resultMatrix = try dotProduct.decrypt(using: secretKey) let resultValues: [Scheme.Scalar] = try resultMatrix.unpack() diff --git a/Tests/PrivateNearestNeighborsSearchTests/PlaintextMatrixTests.swift b/Tests/PrivateNearestNeighborsSearchTests/PlaintextMatrixTests.swift index 028abda6..a87433ad 100644 --- a/Tests/PrivateNearestNeighborsSearchTests/PlaintextMatrixTests.swift +++ b/Tests/PrivateNearestNeighborsSearchTests/PlaintextMatrixTests.swift @@ -33,9 +33,7 @@ final class PlaintextMatrixTests: XCTestCase { guard encryptionParams.supportsSimdEncoding, encryptionParams.polyDegree <= 16 else { return } - let rowCount = encryptionParams.polyDegree - let columnCount = 2 - let dims = try MatrixDimensions(rowCount: rowCount, columnCount: columnCount) + let dims = try MatrixDimensions(rowCount: encryptionParams.polyDegree, columnCount: 2) let packing = MatrixPacking.denseRow let context = try Context(encryptionParameters: encryptionParams) let values = TestUtils.getRandomPlaintextData( @@ -93,7 +91,7 @@ final class PlaintextMatrixTests: XCTestCase { // Wrong number of values do { - let wrongDims = try MatrixDimensions(rowCount: rowCount, columnCount: columnCount + 1) + let wrongDims = try MatrixDimensions((rowCount, columnCount + 1)) XCTAssertThrowsError(try PlaintextMatrix( context: context, dimensions: wrongDims, @@ -102,7 +100,7 @@ final class PlaintextMatrixTests: XCTestCase { } // Too many columns do { - let dims = try MatrixDimensions(rowCount: rowCount, columnCount: columnCount + 1) + let dims = try MatrixDimensions((rowCount, columnCount + 1)) XCTAssertThrowsError(try PlaintextMatrix( context: context, dimensions: dims, @@ -292,7 +290,7 @@ final class PlaintextMatrixTests: XCTestCase { securityLevel: .unchecked) let context = try Context(encryptionParameters: encryptionParams) for ((rowCount, columnCount), expected) in kats { - let dimensions = try MatrixDimensions(rowCount: rowCount, columnCount: columnCount) + let dimensions = try MatrixDimensions((rowCount, columnCount)) try runPlaintextMatrixInitTest( context: context, dimensions: dimensions, @@ -374,7 +372,7 @@ final class PlaintextMatrixTests: XCTestCase { let encryptionParams = try EncryptionParameters(from: rlweParams) let context = try Context(encryptionParameters: encryptionParams) for ((rowCount, columnCount), expected) in kats { - let dimensions = try MatrixDimensions(rowCount: rowCount, columnCount: columnCount) + let dimensions = try MatrixDimensions((rowCount, columnCount)) try runPlaintextMatrixInitTest( context: context, dimensions: dimensions, @@ -476,7 +474,7 @@ final class PlaintextMatrixTests: XCTestCase { securityLevel: SecurityLevel.unchecked) let context = try Context(encryptionParameters: encryptionParams) for ((rowCount, columnCount), expected) in kats { - let dimensions = try MatrixDimensions(rowCount: rowCount, columnCount: columnCount) + let dimensions = try MatrixDimensions((rowCount, columnCount)) let bsgs = BabyStepGiantStep(vectorDimension: dimensions.columnCount.nextPowerOfTwo) try runPlaintextMatrixInitTest( context: context, diff --git a/Tests/PrivateNearestNeighborsSearchTests/Utils.swift b/Tests/PrivateNearestNeighborsSearchTests/Utils.swift new file mode 100644 index 00000000..ed61ab16 --- /dev/null +++ b/Tests/PrivateNearestNeighborsSearchTests/Utils.swift @@ -0,0 +1,55 @@ +// Copyright 2024 Apple Inc. and the Swift Homomorphic Encryption project authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation +import HomomorphicEncryption +import PrivateNearestNeighborsSearch + +struct DatabaseConfig { + let rowCount: Int + let vectorDimension: Int + let metadataCount: Int + + init(rowCount: Int, vectorDimension: Int, metadataCount: Int = 0) { + self.rowCount = rowCount + self.vectorDimension = vectorDimension + self.metadataCount = metadataCount + } +} + +func getDatabaseForTesting(config: DatabaseConfig) -> Database { + let rows = (0.. [BaseElement] { + map { row in + precondition(row.count == vector.count) + return zip(row, vector).reduce(0) { sum, multiplicands in + let product = multiplicands.0.multiplyMod(multiplicands.1, modulus: modulus, variableTime: true) + return sum.addMod(product, modulus: modulus) + } + } + } +} diff --git a/Tests/PrivateNearestNeighborsSearchTests/UtilsTests.swift b/Tests/PrivateNearestNeighborsSearchTests/UtilsTests.swift new file mode 100644 index 00000000..4a8b7a2d --- /dev/null +++ b/Tests/PrivateNearestNeighborsSearchTests/UtilsTests.swift @@ -0,0 +1,60 @@ +// Copyright 2024 Apple Inc. and the Swift Homomorphic Encryption project authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import HomomorphicEncryption +@testable import PrivateNearestNeighborsSearch +import TestUtilities +import XCTest + +final class UtilsTests: XCTestCase { + func testdMatrixMultiplication() throws { + // Int64 + do { + let x = Array2d(data: Array(-3..<3), rowCount: 2, columnCount: 3) + let y = Array2d(data: Array(-6..<6), rowCount: 3, columnCount: 4) + XCTAssertEqual(x.mul(y, modulus: 100), Array2d(data: [[20, 14, 8, 2], [2, 5, 8, 11]])) + // Values in [-floor(modulus/2), floor(modulus-1)/2] + XCTAssertEqual(x.mul(y, modulus: 10), Array2d(data: [[0, 4, -2, 2], [2, -5, -2, 1]])) + } + // Float + do { + let x = Array2d(data: Array(-3..<3).map { Float($0) }, rowCount: 2, columnCount: 3) + let y = Array2d(data: Array(-6..<6).map { Float($0) }, rowCount: 3, columnCount: 4) + XCTAssertEqual(x.mul(y), Array2d(data: [[20.0, 14.0, 8.0, 2.0], [2.0, 5.0, 8.0, 11.0]])) + } + } + + func testFixedPointCosineSimilarity() throws { + let innerDimension = 3 + let x = Array2d(data: Array(-3..<3).map { Float($0) }, rowCount: 2, columnCount: innerDimension) + let y = Array2d(data: Array(-6..<6).map { Float($0) }, rowCount: innerDimension, columnCount: 4) + + let norm = Array2d.Norm.Lp(p: 2.0) + let xNormalized = x.normalizedRows(norm: norm) + let yNormalized = y.transposed().normalizedRows(norm: norm).transposed() + let expected = xNormalized.mul(yNormalized) + + let scalingFactor = 100 + let modulus = UInt32(scalingFactor * scalingFactor * innerDimension + 1) + let z = try x.fixedPointCosineSimilarity(y, modulus: modulus, scalingFactor: Float(scalingFactor)) + + XCTAssertIsClose(fixedPointCosineSimilarityError(innerDimension: 3, scalingFactor: 100), 0.010025) + let absoluteError = fixedPointCosineSimilarityError( + innerDimension: innerDimension, + scalingFactor: scalingFactor) + for (got, expected) in zip(z.data, expected.data) { + XCTAssertIsClose(got, expected, absoluteTolerance: absoluteError) + } + } +}