Skip to content

Commit

Permalink
Adds PNNS server (#76)
Browse files Browse the repository at this point in the history
  • Loading branch information
fboemer authored Aug 26, 2024
1 parent d639913 commit 79684bb
Show file tree
Hide file tree
Showing 20 changed files with 879 additions and 219 deletions.
2 changes: 1 addition & 1 deletion Sources/PrivateInformationRetrieval/MulPir.swift
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ public enum MulPir<Scheme: HeScheme>: IndexPirProtocol {
static func evaluationKeyConfiguration(
expandedQueryCount: Int,
degree: Int,
keyCompression: PirKeyCompressionStrategy) -> HomomorphicEncryption.EvaluationKeyConfiguration
keyCompression: PirKeyCompressionStrategy) -> EvaluationKeyConfiguration
{
let maxExpansionDepth = min(expandedQueryCount, degree).ceilLog2
let smallestPower = degree.log2 - maxExpansionDepth + 1
Expand Down
48 changes: 44 additions & 4 deletions Sources/PrivateNearestNeighborsSearch/CiphertextMatrix.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import HomomorphicEncryption
/// Stores a matrix of scalars as ciphertexts.
public struct CiphertextMatrix<Scheme: HeScheme, Format: PolyFormat>: Equatable, Sendable {
/// Dimensions of the matrix.
@usableFromInline let dimensions: MatrixDimensions
@usableFromInline var dimensions: MatrixDimensions

/// Dimensions of the scalar matrix in a SIMD-encoded plaintext.
@usableFromInline let simdDimensions: SimdEncodingDimensions
Expand All @@ -26,7 +26,7 @@ public struct CiphertextMatrix<Scheme: HeScheme, Format: PolyFormat>: Equatable,
@usableFromInline let packing: MatrixPacking

/// Encrypted data.
@usableFromInline let ciphertexts: [Ciphertext<Scheme, Format>]
@usableFromInline package var ciphertexts: [Ciphertext<Scheme, Format>]

/// The parameter context.
@usableFromInline var context: Context<Scheme> {
Expand Down Expand Up @@ -111,8 +111,8 @@ extension CiphertextMatrix {
ciphertexts: evalCiphertexts)
}

/// Converts the plaintext matrix to ``Coeff`` format.
/// - Returns: The converted plaintext ciphertext.
/// Converts the ciphertext matrix to ``Coeff`` format.
/// - Returns: The converted ciphertext matrix.
/// - Throws: Error upon failure to convert the ciphertext matrix.
@inlinable
public func convertToCoeffFormat() throws -> CiphertextMatrix<Scheme, Coeff> {
Expand All @@ -126,6 +126,33 @@ extension CiphertextMatrix {
packing: packing,
ciphertexts: coeffCiphertexts)
}

/// Converts the ciphertext matrix to canonical format.
/// - Returns: The converted ciphertext matrix.
/// - Throws: Error upon failure to convert the ciphertext matrix.
@inlinable
public func convertToCanonicalFormat() throws -> CiphertextMatrix<Scheme, Scheme.CanonicalCiphertextFormat> {
if Scheme.CanonicalCiphertextFormat.self == Coeff.self {
// swiftlint:disable:next force_cast
return try convertToCoeffFormat() as! CiphertextMatrix<Scheme, Scheme.CanonicalCiphertextFormat>
}
if Scheme.CanonicalCiphertextFormat.self == Eval.self {
// swiftlint:disable:next force_cast
return try convertToEvalFormat() as! CiphertextMatrix<Scheme, Scheme.CanonicalCiphertextFormat>
}
fatalError("Unsupported Format \(Format.description)")
}

/// Performs modulus switching to a single modulus.
///
/// If the ciphertexts already have a single modulus, this is a no-op.
/// - Throws: Error upon failure to modulus switch.
@inlinable
public mutating func modSwitchDownToSingle() throws where Format == Scheme.CanonicalCiphertextFormat {
for index in 0..<ciphertexts.count {
try ciphertexts[index].modSwitchDownToSingle()
}
}
}

extension CiphertextMatrix {
Expand Down Expand Up @@ -278,4 +305,17 @@ extension CiphertextMatrix {
packing: packing,
ciphertexts: [ciphertext])
}

/// Returns the noise budget.
/// - Parameters:
/// - secretKey: Secret key.
/// - variableTime: If `true`, indicates the secret key coefficients may be leaked through timing.
/// - Returns: The noise budget.
/// - Throws: Error upon failure to compute the noise budget.
@inlinable
func noiseBudget(using secretKey: Scheme.SecretKey, variableTime: Bool) throws -> Double {
try ciphertexts.map { ciphertext in
try ciphertext.noiseBudget(using: secretKey, variableTime: variableTime)
}.min() ?? -Double.infinity
}
}
125 changes: 44 additions & 81 deletions Sources/PrivateNearestNeighborsSearch/Client.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,51 +12,57 @@
// See the License for the specific language governing permissions and
// limitations under the License.

import Algorithms
import Foundation
import HomomorphicEncryption

/// Private nearest neighbors client.
struct Client<Scheme: HeScheme> {
public struct Client<Scheme: HeScheme> {
/// Configuration.
let config: ClientConfig<Scheme>
public let config: ClientConfig<Scheme>

/// One context per plaintext modulus.
let contexts: [Context<Scheme>]
@usableFromInline let contexts: [Context<Scheme>]

/// Performs composition of the plaintext CRT responses.
let crtComposer: CrtComposer<Scheme.Scalar>
@usableFromInline let crtComposer: CrtComposer<Scheme.Scalar>

/// Context for the plaintext CRT moduli.
let plaintextContext: PolyContext<Scheme.Scalar>
@usableFromInline let plaintextContext: PolyContext<Scheme.Scalar>

var evaluationKeyConfiguration: HomomorphicEncryption.EvaluationKeyConfiguration {
/// The evaluation key configuration used by the ``Server``.
public var evaluationKeyConfiguration: EvaluationKeyConfiguration {
config.evaluationKeyConfig
}

/// Creates a new ``Client``.
/// - Parameter config: Client configuration.
/// - Throws: Error upon failure to create a new client.
/// - Parameters:
/// - config: Client configuration.
/// - contexts: Contexts for HE computation, one per plaintext modulus.
/// - Throws: Error upon failure to create the client.
@inlinable
init(config: ClientConfig<Scheme>) throws {
public init(config: ClientConfig<Scheme>, contexts: [Context<Scheme>] = []) throws {
guard config.distanceMetric == .cosineSimilarity else {
throw PnnsError.wrongDistanceMetric(got: config.distanceMetric, expected: .cosineSimilarity)
}
self.config = config
let extraEncryptionParams = try config.extraPlaintextModuli.map { plaintextModulus in
try EncryptionParameters<Scheme>(
polyDegree: config.encryptionParams.polyDegree,
plaintextModulus: plaintextModulus,
coefficientModuli: config.encryptionParams.coefficientModuli,
errorStdDev: config.encryptionParams.errorStdDev,
securityLevel: config.encryptionParams.securityLevel)
}
let encryptionParams = [config.encryptionParams] + extraEncryptionParams
self.contexts = try encryptionParams.map { encryptionParams in
try Context(encryptionParameters: encryptionParams)

if !contexts.isEmpty {
precondition(contexts.count == config.encryptionParameters.count)
for (context, encryptionParameters) in zip(contexts, config.encryptionParameters) {
guard context.encryptionParameters == encryptionParameters else {
throw PnnsError.wrongEncryptionParameters(
got: context.encryptionParameters,
expected: encryptionParameters)
}
}
self.contexts = contexts
} else {
self.contexts = try config.encryptionParameters.map { encryptionParams in
try Context(encryptionParameters: encryptionParams)
}
}
self.plaintextContext = try PolyContext(
degree: config.encryptionParams.polyDegree,
degree: config.encryptionParameters[0].polyDegree,
moduli: config.plaintextModuli)
self.crtComposer = try CrtComposer(polyContext: plaintextContext)
}
Expand All @@ -68,17 +74,17 @@ struct Client<Scheme: HeScheme> {
/// - Returns: The query.
/// - Throws: Error upon failure to generate the query.
@inlinable
func generateQuery(vectors: Array2d<Float>, using secretKey: SecretKey<Scheme>) throws -> Query<Scheme> {
let scaledVectors: Array2d<Scheme.SignedScalar> = vectors.normalizedRows(norm: Array2d<Float>.Norm.Lp(p: 2.0))
.scaled(by: Float(config.scalingFactor)).rounded()
let dimensions = try MatrixDimensions(rowCount: vectors.rowCount, columnCount: vectors.columnCount)

public func generateQuery(for vectors: Array2d<Float>,
using secretKey: SecretKey<Scheme>) throws -> Query<Scheme>
{
let scaledVectors: Array2d<Scheme.SignedScalar> = vectors
.normalizedScaledAndRounded(scalingFactor: Float(config.scalingFactor))
let matrices = try contexts.map { context in
// For a single plaintext modulus, reduction isn't necessary
let shouldReduce = contexts.count > 1
let plaintextMatrix = try PlaintextMatrix(
context: context,
dimensions: dimensions,
dimensions: MatrixDimensions(vectors.shape),
packing: config.queryPacking,
signedValues: scaledVectors.data,
reduce: shouldReduce)
Expand All @@ -94,7 +100,7 @@ struct Client<Scheme: HeScheme> {
/// - Returns: The distances from the query vectors to the database rows.
/// - Throws: Error upon failure to decrypt the response.
@inlinable
func decrypt(response: Response<Scheme>, using secretKey: SecretKey<Scheme>) throws -> DatabaseDistances {
public func decrypt(response: Response<Scheme>, using secretKey: SecretKey<Scheme>) throws -> DatabaseDistances {
guard let dimensions = response.ciphertextMatrices.first?.dimensions else {
throw PnnsError.emptyCiphertextArray
}
Expand Down Expand Up @@ -123,65 +129,22 @@ struct Client<Scheme: HeScheme> {
entryMetadatas: response.entryMetadatas)
}

/// Generates a secret key for query encryption and response decryption.
/// - Returns: A freshly generated secret key.
/// - Throws: Error upon failure to generate a secret key.
@inlinable
public func generateSecretKey() throws -> SecretKey<Scheme> {
try contexts[0].generateSecretKey()
}

/// Generates an ``EvaluationKey`` for use in nearest neighbors search.
/// - Parameter secretKey: Secret key used to generate the evaluation key.
/// - Returns: The evaluation key.
/// - Throws: Error upon failure to generate the evaluation key.
/// - Warning: Uses the first context to generate the evaluation key. So either the HE scheme should generate
/// evaluation keys independent of the plaintext modulus (as in BFV), or there should be just one plaintext modulus.
@inlinable
func generateEvaluationKey(using secretKey: SecretKey<Scheme>) throws -> EvaluationKey<Scheme> {
public func generateEvaluationKey(using secretKey: SecretKey<Scheme>) throws -> EvaluationKey<Scheme> {
try contexts[0].generateEvaluationKey(configuration: evaluationKeyConfiguration, using: secretKey)
}
}

extension Array2d where T == Float {
/// A mapping from vectors to non-negative real numbers.
@usableFromInline
enum Norm {
case Lp(p: Float) // sum_i (|x_i|^p)^{1/p}
}

/// Normalizes each row in the matrix.
@inlinable
func normalizedRows(norm: Norm) -> Array2d<Float> {
switch norm {
case let Norm.Lp(p):
let normalizedValues = data.chunks(ofCount: columnCount).flatMap { row in
let sumOfPowers = row.map { pow($0, p) }.reduce(0, +)
let norm = pow(sumOfPowers, 1 / p)
return row.map { value in
if sumOfPowers.isZero {
Float.zero
} else {
value / norm
}
}
}
return Array2d<Float>(
data: normalizedValues,
rowCount: rowCount,
columnCount: columnCount)
}
}

/// Returns the matrix where each entry is rounded to the closest integer.
@inlinable
func rounded<V: FixedWidthInteger & SignedInteger>() -> Array2d<V> {
Array2d<V>(
data: data.map { value in V(value.rounded()) },
rowCount: rowCount,
columnCount: columnCount)
}

/// Returns the matrix where each entry has been multiplied by a scaling factor.
/// - Parameter scalingFactor: The factor to multiply each entry by.
/// - Returns: The scaled matrix.
@inlinable
func scaled(by scalingFactor: Float) -> Array2d<Float> {
Array2d<Float>(
data: data.map { value in value * scalingFactor },
rowCount: rowCount,
columnCount: columnCount)
}
}
52 changes: 20 additions & 32 deletions Sources/PrivateNearestNeighborsSearch/Config.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public enum DistanceMetric: CaseIterable, Codable, Equatable, Hashable, Sendable
/// Client configuration.
public struct ClientConfig<Scheme: HeScheme>: Codable, Equatable, Hashable, Sendable {
/// Encryption parameters.
public let encryptionParams: EncryptionParameters<Scheme>
public let encryptionParameters: [EncryptionParameters<Scheme>]
/// Factor by which to scale floating-point entries before rounding to integers.
public let scalingFactor: Int
/// Packing for the query.
Expand All @@ -40,9 +40,7 @@ public struct ClientConfig<Scheme: HeScheme>: Codable, Equatable, Hashable, Send
public let extraPlaintextModuli: [Scheme.Scalar]

/// The plaintext CRT moduli.
var plaintextModuli: [Scheme.Scalar] {
[encryptionParams.plaintextModulus] + extraPlaintextModuli
}
public var plaintextModuli: [Scheme.Scalar] { encryptionParameters.map(\.plaintextModulus) }

/// Creates a new ``ClientConfig``.
/// - Parameters:
Expand All @@ -54,16 +52,25 @@ public struct ClientConfig<Scheme: HeScheme>: Codable, Equatable, Hashable, Send
/// - distanceMetric: Metric for nearest neighbors computation
/// - extraPlaintextModuli: For plaintext CRT, the list of extra plaintext moduli. The first plaintext modulus
/// will be the one in ``ClientConfig/encryptionParams``.
/// - Throws: Error upon failure to create a new client config.
public init(
encryptionParams: EncryptionParameters<Scheme>,
scalingFactor: Int,
queryPacking: MatrixPacking,
vectorDimension: Int,
evaluationKeyConfig: EvaluationKeyConfiguration,
distanceMetric: DistanceMetric,
extraPlaintextModuli: [Scheme.Scalar] = [])
extraPlaintextModuli: [Scheme.Scalar] = []) throws
{
self.encryptionParams = encryptionParams
let extraEncryptionParams = try extraPlaintextModuli.map { plaintextModulus in
try EncryptionParameters<Scheme>(
polyDegree: encryptionParams.polyDegree,
plaintextModulus: plaintextModulus,
coefficientModuli: encryptionParams.coefficientModuli,
errorStdDev: encryptionParams.errorStdDev,
securityLevel: encryptionParams.securityLevel)
}
self.encryptionParameters = [encryptionParams] + extraEncryptionParams
self.scalingFactor = scalingFactor
self.queryPacking = queryPacking
self.vectorDimension = vectorDimension
Expand All @@ -72,30 +79,15 @@ public struct ClientConfig<Scheme: HeScheme>: Codable, Equatable, Hashable, Send
self.extraPlaintextModuli = extraPlaintextModuli
}

static func maxScalingFactor(vectorDimension: Int, distanceMetric: DistanceMetric,
plaintextModuli: [Scheme.Scalar]) -> Int
@inlinable
package static func maxScalingFactor(distanceMetric: DistanceMetric, vectorDimension: Int,
plaintextModuli: [Scheme.Scalar]) -> Int
{
precondition(distanceMetric == .cosineSimilarity)
let t = plaintextModuli.map { Float($0) }.reduce(1, *)
let scalingFactor = (((t - 1) / 2).squareRoot() - Float(vectorDimension).squareRoot() / 2).rounded(.down)
return Int(scalingFactor)
}

/// Computes the encryption parameters, one per plaintext modulus.
///
/// - Returns: The encryption parameters
/// - Throws: Error upon failure to generate the encryption parameters.
func encryptionParameters() throws -> [EncryptionParameters<Scheme>] {
let extraEncryptionParams = try extraPlaintextModuli.map { plaintextModulus in
try EncryptionParameters<Scheme>(
polyDegree: encryptionParams.polyDegree,
plaintextModulus: plaintextModulus,
coefficientModuli: encryptionParams.coefficientModuli,
errorStdDev: encryptionParams.errorStdDev,
securityLevel: encryptionParams.securityLevel)
}
return [encryptionParams] + extraEncryptionParams
}
}

/// Server configuration.
Expand All @@ -110,6 +102,10 @@ public struct ServerConfig<Scheme: HeScheme>: Codable, Equatable, Hashable, Send
public var plaintextModuli: [Scheme.Scalar] { clientConfig.plaintextModuli }
/// Distance metric.
public var distanceMetric: DistanceMetric { clientConfig.distanceMetric }
/// The encryption parameters, one per plaintext modulus.
public var encryptionParameters: [EncryptionParameters<Scheme>] {
clientConfig.encryptionParameters
}

/// Creates a new ``ServerConfig``.
/// - Parameters:
Expand All @@ -122,12 +118,4 @@ public struct ServerConfig<Scheme: HeScheme>: Codable, Equatable, Hashable, Send
self.clientConfig = clientConfig
self.databasePacking = databasePacking
}

/// Computes the encryption parameters, one per plaintext modulus.
///
/// - Returns: The encryption parameters
/// - Throws: Error upon failure to generate the encryption parameters.
public func encryptionParameters() throws -> [EncryptionParameters<Scheme>] {
try clientConfig.encryptionParameters()
}
}
Loading

0 comments on commit 79684bb

Please sign in to comment.