diff --git a/WORKSPACE b/WORKSPACE index f320e9064..3fd244873 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -364,14 +364,14 @@ git_repository( # AnySketch. git_repository( name = "any_sketch", - commit = "5415eec38253c3bd3f250cb12fdc24242743e426", + commit = "2691dcd099b5f63a2eacd571b010d9dbe832f14d", remote = "sso://team/ads-xmedia-open-measurement-team/any-sketch", shallow_since = "1605572796 -0500", ) git_repository( name = "any_sketch_java", - commit = "d4a8369630667880026b7bf927e405508fbee381", + commit = "9b430b5ab6dc21f89841eb22451b4c93811d1cb6", remote = "sso://team/ads-xmedia-open-measurement-team/any-sketch-java", shallow_since = "1605573394 -0500", ) diff --git a/src/main/kotlin/org/wfanet/measurement/loadtest/CorrectnessImpl.kt b/src/main/kotlin/org/wfanet/measurement/loadtest/CorrectnessImpl.kt index 880d4aff2..3217dd565 100644 --- a/src/main/kotlin/org/wfanet/measurement/loadtest/CorrectnessImpl.kt +++ b/src/main/kotlin/org/wfanet/measurement/loadtest/CorrectnessImpl.kt @@ -36,6 +36,7 @@ import kotlinx.coroutines.flow.transform import org.wfanet.anysketch.AnySketch import org.wfanet.anysketch.SketchProtos import org.wfanet.anysketch.crypto.EncryptSketchRequest +import org.wfanet.anysketch.crypto.EncryptSketchRequest.DestroyedRegisterStrategy.CONFLICTING_KEYS import org.wfanet.anysketch.crypto.EncryptSketchResponse import org.wfanet.anysketch.crypto.SketchEncrypterAdapter import org.wfanet.estimation.Estimators @@ -329,6 +330,7 @@ class CorrectnessImpl( curveId = combinedPublicKey.ellipticCurveId.toLong() elGamalKeysBuilder.elGamalG = combinedPublicKey.generator elGamalKeysBuilder.elGamalY = combinedPublicKey.element + destroyedRegisterStrategy = CONFLICTING_KEYS }.build() val response = EncryptSketchResponse.parseFrom( SketchEncrypterAdapter.EncryptSketch(request.toByteArray()) diff --git a/src/test/cc/wfa/measurement/common/crypto/protocol_encryption_utility_test.cc b/src/test/cc/wfa/measurement/common/crypto/protocol_encryption_utility_test.cc index c9d255a09..14ca02c06 100644 --- a/src/test/cc/wfa/measurement/common/crypto/protocol_encryption_utility_test.cc +++ b/src/test/cc/wfa/measurement/common/crypto/protocol_encryption_utility_test.cc @@ -124,8 +124,8 @@ class TestData { std::string duchy_2_p_h_key_; ElGamalKeyPair duchy_3_el_gamal_key_pair_; std::string duchy_3_p_h_key_; - ElGamalPublicKey client_el_gamal_public_key; // combined from 3 duchy keys; - std::unique_ptr sketch_encrypter; + ElGamalPublicKey client_el_gamal_public_key_; // combined from 3 duchy keys; + std::unique_ptr sketch_encrypter_; TestData() { duchy_1_el_gamal_key_pair_ = GenerateRandomElGamalKeyPair(kTestCurveId); @@ -153,22 +153,26 @@ class TestData { .value() .Add(duchy_3_public_el_gamal_y_ec) .value(); - client_el_gamal_public_key.set_generator( + client_el_gamal_public_key_.set_generator( duchy_1_el_gamal_key_pair_.public_key().generator()); - client_el_gamal_public_key.set_element( + client_el_gamal_public_key_.set_element( client_public_el_gamal_y_ec.ToBytesCompressed().value()); any_sketch::crypto::CiphertextString client_public_key = { - .u = client_el_gamal_public_key.generator(), - .e = client_el_gamal_public_key.element(), + .u = client_el_gamal_public_key_.generator(), + .e = client_el_gamal_public_key_.element(), }; // Create a sketch_encryter for encrypting plaintext any_sketch data. - sketch_encrypter = any_sketch::crypto::CreateWithPublicKey( - kTestCurveId, kMaxFrequency, client_public_key) - .value(); + sketch_encrypter_ = any_sketch::crypto::CreateWithPublicKey( + kTestCurveId, kMaxFrequency, client_public_key) + .value(); } + absl::StatusOr EncryptWithConflictingKeys(const Sketch& sketch) { + return sketch_encrypter_->Encrypt( + sketch, any_sketch::crypto::EncryptSketchRequest::CONFLICTING_KEYS); + } // Helper function to go through the entire MPC protocol using the input data. // The final (flag, count) lists are returned. DecryptLastLayerFlagAndCountResponse GoThroughEntireMpcProtocol( @@ -186,7 +190,7 @@ class TestData { *blind_one_layer_register_index_request_1 .mutable_local_el_gamal_key_pair() = duchy_1_el_gamal_key_pair_; *blind_one_layer_register_index_request_1 - .mutable_composite_el_gamal_public_key() = client_el_gamal_public_key; + .mutable_composite_el_gamal_public_key() = client_el_gamal_public_key_; blind_one_layer_register_index_request_1.set_curve_id(kTestCurveId); *blind_one_layer_register_index_request_1.mutable_sketch() = add_noise_to_sketch_response.sketch(); @@ -202,7 +206,7 @@ class TestData { *blind_one_layer_register_index_request_2 .mutable_local_el_gamal_key_pair() = duchy_2_el_gamal_key_pair_; *blind_one_layer_register_index_request_2 - .mutable_composite_el_gamal_public_key() = client_el_gamal_public_key; + .mutable_composite_el_gamal_public_key() = client_el_gamal_public_key_; blind_one_layer_register_index_request_2.set_curve_id(kTestCurveId); blind_one_layer_register_index_request_2.mutable_sketch()->append( blind_one_layer_register_index_response_1.sketch().begin(), @@ -220,7 +224,7 @@ class TestData { *blind_last_layer_index_then_join_registers_request .mutable_local_el_gamal_key_pair() = duchy_3_el_gamal_key_pair_; *blind_last_layer_index_then_join_registers_request - .mutable_composite_el_gamal_public_key() = client_el_gamal_public_key; + .mutable_composite_el_gamal_public_key() = client_el_gamal_public_key_; blind_last_layer_index_then_join_registers_request.set_curve_id( kTestCurveId); blind_last_layer_index_then_join_registers_request.mutable_sketch()->append( @@ -294,14 +298,14 @@ TEST(BlindOneLayerRegisterIndex, keyAndCountShouldBeReRandomized) { Sketch plain_sketch = CreateEmptyClceSketch(); AddRegister(&plain_sketch, /* index = */ 1, /* key = */ 111, /* count = */ 2); std::string encrypted_sketch = - test_data.sketch_encrypter->Encrypt(plain_sketch).value(); + test_data.EncryptWithConflictingKeys(plain_sketch).value(); // Blind register indexes at duchy 1 BlindOneLayerRegisterIndexRequest request; *request.mutable_local_el_gamal_key_pair() = test_data.duchy_1_el_gamal_key_pair_; *request.mutable_composite_el_gamal_public_key() = - test_data.client_el_gamal_public_key; + test_data.client_el_gamal_public_key_; request.set_curve_id(kTestCurveId); request.mutable_sketch()->append(encrypted_sketch.begin(), encrypted_sketch.end()); @@ -397,7 +401,7 @@ TEST(EndToEnd, SumOfCountsShouldBeCorrect) { AddRegister(&plain_sketch, /* index = */ 1, /* key = */ 111, /* count = */ 3); AddRegister(&plain_sketch, /* index = */ 1, /* key = */ 111, /* count = */ 4); std::string encrypted_sketch = - test_data.sketch_encrypter->Encrypt(plain_sketch).value(); + test_data.EncryptWithConflictingKeys(plain_sketch).value(); DecryptLastLayerFlagAndCountResponse final_response = test_data.GoThroughEntireMpcProtocol(encrypted_sketch); @@ -417,7 +421,7 @@ TEST(EndToEnd, SumOfCoutsShouldBeCappedByMaximumFrequency) { /* count = */ kMaxFrequency - 2); AddRegister(&plain_sketch, /* index = */ 1, /* key = */ 111, /* count = */ 3); std::string encrypted_sketch = - test_data.sketch_encrypter->Encrypt(plain_sketch).value(); + test_data.EncryptWithConflictingKeys(plain_sketch).value(); DecryptLastLayerFlagAndCountResponse final_response = test_data.GoThroughEntireMpcProtocol(encrypted_sketch); @@ -436,7 +440,7 @@ TEST(EndToEnd, KeyCollisionShouldDestroyCount) { AddRegister(&plain_sketch, /* index = */ 1, /* key = */ 111, /* count = */ 2); AddRegister(&plain_sketch, /* index = */ 1, /* key = */ 222, /* count = */ 2); std::string encrypted_sketch = - test_data.sketch_encrypter->Encrypt(plain_sketch).value(); + test_data.EncryptWithConflictingKeys(plain_sketch).value(); DecryptLastLayerFlagAndCountResponse final_response = test_data.GoThroughEntireMpcProtocol(encrypted_sketch); @@ -454,7 +458,7 @@ TEST(EndToEnd, ZeroCountShouldBeSkipped) { Sketch plain_sketch = CreateEmptyClceSketch(); AddRegister(&plain_sketch, /* index = */ 2, /* key = */ 222, /* count = */ 0); std::string encrypted_sketch = - test_data.sketch_encrypter->Encrypt(plain_sketch).value(); + test_data.EncryptWithConflictingKeys(plain_sketch).value(); DecryptLastLayerFlagAndCountResponse final_response = test_data.GoThroughEntireMpcProtocol(encrypted_sketch); @@ -482,7 +486,7 @@ TEST(EndToEnd, CombinedCases) { AddRegister(&plain_sketch, /* index = */ 4, /* key = */ 400, /* count = */ 0); std::string encrypted_sketch = - test_data.sketch_encrypter->Encrypt(plain_sketch).value(); + test_data.EncryptWithConflictingKeys(plain_sketch).value(); DecryptLastLayerFlagAndCountResponse final_response = test_data.GoThroughEntireMpcProtocol(encrypted_sketch); diff --git a/src/test/kotlin/org/wfanet/measurement/integration/common/FakeDataProviderRule.kt b/src/test/kotlin/org/wfanet/measurement/integration/common/FakeDataProviderRule.kt index 941338830..dabe3c637 100644 --- a/src/test/kotlin/org/wfanet/measurement/integration/common/FakeDataProviderRule.kt +++ b/src/test/kotlin/org/wfanet/measurement/integration/common/FakeDataProviderRule.kt @@ -29,6 +29,7 @@ import org.junit.rules.TestRule import org.junit.runner.Description import org.junit.runners.model.Statement import org.wfanet.anysketch.crypto.EncryptSketchRequest +import org.wfanet.anysketch.crypto.EncryptSketchRequest.DestroyedRegisterStrategy.CONFLICTING_KEYS import org.wfanet.anysketch.crypto.EncryptSketchResponse import org.wfanet.anysketch.crypto.SketchEncrypterAdapter import org.wfanet.measurement.api.v1alpha.CombinedPublicKey @@ -119,13 +120,20 @@ class FakeDataProviderRule : TestRule { private fun generateFakeEncryptedSketch(combinedElGamalKey: ElGamalPublicKey): ByteString { val sketch = Sketch.newBuilder().apply { config = sketchConfig - for (i in 1L..10L) { + // Adds nine normal registers. + for (i in 1L..9L) { addRegistersBuilder().apply { index = i addValues(i) addValues(1) } } + // Adds another locally destroyed register. + addRegistersBuilder().apply { + index = 10 + addValues(-1) + addValues(1) + } }.build() val request = EncryptSketchRequest.newBuilder().apply { this.sketch = sketch @@ -135,6 +143,7 @@ class FakeDataProviderRule : TestRule { elGamalG = combinedElGamalKey.generator elGamalY = combinedElGamalKey.element } + destroyedRegisterStrategy = CONFLICTING_KEYS }.build() val response = EncryptSketchResponse.parseFrom( SketchEncrypterAdapter.EncryptSketch(request.toByteArray())