Skip to content

Commit

Permalink
Adding Signed Encoding and Decoding (#67)
Browse files Browse the repository at this point in the history
  • Loading branch information
akshaywadia authored Aug 20, 2024
1 parent 1a3cca2 commit 534de8c
Show file tree
Hide file tree
Showing 9 changed files with 366 additions and 8 deletions.
33 changes: 33 additions & 0 deletions Sources/HomomorphicEncryption/Bfv/Bfv+Encode.swift
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ extension Bfv {
try context.encode(values: values, format: format)
}

@inlinable
// swiftlint:disable:next missing_docs attributes
public static func encode(context: Context<Bfv<T>>, signedValues: [some SignedScalarType],
format: EncodeFormat) throws -> CoeffPlaintext
{
try context.encode(signedValues: signedValues, format: format)
}

@inlinable
// swiftlint:disable:next missing_docs attributes
public static func encode(context: Context<Bfv<T>>, values: [some ScalarType], format: EncodeFormat,
Expand All @@ -41,16 +49,41 @@ extension Bfv {
return try coeffPlaintext.convertToEvalFormat(moduliCount: moduliCount)
}

@inlinable
// swiftlint:disable:next missing_docs attributes
public static func encode(
context: Context<Bfv<T>>,
signedValues: [some SignedScalarType],
format: EncodeFormat,
moduliCount: Int?) throws -> EvalPlaintext
{
let coeffPlaintext = try Self.encode(context: context, signedValues: signedValues, format: format)
return try coeffPlaintext.convertToEvalFormat(moduliCount: moduliCount)
}

@inlinable
// swiftlint:disable:next missing_docs attributes
public static func decode<V: ScalarType>(plaintext: CoeffPlaintext, format: EncodeFormat) throws -> [V] {
try plaintext.context.decode(plaintext: plaintext, format: format)
}

@inlinable
// swiftlint:disable:next missing_docs attributes
public static func decode<V: SignedScalarType>(plaintext: CoeffPlaintext, format: EncodeFormat) throws -> [V] {
try plaintext.context.decode(plaintext: plaintext, format: format)
}

@inlinable
// swiftlint:disable:next missing_docs attributes
public static func decode<V: ScalarType>(plaintext: EvalPlaintext, format: EncodeFormat) throws -> [V] {
let coeffPlaintext = try plaintext.convertToCoeffFormat()
return try coeffPlaintext.decode(format: format)
}

@inlinable
// swiftlint:disable:next missing_docs attributes
public static func decode<V: SignedScalarType>(plaintext: EvalPlaintext, format: EncodeFormat) throws -> [V] {
let coeffPlaintext = try plaintext.convertToCoeffFormat()
return try coeffPlaintext.decode(format: format)
}
}
66 changes: 64 additions & 2 deletions Sources/HomomorphicEncryption/Encoding.swift
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,27 @@ extension Context {
}
}

/// Encodes `signedValues` in the given format.
///
/// Encoding will use the top-level ciphertext context with all moduli.
/// - Parameters:
/// - signedValues: Signed values to encode.
/// - format: Encoding format.
/// - Returns: The plaintext encoding `signedValues`.
/// - Throws: Error upon failure to encode.
@inlinable
public func encode(signedValues: [some SignedScalarType], format: EncodeFormat) throws -> Plaintext<Scheme, Coeff> {
let signedModulus = Scheme.Scalar.SignedScalar(plaintextModulus)
let bounds = -(signedModulus >> 1)...((signedModulus - 1) >> 1)
let centeredValues = try signedValues.map { value in
guard bounds.contains(Scheme.Scalar.SignedScalar(value)) else {
throw HeError.encodingDataOutOfBounds(for: bounds)
}
return try Scheme.Scalar(value.centeredToRemainder(modulus: plaintextModulus))
}
return try encode(values: centeredValues, format: format)
}

/// Encodes `values` in the given format.
/// - Parameters:
/// - values: Values to encode.
Expand All @@ -52,6 +73,21 @@ extension Context {
try Scheme.encode(context: self, values: values, format: format, moduliCount: moduliCount)
}

