From 08777448530937ec04dbb810c9612782ceb00796 Mon Sep 17 00:00:00 2001 From: Fabian Boemer Date: Tue, 20 Aug 2024 15:33:36 -0700 Subject: [PATCH] Make decoding more generic (#68) --- .../Bfv/Bfv+Encode.swift | 14 +-- .../HomomorphicEncryption/Ciphertext.swift | 12 +- Sources/HomomorphicEncryption/Encoding.swift | 24 +--- Sources/HomomorphicEncryption/HeScheme.swift | 104 ++++++++++++++++-- .../HomomorphicEncryption/NoOpScheme.swift | 18 ++- Sources/HomomorphicEncryption/Plaintext.swift | 28 +---- .../HeAPITests.swift | 11 +- 7 files changed, 128 insertions(+), 83 deletions(-) diff --git a/Sources/HomomorphicEncryption/Bfv/Bfv+Encode.swift b/Sources/HomomorphicEncryption/Bfv/Bfv+Encode.swift index f2a1b723..9340f07b 100644 --- a/Sources/HomomorphicEncryption/Bfv/Bfv+Encode.swift +++ b/Sources/HomomorphicEncryption/Bfv/Bfv+Encode.swift @@ -63,27 +63,25 @@ extension Bfv { @inlinable // swiftlint:disable:next missing_docs attributes - public static func decode(plaintext: CoeffPlaintext, format: EncodeFormat) throws -> [V] { + public static func decodeCoeff(plaintext: CoeffPlaintext, format: EncodeFormat) throws -> [V] { try plaintext.context.decode(plaintext: plaintext, format: format) } @inlinable // swiftlint:disable:next missing_docs attributes - public static func decode(plaintext: CoeffPlaintext, format: EncodeFormat) throws -> [V] { + public static func decodeCoeff(plaintext: CoeffPlaintext, format: EncodeFormat) throws -> [V] { try plaintext.context.decode(plaintext: plaintext, format: format) } @inlinable // swiftlint:disable:next missing_docs attributes - public static func decode(plaintext: EvalPlaintext, format: EncodeFormat) throws -> [V] { - let coeffPlaintext = try plaintext.convertToCoeffFormat() - return try coeffPlaintext.decode(format: format) + public static func decodeEval(plaintext: EvalPlaintext, format: EncodeFormat) throws -> [V] { + try plaintext.convertToCoeffFormat().decode(format: format) } @inlinable // swiftlint:disable:next missing_docs attributes - public static func decode(plaintext: EvalPlaintext, format: EncodeFormat) throws -> [V] { - let coeffPlaintext = try plaintext.convertToCoeffFormat() - return try coeffPlaintext.decode(format: format) + public static func decodeEval(plaintext: EvalPlaintext, format: EncodeFormat) throws -> [V] { + try plaintext.convertToCoeffFormat().decode(format: format) } } diff --git a/Sources/HomomorphicEncryption/Ciphertext.swift b/Sources/HomomorphicEncryption/Ciphertext.swift index e7cec143..dded0f85 100644 --- a/Sources/HomomorphicEncryption/Ciphertext.swift +++ b/Sources/HomomorphicEncryption/Ciphertext.swift @@ -63,17 +63,7 @@ public struct Ciphertext: Equatable, Senda /// - seelaso: ``Ciphertext/isTransparent()`` @inlinable public static func zero(context: Context, moduliCount: Int? = nil) throws -> Ciphertext { - if Format.self == Coeff.self { - let coeffCiphertext = try Scheme.zeroCiphertextCoeff(context: context, moduliCount: moduliCount) - // swiftlint:disable:next force_cast - return coeffCiphertext as! Ciphertext - } - if Format.self == Eval.self { - let evalCiphertext = try Scheme.zeroCiphertextEval(context: context, moduliCount: moduliCount) - // swiftlint:disable:next force_cast - return evalCiphertext as! Ciphertext - } - throw HeError.unsupportedHeOperation() + try Scheme.zero(context: context, moduliCount: moduliCount) } // MARK: ciphertext += plaintext diff --git a/Sources/HomomorphicEncryption/Encoding.swift b/Sources/HomomorphicEncryption/Encoding.swift index 15f8fb17..22e47884 100644 --- a/Sources/HomomorphicEncryption/Encoding.swift +++ b/Sources/HomomorphicEncryption/Encoding.swift @@ -96,8 +96,8 @@ extension Context { /// - Returns: The decoded values. /// - Throws: Error upon failure to decode. @inlinable - public func decode(plaintext: Plaintext, - format: EncodeFormat) throws -> [T] + func decode(plaintext: Plaintext, + format: EncodeFormat) throws -> [T] { switch format { case .coefficient: @@ -115,27 +115,13 @@ extension Context { /// - Returns: The decoded signed values. /// - Throws: Error upon failure to decode. @inlinable - public func decode(plaintext: Plaintext, format: EncodeFormat) throws -> [T] { + func decode(plaintext: Plaintext, format: EncodeFormat) throws -> [T] { let unsignedValues: [Scheme.Scalar] = try decode(plaintext: plaintext, format: format) return unsignedValues.map { value in T(value.remainderToCentered(modulus: plaintextModulus)) } } - /// Decodes a plaintext with the given format. - /// - /// - Parameters: - /// - plaintext: Plaintext to decode. - /// - format: Format to decode with. - /// - Returns: The decoded values. - /// - Throws: Error upon failure to decode. - @inlinable - public func decode(plaintext: Plaintext, - format: EncodeFormat) throws -> [T] - { - try Scheme.decode(plaintext: plaintext, format: format) - } - /// Decodes a plaintext with the given format, into signed values. /// /// - Parameters: @@ -144,8 +130,8 @@ extension Context { /// - Returns: The decoded signed values. /// - Throws: Error upon failure to decode. @inlinable - public func decode(plaintext: Plaintext, format: EncodeFormat) throws -> [T] { - try Scheme.decode(plaintext: plaintext, format: format) + func decode(plaintext: Plaintext, format: EncodeFormat) throws -> [T] { + try Scheme.decodeEval(plaintext: plaintext, format: format) } @inlinable diff --git a/Sources/HomomorphicEncryption/HeScheme.swift b/Sources/HomomorphicEncryption/HeScheme.swift index 56e3d1fa..ec68040d 100644 --- a/Sources/HomomorphicEncryption/HeScheme.swift +++ b/Sources/HomomorphicEncryption/HeScheme.swift @@ -90,6 +90,8 @@ public struct SimdEncodingDimensions: Codable, Equatable, Hashable, Sendable { public protocol HeScheme { /// Coefficient type for each polynomial. associatedtype Scalar: ScalarType + /// Coefficient type for signed encoding/decoding. + typealias SignedScalar = Scalar.SignedScalar /// Polynomial format for the . associatedtype CanonicalCiphertextFormat: PolyFormat @@ -223,8 +225,8 @@ public protocol HeScheme { /// - format: Encoding format of the plaintext. /// - Returns: The decoded values. /// - Throws: Error upon failure to decode the plaintext. - /// - seealso: ``Plaintext/decode(format:)-9l5kz`` for an alternative API. - static func decode(plaintext: CoeffPlaintext, format: EncodeFormat) throws -> [T] + /// - seealso: ``Plaintext/decode(format:)-i0qm`` for an alternative API. + static func decodeCoeff(plaintext: CoeffPlaintext, format: EncodeFormat) throws -> [T] /// Decodes a plaintext in ``Coeff`` format into signed values. /// - Parameters: @@ -232,8 +234,8 @@ public protocol HeScheme { /// - format: Encoding format of the plaintext. /// - Returns: The decoded signed values. /// - Throws: Error upon failure to decode the plaintext. - /// - seealso: ``Plaintext/decode(format:)-9l5kz`` for an alternative API. - static func decode(plaintext: CoeffPlaintext, format: EncodeFormat) throws -> [T] + /// - seealso: ``Plaintext/decode(format:)-5081e`` for an alternative API. + static func decodeCoeff(plaintext: CoeffPlaintext, format: EncodeFormat) throws -> [T] /// Decodes a plaintext in ``Eval`` format. /// - Parameters: @@ -241,8 +243,8 @@ public protocol HeScheme { /// - format: Encoding format of the plaintext. /// - Returns: The decoded values. /// - Throws: Error upon failure to decode the plaintext. - /// - seealso: ``Plaintext/decode(format:)-i9hh`` for an alternative API. - static func decode(plaintext: EvalPlaintext, format: EncodeFormat) throws -> [T] + /// - seealso: ``Plaintext/decode(format:)-i0qm`` for an alternative API. + static func decodeEval(plaintext: EvalPlaintext, format: EncodeFormat) throws -> [T] /// Decodes a plaintext in ``Eval`` format to signed values. /// - Parameters: @@ -250,8 +252,8 @@ public protocol HeScheme { /// - format: Encoding format of the plaintext. /// - Returns: The decoded signed values. /// - Throws: Error upon failure to decode the plaintext. - /// - seealso: ``Plaintext/decode(format:)-i9hh`` for an alternative API. - static func decode(plaintext: EvalPlaintext, format: EncodeFormat) throws -> [T] + /// - seealso: ``Plaintext/decode(format:)-5081e`` for an alternative API. + static func decodeEval(plaintext: EvalPlaintext, format: EncodeFormat) throws -> [T] /// Symmetric secret key encryption of a plaintext. /// - Parameters: @@ -883,6 +885,92 @@ extension HeScheme { fatalError("Unsupported Format \(Format.description)") // swiftlint:enable force_cast } + + /// Generates a ciphertext of zeros. + /// + /// A zero ciphertext may arise from HE computations, e.g., by subtracting a ciphertext from itself, or multiplying + /// a ciphertext with a zero plaintext. + /// + /// - Parameters: + /// - context: Context for HE computation. + /// - moduliCount: Number of moduli in the zero ciphertext. If `nil`, the ciphertext will have the ciphertext + /// context with all the coefficient moduli in `context`. + /// - Returns: A zero ciphertext. + /// - Throws: Error upon failure to encode. + /// - Warning: a zero ciphertext is *transparent*, i.e., everyone can see the the underlying plaintext, zero in + /// this case. Transparency can propagate to ciphertexts operating with transparent ciphertexts, e.g. + /// ``` + /// transparentCiphertext * ciphertext = transparentCiphertext + /// transparentCiphertext * plaintext = transparentCiphertext + /// transparentCiphertext + plaintext = transparentCiphertext + /// ``` + /// - seelaso: ``Ciphertext/isTransparent()`` + @inlinable + public static func zero(context: Context, + moduliCount: Int? = nil) throws -> Ciphertext + { + if Format.self == Coeff.self { + let coeffCiphertext = try zeroCiphertextCoeff(context: context, moduliCount: moduliCount) + // swiftlint:disable:next force_cast + return coeffCiphertext as! Ciphertext + } + if Format.self == Eval.self { + let evalCiphertext = try zeroCiphertextEval(context: context, moduliCount: moduliCount) + // swiftlint:disable:next force_cast + return evalCiphertext as! Ciphertext + } + fatalError("Unsupported Format \(Format.description)") + } + + /// Decodes a plaintext. + /// - Parameters: + /// - plaintext: Plaintext to decode. + /// - format: Encoding format of the plaintext. + /// - Returns: The decoded values. + /// - Throws: Error upon failure to decode the plaintext. + /// - seealso: ``Plaintext/decode(format:)-i0qm`` for an alternative API. + @inlinable + public static func decode( + plaintext: Plaintext, + format: EncodeFormat) throws -> [T] + { + if Format.self == Coeff.self { + // swiftlint:disable:next force_cast + let coeffPlaintext = plaintext as! CoeffPlaintext + return try decodeCoeff(plaintext: coeffPlaintext, format: format) + } + if Format.self == Eval.self { + // swiftlint:disable:next force_cast + let evalPlaintext = plaintext as! EvalPlaintext + return try decodeEval(plaintext: evalPlaintext, format: format) + } + fatalError("Unsupported Format \(Format.description)") + } + + /// Decodes a plaintext to signed values. + /// - Parameters: + /// - plaintext: Plaintext to decode. + /// - format: Encoding format of the plaintext. + /// - Returns: The decoded signed values. + /// - Throws: Error upon failure to decode the plaintext. + /// - seealso: ``Plaintext/decode(format:)-5081e`` for an alternative API. + @inlinable + public static func decode( + plaintext: Plaintext, + format: EncodeFormat) throws -> [T] + { + if Format.self == Coeff.self { + // swiftlint:disable:next force_cast + let coeffPlaintext = plaintext as! CoeffPlaintext + return try decodeCoeff(plaintext: coeffPlaintext, format: format) + } + if Format.self == Eval.self { + // swiftlint:disable:next force_cast + let evalPlaintext = plaintext as! EvalPlaintext + return try decodeEval(plaintext: evalPlaintext, format: format) + } + fatalError("Unsupported Format \(Format.description)") + } } extension HeScheme { diff --git a/Sources/HomomorphicEncryption/NoOpScheme.swift b/Sources/HomomorphicEncryption/NoOpScheme.swift index 227eff00..5fa98d3e 100644 --- a/Sources/HomomorphicEncryption/NoOpScheme.swift +++ b/Sources/HomomorphicEncryption/NoOpScheme.swift @@ -85,24 +85,20 @@ public enum NoOpScheme: HeScheme { return try EvalPlaintext(context: context, poly: coeffPlaintext.poly.forwardNtt()) } - public static func decode(plaintext: CoeffPlaintext, format: EncodeFormat) throws -> [T] where T: ScalarType { + public static func decodeCoeff(plaintext: CoeffPlaintext, format: EncodeFormat) throws -> [T] { try plaintext.context.decode(plaintext: plaintext, format: format) } - public static func decode(plaintext: CoeffPlaintext, format: EncodeFormat) throws -> [T] - where T: SignedScalarType - { - try plaintext.context.decode(plaintext: plaintext, format: format) + public static func decodeEval(plaintext: EvalPlaintext, format: EncodeFormat) throws -> [T] { + try plaintext.inverseNtt().decode(format: format) } - public static func decode(plaintext: EvalPlaintext, format: EncodeFormat) throws -> [T] where T: ScalarType { - try decode(plaintext: plaintext.inverseNtt(), format: format) + public static func decodeCoeff(plaintext: CoeffPlaintext, format: EncodeFormat) throws -> [T] { + try plaintext.context.decode(plaintext: plaintext, format: format) } - public static func decode(plaintext: EvalPlaintext, format: EncodeFormat) throws -> [T] - where T: SignedScalarType - { - try decode(plaintext: plaintext.inverseNtt(), format: format) + public static func decodeEval(plaintext: EvalPlaintext, format: EncodeFormat) throws -> [T] { + try plaintext.inverseNtt().decode(format: format) } public static func zeroCiphertextCoeff(context: Context, moduliCount _: Int?) throws -> CoeffCiphertext { diff --git a/Sources/HomomorphicEncryption/Plaintext.swift b/Sources/HomomorphicEncryption/Plaintext.swift index d8056af9..0fade8f4 100644 --- a/Sources/HomomorphicEncryption/Plaintext.swift +++ b/Sources/HomomorphicEncryption/Plaintext.swift @@ -174,41 +174,21 @@ extension Plaintext { return Plaintext(context: context, poly: coeffPoly) } - /// Decodes a plaintext in ``Coeff`` format. + /// Decodes a plaintext. /// - Parameter format: Encoding format of the plaintext. /// - Returns: The decoded values. /// - Throws: Error upon failure to decode the plaintext. - /// - seealso: ``HeScheme/decode(plaintext:format:)-h6vl`` for an alternative API. @inlinable - public func decode(format: EncodeFormat) throws -> [T] where Format == Coeff { + public func decode(format: EncodeFormat) throws -> [T] { try Scheme.decode(plaintext: self, format: format) } - /// Decodes a plaintext in ``Coeff`` format to signed values. + /// Decodes a plaintext to signed values. /// - Parameter format: Encoding format of the plaintext. /// - Returns: The decoded signed values. /// - Throws: Error upon failure to decode the plaintext. @inlinable - public func decode(format: EncodeFormat) throws -> [T] where Format == Coeff { - try Scheme.decode(plaintext: self, format: format) - } - - /// Decodes a plaintext in ``Eval`` format. - /// - Parameter format: Encoding format of the plaintext. - /// - Returns: The decoded values. - /// - Throws: Error upon failure to decode the plaintext. - /// - seealso: ``HeScheme/decode(plaintext:format:)-663x4`` for an alternative API. - @inlinable - public func decode(format: EncodeFormat) throws -> [T] where Format == Eval { - try Scheme.decode(plaintext: self, format: format) - } - - /// Decodes a plaintext in ``Eval`` format to signed values. - /// - Parameter format: Encoding format of the plaintext. - /// - Returns: Error upon failure to decode the plaintext. - /// - Throws: Error upon failure to decode the plaintext. - @inlinable - public func decode(format: EncodeFormat) throws -> [T] where Format == Eval { + public func decode(format: EncodeFormat) throws -> [T] { try Scheme.decode(plaintext: self, format: format) } diff --git a/Tests/HomomorphicEncryptionTests/HeAPITests.swift b/Tests/HomomorphicEncryptionTests/HeAPITests.swift index da18f570..4b6de5de 100644 --- a/Tests/HomomorphicEncryptionTests/HeAPITests.swift +++ b/Tests/HomomorphicEncryptionTests/HeAPITests.swift @@ -157,8 +157,7 @@ class HeAPITests: XCTestCase { let plaintextCoeffSigned: Plaintext = try context.encode( signedValues: signedData, format: encodeFormat) - let roundTrip: [Scheme.Scalar.SignedScalar] = try context.decode( - plaintext: plaintextCoeffSigned, + let roundTrip: [Scheme.Scalar.SignedScalar] = try plaintextCoeffSigned.decode( format: encodeFormat) XCTAssertEqual(roundTrip, signedData) case is Eval.Type: @@ -1084,3 +1083,11 @@ class HeAPITests: XCTestCase { try runBfvTests(UInt64.self) } } + +/// This will compile if Plaintext.decode is generic across PolyFormat. +extension Plaintext { + private func checkDecodeIsGeneric() throws { + let _: [Scheme.Scalar] = try decode(format: .coefficient) + let _: [Scheme.SignedScalar] = try decode(format: .coefficient) + } +}