Skip to content

Commit

Permalink
Make decoding more generic (#68)
Browse files Browse the repository at this point in the history
  • Loading branch information
fboemer authored Aug 20, 2024
1 parent 534de8c commit 0877744
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 83 deletions.
14 changes: 6 additions & 8 deletions Sources/HomomorphicEncryption/Bfv/Bfv+Encode.swift
Original file line number Diff line number Diff line change
Expand Up @@ -63,27 +63,25 @@ extension Bfv {

@inlinable
// swiftlint:disable:next missing_docs attributes
public static func decode<V: ScalarType>(plaintext: CoeffPlaintext, format: EncodeFormat) throws -> [V] {
public static func decodeCoeff<V: ScalarType>(plaintext: CoeffPlaintext, format: EncodeFormat) throws -> [V] {
try plaintext.context.decode(plaintext: plaintext, format: format)
}

@inlinable
// swiftlint:disable:next missing_docs attributes
public static func decode<V: SignedScalarType>(plaintext: CoeffPlaintext, format: EncodeFormat) throws -> [V] {
public static func decodeCoeff<V: SignedScalarType>(plaintext: CoeffPlaintext, format: EncodeFormat) throws -> [V] {
try plaintext.context.decode(plaintext: plaintext, format: format)
}

@inlinable
// swiftlint:disable:next missing_docs attributes
public static func decode<V: ScalarType>(plaintext: EvalPlaintext, format: EncodeFormat) throws -> [V] {
let coeffPlaintext = try plaintext.convertToCoeffFormat()
return try coeffPlaintext.decode(format: format)
public static func decodeEval<V: ScalarType>(plaintext: EvalPlaintext, format: EncodeFormat) throws -> [V] {
try plaintext.convertToCoeffFormat().decode(format: format)
}

@inlinable
// swiftlint:disable:next missing_docs attributes
public static func decode<V: SignedScalarType>(plaintext: EvalPlaintext, format: EncodeFormat) throws -> [V] {
let coeffPlaintext = try plaintext.convertToCoeffFormat()
return try coeffPlaintext.decode(format: format)
public static func decodeEval<V: SignedScalarType>(plaintext: EvalPlaintext, format: EncodeFormat) throws -> [V] {
try plaintext.convertToCoeffFormat().decode(format: format)
}
}
12 changes: 1 addition & 11 deletions Sources/HomomorphicEncryption/Ciphertext.swift
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,7 @@ public struct Ciphertext<Scheme: HeScheme, Format: PolyFormat>: Equatable, Senda
/// - seelaso: ``Ciphertext/isTransparent()``
@inlinable
public static func zero(context: Context<Scheme>, moduliCount: Int? = nil) throws -> Ciphertext<Scheme, Format> {
if Format.self == Coeff.self {
let coeffCiphertext = try Scheme.zeroCiphertextCoeff(context: context, moduliCount: moduliCount)
// swiftlint:disable:next force_cast
return coeffCiphertext as! Ciphertext<Scheme, Format>
}
if Format.self == Eval.self {
let evalCiphertext = try Scheme.zeroCiphertextEval(context: context, moduliCount: moduliCount)
// swiftlint:disable:next force_cast
return evalCiphertext as! Ciphertext<Scheme, Format>
}
throw HeError.unsupportedHeOperation()
try Scheme.zero(context: context, moduliCount: moduliCount)
}

// MARK: ciphertext += plaintext
Expand Down
24 changes: 5 additions & 19 deletions Sources/HomomorphicEncryption/Encoding.swift
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ extension Context {
/// - Returns: The decoded values.
/// - Throws: Error upon failure to decode.
@inlinable
public func decode<T: ScalarType>(plaintext: Plaintext<Scheme, Coeff>,
format: EncodeFormat) throws -> [T]
func decode<T: ScalarType>(plaintext: Plaintext<Scheme, Coeff>,
format: EncodeFormat) throws -> [T]
{
switch format {
case .coefficient:
Expand All @@ -115,27 +115,13 @@ extension Context {
/// - Returns: The decoded signed values.
/// - Throws: Error upon failure to decode.
@inlinable
public func decode<T: SignedScalarType>(plaintext: Plaintext<Scheme, Coeff>, format: EncodeFormat) throws -> [T] {
func decode<T: SignedScalarType>(plaintext: Plaintext<Scheme, Coeff>, format: EncodeFormat) throws -> [T] {
let unsignedValues: [Scheme.Scalar] = try decode(plaintext: plaintext, format: format)
return unsignedValues.map { value in
T(value.remainderToCentered(modulus: plaintextModulus))
}
}

/// Decodes a plaintext with the given format.
///
/// - Parameters:
/// - plaintext: Plaintext to decode.
/// - format: Format to decode with.
/// - Returns: The decoded values.
/// - Throws: Error upon failure to decode.
@inlinable
public func decode<T: ScalarType>(plaintext: Plaintext<Scheme, Eval>,
format: EncodeFormat) throws -> [T]
{
try Scheme.decode(plaintext: plaintext, format: format)
}

/// Decodes a plaintext with the given format, into signed values.
///
/// - Parameters:
Expand All @@ -144,8 +130,8 @@ extension Context {
/// - Returns: The decoded signed values.
/// - Throws: Error upon failure to decode.
@inlinable
public func decode<T: SignedScalarType>(plaintext: Plaintext<Scheme, Eval>, format: EncodeFormat) throws -> [T] {
try Scheme.decode(plaintext: plaintext, format: format)
func decode<T: SignedScalarType>(plaintext: Plaintext<Scheme, Eval>, format: EncodeFormat) throws -> [T] {
try Scheme.decodeEval(plaintext: plaintext, format: format)
}

@inlinable
Expand Down
104 changes: 96 additions & 8 deletions Sources/HomomorphicEncryption/HeScheme.swift
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ public struct SimdEncodingDimensions: Codable, Equatable, Hashable, Sendable {
public protocol HeScheme {
/// Coefficient type for each polynomial.
associatedtype Scalar: ScalarType
/// Coefficient type for signed encoding/decoding.
typealias SignedScalar = Scalar.SignedScalar

/// Polynomial format for the <doc:/documentation/HomomorphicEncryption/HeScheme/CanonicalCiphertext>.
associatedtype CanonicalCiphertextFormat: PolyFormat
Expand Down Expand Up @@ -223,35 +225,35 @@ public protocol HeScheme {
/// - format: Encoding format of the plaintext.
/// - Returns: The decoded values.
/// - Throws: Error upon failure to decode the plaintext.
/// - seealso: ``Plaintext/decode(format:)-9l5kz`` for an alternative API.
static func decode<T: ScalarType>(plaintext: CoeffPlaintext, format: EncodeFormat) throws -> [T]
/// - seealso: ``Plaintext/decode(format:)-i0qm`` for an alternative API.
static func decodeCoeff<T: ScalarType>(plaintext: CoeffPlaintext, format: EncodeFormat) throws -> [T]

/// Decodes a plaintext in ``Coeff`` format into signed values.
/// - Parameters:
/// - plaintext: Plaintext to decode.
/// - format: Encoding format of the plaintext.
/// - Returns: The decoded signed values.
/// - Throws: Error upon failure to decode the plaintext.
/// - seealso: ``Plaintext/decode(format:)-9l5kz`` for an alternative API.
static func decode<T: SignedScalarType>(plaintext: CoeffPlaintext, format: EncodeFormat) throws -> [T]
/// - seealso: ``Plaintext/decode(format:)-5081e`` for an alternative API.
static func decodeCoeff<T: SignedScalarType>(plaintext: CoeffPlaintext, format: EncodeFormat) throws -> [T]

/// Decodes a plaintext in ``Eval`` format.
/// - Parameters:
/// - plaintext: Plaintext to decode.
/// - format: Encoding format of the plaintext.
/// - Returns: The decoded values.
/// - Throws: Error upon failure to decode the plaintext.
/// - seealso: ``Plaintext/decode(format:)-i9hh`` for an alternative API.
static func decode<T: ScalarType>(plaintext: EvalPlaintext, format: EncodeFormat) throws -> [T]
/// - seealso: ``Plaintext/decode(format:)-i0qm`` for an alternative API.
static func decodeEval<T: ScalarType>(plaintext: EvalPlaintext, format: EncodeFormat) throws -> [T]

/// Decodes a plaintext in ``Eval`` format to signed values.
/// - Parameters:
/// - plaintext: Plaintext to decode.
/// - format: Encoding format of the plaintext.
/// - Returns: The decoded signed values.
/// - Throws: Error upon failure to decode the plaintext.
/// - seealso: ``Plaintext/decode(format:)-i9hh`` for an alternative API.
static func decode<T: SignedScalarType>(plaintext: EvalPlaintext, format: EncodeFormat) throws -> [T]
/// - seealso: ``Plaintext/decode(format:)-5081e`` for an alternative API.
static func decodeEval<T: SignedScalarType>(plaintext: EvalPlaintext, format: EncodeFormat) throws -> [T]

/// Symmetric secret key encryption of a plaintext.
/// - Parameters:
Expand Down Expand Up @@ -883,6 +885,92 @@ extension HeScheme {
fatalError("Unsupported Format \(Format.description)")
// swiftlint:enable force_cast
}

/// Generates a ciphertext of zeros.
///
/// A zero ciphertext may arise from HE computations, e.g., by subtracting a ciphertext from itself, or multiplying
/// a ciphertext with a zero plaintext.
///
/// - Parameters:
/// - context: Context for HE computation.
/// - moduliCount: Number of moduli in the zero ciphertext. If `nil`, the ciphertext will have the ciphertext
/// context with all the coefficient moduli in `context`.
/// - Returns: A zero ciphertext.
/// - Throws: Error upon failure to encode.
/// - Warning: a zero ciphertext is *transparent*, i.e., everyone can see the the underlying plaintext, zero in
/// this case. Transparency can propagate to ciphertexts operating with transparent ciphertexts, e.g.
/// ```
/// transparentCiphertext * ciphertext = transparentCiphertext
/// transparentCiphertext * plaintext = transparentCiphertext
/// transparentCiphertext + plaintext = transparentCiphertext
/// ```
/// - seelaso: ``Ciphertext/isTransparent()``
@inlinable
public static func zero<Format: PolyFormat>(context: Context<Self>,
moduliCount: Int? = nil) throws -> Ciphertext<Self, Format>
{
if Format.self == Coeff.self {
let coeffCiphertext = try zeroCiphertextCoeff(context: context, moduliCount: moduliCount)
// swiftlint:disable:next force_cast
return coeffCiphertext as! Ciphertext<Self, Format>
}
if Format.self == Eval.self {
let evalCiphertext = try zeroCiphertextEval(context: context, moduliCount: moduliCount)
// swiftlint:disable:next force_cast
return evalCiphertext as! Ciphertext<Self, Format>
}
fatalError("Unsupported Format \(Format.description)")
}

/// Decodes a plaintext.
/// - Parameters:
/// - plaintext: Plaintext to decode.
/// - format: Encoding format of the plaintext.
/// - Returns: The decoded values.
/// - Throws: Error upon failure to decode the plaintext.
/// - seealso: ``Plaintext/decode(format:)-i0qm`` for an alternative API.
@inlinable
public static func decode<T: ScalarType, Format: PolyFormat>(
plaintext: Plaintext<Self, Format>,
format: EncodeFormat) throws -> [T]
{
if Format.self == Coeff.self {
// swiftlint:disable:next force_cast
let coeffPlaintext = plaintext as! CoeffPlaintext
return try decodeCoeff(plaintext: coeffPlaintext, format: format)
}
if Format.self == Eval.self {
// swiftlint:disable:next force_cast
let evalPlaintext = plaintext as! EvalPlaintext
return try decodeEval(plaintext: evalPlaintext, format: format)
}
fatalError("Unsupported Format \(Format.description)")
}

/// Decodes a plaintext to signed values.
/// - Parameters:
/// - plaintext: Plaintext to decode.
/// - format: Encoding format of the plaintext.
/// - Returns: The decoded signed values.
/// - Throws: Error upon failure to decode the plaintext.
/// - seealso: ``Plaintext/decode(format:)-5081e`` for an alternative API.
@inlinable
public static func decode<T: SignedScalarType, Format: PolyFormat>(
plaintext: Plaintext<Self, Format>,
format: EncodeFormat) throws -> [T]
{
if Format.self == Coeff.self {
// swiftlint:disable:next force_cast
let coeffPlaintext = plaintext as! CoeffPlaintext
return try decodeCoeff(plaintext: coeffPlaintext, format: format)
}
if Format.self == Eval.self {
// swiftlint:disable:next force_cast
let evalPlaintext = plaintext as! EvalPlaintext
return try decodeEval(plaintext: evalPlaintext, format: format)
}
fatalError("Unsupported Format \(Format.description)")
}
}

extension HeScheme {
Expand Down
18 changes: 7 additions & 11 deletions Sources/HomomorphicEncryption/NoOpScheme.swift
Original file line number Diff line number Diff line change
Expand Up @@ -85,24 +85,20 @@ public enum NoOpScheme: HeScheme {
return try EvalPlaintext(context: context, poly: coeffPlaintext.poly.forwardNtt())
}

public static func decode<T>(plaintext: CoeffPlaintext, format: EncodeFormat) throws -> [T] where T: ScalarType {
public static func decodeCoeff<T: ScalarType>(plaintext: CoeffPlaintext, format: EncodeFormat) throws -> [T] {
try plaintext.context.decode(plaintext: plaintext, format: format)
}

public static func decode<T>(plaintext: CoeffPlaintext, format: EncodeFormat) throws -> [T]
where T: SignedScalarType
{
try plaintext.context.decode(plaintext: plaintext, format: format)
public static func decodeEval<T: ScalarType>(plaintext: EvalPlaintext, format: EncodeFormat) throws -> [T] {
try plaintext.inverseNtt().decode(format: format)
}

public static func decode<T>(plaintext: EvalPlaintext, format: EncodeFormat) throws -> [T] where T: ScalarType {
try decode(plaintext: plaintext.inverseNtt(), format: format)
public static func decodeCoeff<T: SignedScalarType>(plaintext: CoeffPlaintext, format: EncodeFormat) throws -> [T] {
try plaintext.context.decode(plaintext: plaintext, format: format)
}

public static func decode<T>(plaintext: EvalPlaintext, format: EncodeFormat) throws -> [T]
where T: SignedScalarType
{
try decode(plaintext: plaintext.inverseNtt(), format: format)
public static func decodeEval<T: SignedScalarType>(plaintext: EvalPlaintext, format: EncodeFormat) throws -> [T] {
try plaintext.inverseNtt().decode(format: format)
}

public static func zeroCiphertextCoeff(context: Context<Self>, moduliCount _: Int?) throws -> CoeffCiphertext {
Expand Down
28 changes: 4 additions & 24 deletions Sources/HomomorphicEncryption/Plaintext.swift
Original file line number Diff line number Diff line change
Expand Up @@ -174,41 +174,21 @@ extension Plaintext {
return Plaintext<Scheme, Coeff>(context: context, poly: coeffPoly)
}

/// Decodes a plaintext in ``Coeff`` format.
/// Decodes a plaintext.
/// - Parameter format: Encoding format of the plaintext.
/// - Returns: The decoded values.
/// - Throws: Error upon failure to decode the plaintext.
/// - seealso: ``HeScheme/decode(plaintext:format:)-h6vl`` for an alternative API.
@inlinable
public func decode<T: ScalarType>(format: EncodeFormat) throws -> [T] where Format == Coeff {
public func decode<T: ScalarType>(format: EncodeFormat) throws -> [T] {
try Scheme.decode(plaintext: self, format: format)
}

/// Decodes a plaintext in ``Coeff`` format to signed values.
/// Decodes a plaintext to signed values.
/// - Parameter format: Encoding format of the plaintext.
/// - Returns: The decoded signed values.
/// - Throws: Error upon failure to decode the plaintext.
@inlinable
public func decode<T: SignedScalarType>(format: EncodeFormat) throws -> [T] where Format == Coeff {
try Scheme.decode(plaintext: self, format: format)
}

/// Decodes a plaintext in ``Eval`` format.
/// - Parameter format: Encoding format of the plaintext.
/// - Returns: The decoded values.
/// - Throws: Error upon failure to decode the plaintext.
/// - seealso: ``HeScheme/decode(plaintext:format:)-663x4`` for an alternative API.
@inlinable
public func decode<T: ScalarType>(format: EncodeFormat) throws -> [T] where Format == Eval {
try Scheme.decode(plaintext: self, format: format)
}

/// Decodes a plaintext in ``Eval`` format to signed values.
/// - Parameter format: Encoding format of the plaintext.
/// - Returns: Error upon failure to decode the plaintext.
/// - Throws: Error upon failure to decode the plaintext.
@inlinable
public func decode<T: SignedScalarType>(format: EncodeFormat) throws -> [T] where Format == Eval {
public func decode<T: SignedScalarType>(format: EncodeFormat) throws -> [T] {
try Scheme.decode(plaintext: self, format: format)
}

Expand Down
11 changes: 9 additions & 2 deletions Tests/HomomorphicEncryptionTests/HeAPITests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,7 @@ class HeAPITests: XCTestCase {
let plaintextCoeffSigned: Plaintext<Scheme, Coeff> = try context.encode(
signedValues: signedData,
format: encodeFormat)
let roundTrip: [Scheme.Scalar.SignedScalar] = try context.decode(
plaintext: plaintextCoeffSigned,
let roundTrip: [Scheme.Scalar.SignedScalar] = try plaintextCoeffSigned.decode(
format: encodeFormat)
XCTAssertEqual(roundTrip, signedData)
case is Eval.Type:
Expand Down Expand Up @@ -1084,3 +1083,11 @@ class HeAPITests: XCTestCase {
try runBfvTests(UInt64.self)
}
}

/// This will compile if Plaintext.decode is generic across PolyFormat.
extension Plaintext {
private func checkDecodeIsGeneric() throws {
let _: [Scheme.Scalar] = try decode(format: .coefficient)
let _: [Scheme.SignedScalar] = try decode(format: .coefficient)
}
}

0 comments on commit 0877744

Please sign in to comment.