/// Encodes `signedValues` in the given format.
/// - Parameters:
/// - signedValues: Signed values to encode.
/// - format: Encoding format.
/// - moduliCount: Optional number of moduli. If not set, encoding will use the top-level ciphertext with all the
/// moduli.
/// - Returns: The plaintext encoding `signedValues`.
/// - Throws: Error upon failure to encode.
@inlinable
public func encode(signedValues: [some SignedScalarType], format: EncodeFormat,
moduliCount: Int? = nil) throws -> Plaintext<Scheme, Eval>
{
try Scheme.encode(context: self, signedValues: signedValues, format: format, moduliCount: moduliCount)
}

/// Decodes a plaintext with the given format.
///
/// - Parameters:
Expand All @@ -71,6 +107,21 @@ extension Context {
}
}

/// Decodes a plaintext with the given format, into signed values.
///
/// - Parameters:
/// - plaintext: Plaintext to decode.
/// - format: Format to decode with.
/// - Returns: The decoded signed values.
/// - Throws: Error upon failure to decode.
@inlinable
public func decode<T: SignedScalarType>(plaintext: Plaintext<Scheme, Coeff>, format: EncodeFormat) throws -> [T] {
let unsignedValues: [Scheme.Scalar] = try decode(plaintext: plaintext, format: format)
return unsignedValues.map { value in
T(value.remainderToCentered(modulus: plaintextModulus))
}
}

/// Decodes a plaintext with the given format.
///
/// - Parameters:
Expand All @@ -85,15 +136,26 @@ extension Context {
try Scheme.decode(plaintext: plaintext, format: format)
}

/// Decodes a plaintext with the given format, into signed values.
///
/// - Parameters:
/// - plaintext: Plaintext to decode.
/// - format: Format to decode with.
/// - Returns: The decoded signed values.
/// - Throws: Error upon failure to decode.
@inlinable
public func decode<T: SignedScalarType>(plaintext: Plaintext<Scheme, Eval>, format: EncodeFormat) throws -> [T] {
try Scheme.decode(plaintext: plaintext, format: format)
}

