Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make sharding configurable #133

Merged
merged 1 commit into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Sources/HomomorphicEncryption/Keys.swift
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ extension Sequence<EvaluationKeyConfig> {
///
/// > Note: The union can be used to generate an `EvaluationKey` which supports the HE operations of any of the
/// evaluation key configurations.
/// - Returns: The joint evaluation configuration
/// - Returns: The joint evaluation configuration.
public func union() -> EvaluationKeyConfig {
var galoisElements: Set<Int> = []
var hasRelinearizationKey = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,45 @@ leaked to the server. Leakage is determined by the universe size divided by the
number of shards. For example, a universe size of 1 million keywords with two
shards means 500k keywords map to each shard.

#### Sharding function
By default we use a sharding function that look like `truncate(SHA256(keyword)) % shardCount`. However, there are cases,
when you have two or more datasets that all use the same keyword. As an example, consider a database like:

ID | Name | Portrait
-- | ---- | --------
1 | Abe | <3KB blob>
2 | Eva | <5kb blob>
...| ... | ...

Depending on the situation, one might want to query only specific columns. So, you transform this into two PIR datasets:
- ID -> Name
- ID -> Portrait

When both `Name` and `Portrait` columns are required, two PIR requests with the same `ID` are made. A curious server
could associate the requests based on timing and see two shardIndexes calculated from the same `ID`. When the shard
sizes differ, this leaks more information about the `ID` than individual shard sizes suggest.

Let’s assume the universe size for `ID` is 100K. The mapping from `ID` to `Name` is sharded into 10 shards, and the
mapping from `ID` to `Portrait` is sharded into 57 shards. When 100K IDs are divided into 10 shards, knowing which shard
an ID belongs to narrows the possible candidates to 10K. Similarly, for 57 shards, identifying the specific shard
narrows the potential candidates to about 1,755. If shards for both mappings are known, the number of remaining
candidates is reduced to about 176.

Knowing the shard for both mappings significantly narrows down the possible IDs.

To avoid this leakage, we use sharding based on the number of shards in other use case. We call this sharding function
`doubleMod` and it is defined as: `(truncate(SHA256(keyword)) % otherShardCount) % shardCount`. For example, in the `ID
-> Name` mapping, we’d use `doubleMod`: `shard_name = (truncate(SHA256(keyword)) % 57) % 10 = shard_portrait % 10`.
Knowing both `shard_name` and `shard_portrait` doesn’t provide extra information to the server anymore.

To use the `doubleMod` sharding function, add the following to the configuration file. (This example assumes that the
other usecase has 57 shards).
```json
"shardingFunction" : {
"doubleMod" : 57
}
```

#### Symmetric PIR
Some PIR algorithms, such as MulPir, include an optimization which returns multiple keyword-value pairs in the PIR
response, beyond the keyword-value pair requested by the client. However, this may be undesirable, e.g., if the database
Expand Down
25 changes: 12 additions & 13 deletions Sources/PIRProcessDatabase/ProcessDatabase.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,6 @@ import Logging
import PrivateInformationRetrieval
import PrivateInformationRetrievalProtobuf

/// Creates a new `KeywordDatabase` from a given path.
/// - Parameters:
/// - path: The path to the `KeywordDatabase` file.
/// - sharding: The sharding strategy to use.
extension KeywordDatabase {
init(from path: String, sharding: Sharding) throws {
let database = try Apple_SwiftHomomorphicEncryption_Pir_V1_KeywordDatabase(from: path)
try self.init(rows: database.native(), sharding: sharding)
}
}

/// The different table sizes that can be used for the PIR database.
enum TableSizeOption: Codable, Equatable, Hashable {
/// An `allowExpansion` option allows the database to grow as needed.
Expand Down Expand Up @@ -134,6 +123,7 @@ struct Arguments: Codable, Equatable, Hashable, Sendable {
let rlweParameters: PredefinedRlweParameters
let outputEvaluationKeyConfig: String?
var sharding: Sharding?
var shardingFunction: ShardingFunction?
var cuckooTableArguments: CuckooTableArguments?
var algorithm: PirAlgorithm?
var keyCompression: PirKeyCompressionStrategy?
Expand Down Expand Up @@ -168,6 +158,7 @@ struct Arguments: Codable, Equatable, Hashable, Sendable {
rlweParameters: resolved.rlweParameters,
outputEvaluationKeyConfig: resolved.outputEvaluationKeyConfig,
sharding: resolved.sharding,
shardingFunction: resolved.shardingFunction,
cuckooTableArguments: cuckooTableArguments,
algorithm: resolved.algorithm,
keyCompression: PirKeyCompressionStrategy.noCompression,
Expand Down Expand Up @@ -212,6 +203,7 @@ struct Arguments: Codable, Equatable, Hashable, Sendable {
outputPirParameters: outputPirParameters,
outputEvaluationKeyConfig: outputEvaluationKeyConfig,
sharding: sharding ?? Sharding.shardCount(1),
shardingFunction: shardingFunction ?? .sha256,
cuckooTableConfig: cuckooTableConfig,
rlweParameters: rlweParameters,
algorithm: algorithm ?? .mulPir,
Expand All @@ -228,6 +220,7 @@ struct ResolvedArguments: CustomStringConvertible, Encodable {
let outputPirParameters: String
let outputEvaluationKeyConfig: String?
let sharding: Sharding
let shardingFunction: ShardingFunction
let cuckooTableConfig: CuckooTableConfig
let rlweParameters: PredefinedRlweParameters
let algorithm: PirAlgorithm
Expand Down Expand Up @@ -260,6 +253,7 @@ struct ResolvedArguments: CustomStringConvertible, Encodable {
outputPirParameters: String,
outputEvaluationKeyConfig: String?,
sharding: Sharding,
shardingFunction: ShardingFunction,
cuckooTableConfig: CuckooTableConfig,
rlweParameters: PredefinedRlweParameters,
algorithm: PirAlgorithm,
Expand All @@ -272,6 +266,7 @@ struct ResolvedArguments: CustomStringConvertible, Encodable {
self.outputPirParameters = outputPirParameters
self.outputEvaluationKeyConfig = outputEvaluationKeyConfig
self.sharding = sharding
self.shardingFunction = shardingFunction
self.cuckooTableConfig = cuckooTableConfig
self.rlweParameters = rlweParameters
self.algorithm = algorithm
Expand Down Expand Up @@ -332,7 +327,8 @@ struct ProcessDatabase: AsyncParsableCommand {
cuckooTableConfig: config.cuckooTableConfig,
unevenDimensions: true,
keyCompression: config.keyCompression,
useMaxSerializedBucketSize: config.useMaxSerializedBucketSize)
useMaxSerializedBucketSize: config.useMaxSerializedBucketSize,
shardingFunction: config.shardingFunction)
let databaseConfig = KeywordDatabaseConfig(
sharding: config.sharding,
keywordPirConfig: keywordConfig)
Expand All @@ -345,7 +341,10 @@ struct ProcessDatabase: AsyncParsableCommand {
trialsPerShard: config.trialsPerShard)

let context = try Context(encryptionParameters: processArgs.encryptionParameters)
let keywordDatabase = try KeywordDatabase(rows: database, sharding: processArgs.databaseConfig.sharding)
let keywordDatabase = try KeywordDatabase(
rows: database,
sharding: processArgs.databaseConfig.sharding,
shardingFunction: config.shardingFunction)
ProcessDatabase.logger.info("Sharded database into \(keywordDatabase.shards.count) shards")

let shards = keywordDatabase.shards.sorted { $0.0.localizedStandardCompare($1.0) == .orderedAscending }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,4 +95,15 @@ rows {
```
This will generate `floor(100/15) = 6` shards, saved to `database-entry-count-0.txtpb` through `database-entry-count-5.txtpb`.

4. To configure the sharding function one can use the `sharding-function` option. If using the `doubleMod` sharding function, one also has to specify `other-shard-count`. An example for using `doubleMod` follows:
```sh
PIRShardDatabase \
--input-database database.txtpb \
--output-database database-shard-SHARD_ID.txtpb \
--sharding shardCount \
--sharding-count 5 \
--sharding-function doubleMod \
--other-shard-count 10
```

> Note: For a more compact format, use the `.binpb` extension to load the input database, and save the sharded databases in protocol buffer binary format.
27 changes: 26 additions & 1 deletion Sources/PIRShardDatabase/ShardDatabase.swift
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,20 @@ enum ShardingOption: String, CaseIterable, ExpressibleByArgument {
case shardCount
}

enum ShardingFunctionOption: String, CaseIterable, ExpressibleByArgument {
case doubleMod
case sha256
}

struct ShardingArguments: ParsableArguments {
@Option var sharding: ShardingOption
@Option(help: "A positive integer")
var shardingCount: Int

@Option var shardingFunction: ShardingFunctionOption = .sha256

@Option(help: "Shards in the other usecase")
var otherShardCount: Int?
}

extension Sharding {
Expand All @@ -44,6 +54,20 @@ extension Sharding {
}
}

extension ShardingFunction {
init(from arguments: ShardingArguments) throws {
switch arguments.shardingFunction {
case .doubleMod:
guard let otherShardCount = arguments.otherShardCount else {
throw ValidationError("Must specify 'otherShardCount' when using 'doubleMod' sharding function.")
}
self = .doubleMod(otherShardCount: otherShardCount)
case .sha256:
self = .sha256
}
}
}

extension String {
func validateProtoFilename(descriptor: String) throws {
guard hasSuffix(".txtpb") || hasSuffix(".binpb") else {
Expand Down Expand Up @@ -86,9 +110,10 @@ struct ProcessCommand: ParsableCommand {
guard let sharding = Sharding(from: sharding) else {
throw ValidationError("Invalid sharding \(sharding)")
}
let shardingFunction = try ShardingFunction(from: self.sharding)
let database: [KeywordValuePair] =
try Apple_SwiftHomomorphicEncryption_Pir_V1_KeywordDatabase(from: inputDatabase).native()
let sharded = try KeywordDatabase(rows: database, sharding: sharding)
let sharded = try KeywordDatabase(rows: database, sharding: sharding, shardingFunction: shardingFunction)
for (shardID, shard) in sharded.shards {
let outputDatabaseFilename = outputDatabase.replacingOccurrences(
of: "SHARD_ID",
Expand Down
100 changes: 97 additions & 3 deletions Sources/PrivateInformationRetrieval/KeywordDatabase.swift
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,92 @@ extension KeywordValuePair.Keyword {
}
}

/// Sharding function that determines the shard a keyword should be in.
public struct ShardingFunction: Hashable, Sendable {
/// Internal enumeration with supported cases.
@usableFromInline
package enum Internal: Hashable, Sendable {
case sha256
case doubleMod(otherShardCount: Int)
}

/// SHA256 based sharding.
///
/// The shard is determined by `truncate(SHA256(keyword)) % shardCount`.
public static let sha256: Self = .init(.sha256)

/// Internal representation.
@usableFromInline package var function: Internal

init(_ function: Internal) {
self.function = function
}

/// Sharding is dependent on another usecase.
///
/// The shard is determined by `(truncate(SHA256(keyword)) % otherShardCount) % shardCount`.
/// - Parameter otherShardCount: Number of shards in the other usecase.
/// - Returns: Sharding function that depends also on another usecase.
public static func doubleMod(otherShardCount: Int) -> Self {
.init(.doubleMod(otherShardCount: otherShardCount))
}
}

extension ShardingFunction {
/// Compute the shard index for keyword.
/// - Parameters:
/// - keyword: The keyword.
/// - shardCount: Number of shards.
/// - Returns: An index in the range `0..<shardCount`.
@inlinable
public func shardIndex(keyword: KeywordValuePair.Keyword, shardCount: Int) -> Int {
switch function {
case .sha256:
return keyword.shardIndex(shardCount: shardCount)
case let .doubleMod(otherShardCount):
let otherShardIndex = keyword.shardIndex(shardCount: otherShardCount)
return otherShardIndex % shardCount
}
}
}

// custom implementation
extension ShardingFunction: Codable {
enum CodingKeys: String, CodingKey {
case sha256
case doubleMod
}

public init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
var allKeys = ArraySlice(container.allKeys)
guard let onlyKey = allKeys.popFirst(), allKeys.isEmpty else {
throw DecodingError.typeMismatch(
Self.self,
DecodingError.Context(
codingPath: container.codingPath,
debugDescription: "Invalid number of keys found, expected one."))
}
switch onlyKey {
case .sha256:
self = .sha256
case .doubleMod:
let otherShardCount = try container.decode(Int.self, forKey: .doubleMod)
self = .doubleMod(otherShardCount: otherShardCount)
}
}

public func encode(to encoder: Encoder) throws {
var container = encoder.container(keyedBy: CodingKeys.self)
switch function {
case .sha256:
try container.encodeNil(forKey: .sha256)
case let .doubleMod(otherShardCount):
try container.encode(otherShardCount, forKey: .doubleMod)
}
}
}

/// Different ways to divide a database into disjoint shards.
public enum Sharding: Hashable, Codable, Sendable {
/// Divide database into as many shards as needed to average at least `entryCountPerShard` entries per shard.
Expand Down Expand Up @@ -232,8 +318,13 @@ public struct KeywordDatabase {
/// - Parameters:
/// - rows: Rows in the database.
/// - sharding: How to shard the database.
/// - shardingFunction: What function to use for sharding.
/// - Throws: Error upon failure to initialize the database.
public init(rows: some Collection<KeywordValuePair>, sharding: Sharding) throws {
public init(
rows: some Collection<KeywordValuePair>,
sharding: Sharding,
shardingFunction: ShardingFunction = .sha256) throws
{
let shardCount = switch sharding {
case let .shardCount(shardCount): shardCount
case let .entryCountPerShard(entryCountPerShard):
Expand All @@ -243,7 +334,7 @@ public struct KeywordDatabase {

var shards: [String: KeywordDatabaseShard] = [:]
for row in rows {
let shardID = String(row.keyword.shardIndex(shardCount: shardCount))
let shardID = String(shardingFunction.shardIndex(keyword: row.keyword, shardCount: shardCount))
if let previousValue = shards[shardID, default: KeywordDatabaseShard(shardID: shardID, rows: [])].rows
.updateValue(
row.value,
Expand Down Expand Up @@ -490,7 +581,10 @@ public enum ProcessKeywordDatabase {
let keywordConfig = arguments.databaseConfig.keywordPirConfig

let context = try Context(encryptionParameters: arguments.encryptionParameters)
let keywordDatabase = try KeywordDatabase(rows: rows, sharding: arguments.databaseConfig.sharding)
let keywordDatabase = try KeywordDatabase(
rows: rows,
sharding: arguments.databaseConfig.sharding,
shardingFunction: keywordConfig.shardingFunction)

var processedShards = [String: ProcessedDatabaseWithParameters<Scheme>]()
for (shardID, shardedDatabase) in keywordDatabase.shards where !shardedDatabase.isEmpty {
Expand Down
20 changes: 16 additions & 4 deletions Sources/PrivateInformationRetrieval/KeywordPirProtocol.swift
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,12 @@ public struct KeywordPirConfig: Hashable, Codable, Sendable {
/// Otherwise the largest serialized bucket size is used instead.
@usableFromInline let useMaxSerializedBucketSize: Bool

/// Sharding function configuration.
@usableFromInline let shardingFunction: ShardingFunction

/// Keyword PIR parameters.
public var parameter: KeywordPirParameter {
KeywordPirParameter(hashFunctionCount: cuckooTableConfig.hashFunctionCount)
KeywordPirParameter(hashFunctionCount: cuckooTableConfig.hashFunctionCount, shardingFunction: shardingFunction)
}

/// Initializes a ``KeywordPirConfig``.
Expand All @@ -48,13 +51,15 @@ public struct KeywordPirConfig: Hashable, Codable, Sendable {
/// - useMaxSerializedBucketSize: Enable this to set the entry size in index PIR layer to
/// ``CuckooTableConfig/maxSerializedBucketSize``. When not enabled, the largest serialized bucket size is used
/// instead.
/// - shardingFunction: The sharding function to use.
/// - Throws: Error upon invalid arguments.
public init(
dimensionCount: Int,
cuckooTableConfig: CuckooTableConfig,
unevenDimensions: Bool,
keyCompression: PirKeyCompressionStrategy,
useMaxSerializedBucketSize: Bool = false) throws
useMaxSerializedBucketSize: Bool = false,
shardingFunction: ShardingFunction = .sha256) throws
{
let validDimensionsCount = [1, 2]
guard validDimensionsCount.contains(dimensionCount) else {
Expand All @@ -68,6 +73,7 @@ public struct KeywordPirConfig: Hashable, Codable, Sendable {
self.unevenDimensions = unevenDimensions
self.keyCompression = keyCompression
self.useMaxSerializedBucketSize = useMaxSerializedBucketSize
self.shardingFunction = shardingFunction
}
}

Expand All @@ -78,10 +84,16 @@ public struct KeywordPirParameter: Hashable, Codable, Sendable {
/// Number of hash functions in the ``CuckooTableConfig``.
public let hashFunctionCount: Int

/// Sharding function used.
public let shardingFunction: ShardingFunction

/// Initializes a ``KeywordPirParameter``.
/// - Parameter hashFunctionCount: Number of hash functions in the ``CuckooTableConfig``.
public init(hashFunctionCount: Int) {
/// - Parameters:
/// - hashFunctionCount: Number of hash functions in the ``CuckooTableConfig``.
/// - shardingFunction: Sharding function that was used for sharding.
public init(hashFunctionCount: Int, shardingFunction: ShardingFunction = .sha256) {
self.hashFunctionCount = hashFunctionCount
self.shardingFunction = shardingFunction
}
}

Expand Down
Loading
Loading