@inlinable
func validDataForEncoding(values: [some ScalarType]) throws {
guard values.count <= encryptionParameters.polyDegree else {
throw HeError.encodingDataCountExceedsLimit(count: values.count, limit: encryptionParameters.polyDegree)
}
for value in values {
guard value < encryptionParameters.plaintextModulus else {
throw HeError.encodingDataExceedsLimit(
limit: Int(encryptionParameters.plaintextModulus))
throw HeError.encodingDataOutOfBounds(for: 0..<encryptionParameters.plaintextModulus)
}
}
}
Expand Down
17 changes: 14 additions & 3 deletions Sources/HomomorphicEncryption/Error.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ public enum HeError: Error, Equatable {
case coprimeModuli(moduli: [Int64])
case emptyModulus
case encodingDataCountExceedsLimit(count: Int, limit: Int)
case encodingDataExceedsLimit(limit: Int)
/// The actual encoding data might be sensitive, so we omit it.
case encodingDataOutOfBounds(_ closedRange: ClosedRange<Int64>)
case errorCastingPolyFormat(_ description: String)
case incompatibleCiphertextAndPlaintext(_ description: String)
case incompatibleCiphertextCount(_ description: String)
Expand Down Expand Up @@ -52,6 +53,16 @@ public enum HeError: Error, Equatable {
}

extension HeError {
@inlinable
static func encodingDataOutOfBounds(for bounds: ClosedRange<some SignedScalarType>) -> Self {
.encodingDataOutOfBounds(Int64(bounds.lowerBound)...Int64(bounds.upperBound))
}

@inlinable
static func encodingDataOutOfBounds(for bounds: Range<some ScalarType>) -> Self {
.encodingDataOutOfBounds(Int64(bounds.lowerBound)...(Int64(bounds.upperBound) - 1))
}

@inlinable
static func errorCastingPolyFormat(from t1: (some PolyFormat).Type, to t2: (some PolyFormat).Type) -> Self {
.errorCastingPolyFormat("Error casting poly format from: \(t1.description) to: \(t2.description)")
Expand Down Expand Up @@ -177,8 +188,8 @@ extension HeError: LocalizedError {
"Empty modulus"
case let .encodingDataCountExceedsLimit(count, limit):
"Actual number of data \(count) exceeds limit \(limit)"
case let .encodingDataExceedsLimit(limit):
"Actual data exceeds limit \(limit)"
case let .encodingDataOutOfBounds(closedRange):
"Values not in encoding bounds \(closedRange)"
case let .errorCastingPolyFormat(description):
"\(description) "
case let .incompatibleCiphertextCount(description):
Expand Down
49 changes: 49 additions & 0 deletions Sources/HomomorphicEncryption/HeScheme.swift
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,22 @@ public protocol HeScheme {
/// - Returns: A plaintext encoding `values`.
/// - Throws: Error upon failure to encode.
/// - seealso: ``Context/encode(values:format:)`` for an alternative API.
/// - seealso: ``HeScheme/encode(context:signedValues:format:)`` to encode signed values.
static func encode(context: Context<Self>, values: [some ScalarType], format: EncodeFormat) throws -> CoeffPlaintext

/// Encodes signed values into a plaintext with coefficient format.
///
/// - Parameters:
/// - context: Context for HE computation.
/// - signedValues: Signed values to encode.
/// - format: Encoding format.
/// - Returns: A plaintext encoding `signedValues`.
/// - Throws: Error upon failure to encode.
/// - seealso: ``Context/encode(signedValues:format:)`` for an alternative API.
/// - seealso: ``HeScheme/encode(context:values:format)`` to encode unsigned values.
static func encode(context: Context<Self>, signedValues: [some SignedScalarType], format: EncodeFormat) throws
-> CoeffPlaintext

/// Encodes values into a plaintext with evaluation format.
///
/// The encoded plaintext will have a ``Plaintext/polyContext()`` with the `moduliCount` first ciphertext moduli.
Expand All @@ -183,9 +197,26 @@ public protocol HeScheme {
/// - Returns: A plaintext encoding `values`.
/// - Throws: Error upon failure to encode.
/// - seealso: ``Context/encode(values:format:moduliCount:)`` for an alternative API.
/// - seealso: ``HeScheme/encode(context:signedValues:format:moduliCount:)`` to encode signed values.
static func encode(context: Context<Self>, values: [some ScalarType], format: EncodeFormat,
moduliCount: Int?) throws -> EvalPlaintext

/// Encodes signed values into a plaintext with evaluation format.
///
/// The encoded plaintext will have a ``Plaintext/polyContext()`` with the `moduliCount` first ciphertext moduli.
/// - Parameters:
/// - context: Context for HE computation.
/// - signedValues: Signed values to encode.
/// - format: Encoding format.
/// - moduliCount: Optional number of moduli. If not set, encoding will use the top-level ciphertext with all the
/// moduli.
/// - Returns: A plaintext encoding `signedValues`.
/// - Throws: Error upon failure to encode.
/// - seealso: ``Context/encode(signedValues:format:moduliCount:)`` for an alternative API.
/// - seealso: ``HeScheme/encode(context:values:format:moduliCount:)`` to encode unsigned values.
static func encode(context: Context<Self>, signedValues: [some SignedScalarType], format: EncodeFormat,
moduliCount: Int?) throws -> EvalPlaintext

/// Decodes a plaintext in ``Coeff`` format.
/// - Parameters:
/// - plaintext: Plaintext to decode.
Expand All @@ -195,6 +226,15 @@ public protocol HeScheme {
/// - seealso: ``Plaintext/decode(format:)-9l5kz`` for an alternative API.
static func decode<T: ScalarType>(plaintext: CoeffPlaintext, format: EncodeFormat) throws -> [T]

/// Decodes a plaintext in ``Coeff`` format into signed values.
/// - Parameters:
/// - plaintext: Plaintext to decode.
/// - format: Encoding format of the plaintext.
/// - Returns: The decoded signed values.
/// - Throws: Error upon failure to decode the plaintext.
/// - seealso: ``Plaintext/decode(format:)-9l5kz`` for an alternative API.
static func decode<T: SignedScalarType>(plaintext: CoeffPlaintext, format: EncodeFormat) throws -> [T]

/// Decodes a plaintext in ``Eval`` format.
/// - Parameters:
/// - plaintext: Plaintext to decode.
Expand All @@ -204,6 +244,15 @@ public protocol HeScheme {
/// - seealso: ``Plaintext/decode(format:)-i9hh`` for an alternative API.
static func decode<T: ScalarType>(plaintext: EvalPlaintext, format: EncodeFormat) throws -> [T]

/// Decodes a plaintext in ``Eval`` format to signed values.
/// - Parameters:
/// - plaintext: Plaintext to decode.
/// - format: Encoding format of the plaintext.
/// - Returns: The decoded signed values.
/// - Throws: Error upon failure to decode the plaintext.
/// - seealso: ``Plaintext/decode(format:)-i9hh`` for an alternative API.
static func decode<T: SignedScalarType>(plaintext: EvalPlaintext, format: EncodeFormat) throws -> [T]

/// Symmetric secret key encryption of a plaintext.
/// - Parameters:
/// - plaintext: Plaintext to encrypt.
Expand Down
28 changes: 28 additions & 0 deletions Sources/HomomorphicEncryption/NoOpScheme.swift
Original file line number Diff line number Diff line change
Expand Up @@ -62,21 +62,49 @@ public enum NoOpScheme: HeScheme {
try context.encode(values: values, format: format)
}

public static func encode(context: Context<NoOpScheme>, signedValues: [some SignedScalarType],
format: EncodeFormat) throws -> CoeffPlaintext
{
try context.encode(signedValues: signedValues, format: format)
}

public static func encode(context: Context<NoOpScheme>, values: [some ScalarType],
format: EncodeFormat, moduliCount _: Int?) throws -> EvalPlaintext
{
let coeffPlaintext = try Self.encode(context: context, values: values, format: format)
return try EvalPlaintext(context: context, poly: coeffPlaintext.poly.forwardNtt())
}

public static func encode(
context: Context<NoOpScheme>,
signedValues: [some SignedScalarType],
format: EncodeFormat,
moduliCount _: Int?) throws -> EvalPlaintext
{
let coeffPlaintext = try Self.encode(context: context, signedValues: signedValues, format: format)
return try EvalPlaintext(context: context, poly: coeffPlaintext.poly.forwardNtt())
}

public static func decode<T>(plaintext: CoeffPlaintext, format: EncodeFormat) throws -> [T] where T: ScalarType {
try plaintext.context.decode(plaintext: plaintext, format: format)
}

public static func decode<T>(plaintext: CoeffPlaintext, format: EncodeFormat) throws -> [T]
where T: SignedScalarType
{
try plaintext.context.decode(plaintext: plaintext, format: format)
}

public static func decode<T>(plaintext: EvalPlaintext, format: EncodeFormat) throws -> [T] where T: ScalarType {
try decode(plaintext: plaintext.inverseNtt(), format: format)
}

public static func decode<T>(plaintext: EvalPlaintext, format: EncodeFormat) throws -> [T]
where T: SignedScalarType
{
try decode(plaintext: plaintext.inverseNtt(), format: format)
}

public static func zeroCiphertextCoeff(context: Context<Self>, moduliCount _: Int?) throws -> CoeffCiphertext {
NoOpScheme
.CoeffCiphertext(
Expand Down
18 changes: 18 additions & 0 deletions Sources/HomomorphicEncryption/Plaintext.swift
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,15 @@ extension Plaintext {
try Scheme.decode(plaintext: self, format: format)
}

/// Decodes a plaintext in ``Coeff`` format to signed values.
/// - Parameter format: Encoding format of the plaintext.
/// - Returns: The decoded signed values.
/// - Throws: Error upon failure to decode the plaintext.
@inlinable
public func decode<T: SignedScalarType>(format: EncodeFormat) throws -> [T] where Format == Coeff {
try Scheme.decode(plaintext: self, format: format)
}

/// Decodes a plaintext in ``Eval`` format.
/// - Parameter format: Encoding format of the plaintext.
/// - Returns: The decoded values.
Expand All @@ -194,6 +203,15 @@ extension Plaintext {
try Scheme.decode(plaintext: self, format: format)
}

/// Decodes a plaintext in ``Eval`` format to signed values.
/// - Parameter format: Encoding format of the plaintext.
/// - Returns: Error upon failure to decode the plaintext.
/// - Throws: Error upon failure to decode the plaintext.
@inlinable
public func decode<T: SignedScalarType>(format: EncodeFormat) throws -> [T] where Format == Eval {
try Scheme.decode(plaintext: self, format: format)
}

/// Symmetric secret key encryption of the plaintext.
/// - Parameter secretKey: Secret key to encrypt with.
/// - Returns: A ciphertext encrypting the plaintext.
Expand Down
Loading

0 comments on commit 534de8c

Please sign in to comment.