From 89a8b91b578f9a0aa7ecf0e277f1aa6c1ad17d0f Mon Sep 17 00:00:00 2001 From: Phi Hung Le Date: Sat, 22 Jul 2023 02:38:39 -0400 Subject: [PATCH 01/15] Implementation of the initialization phase, setup phase, and execution phase of the reach only protocol. --- .../crypto/encryption_utility_helper.cc | 58 ++ .../common/crypto/encryption_utility_helper.h | 27 + .../common/crypto/protocol_cryptor.cc | 37 + .../common/crypto/protocol_cryptor.h | 3 + .../protocol/liquid_legions_v2/BUILD.bazel | 46 ++ .../liquid_legions_v2_encryption_utility.cc | 33 - ...id_legions_v2_encryption_utility_helper.cc | 82 +++ ...uid_legions_v2_encryption_utility_helper.h | 48 ++ ...ly_liquid_legions_v2_encryption_utility.cc | 629 +++++++++++++++++ ...nly_liquid_legions_v2_encryption_utility.h | 93 +++ .../measurement/internal/duchy/crypto.proto | 22 - ...liquid_legions_sketch_aggregation_v2.proto | 10 - ...liquid_legions_v2_encryption_methods.proto | 75 +- .../protocol/liquid_legions_v2/BUILD.bazel | 95 ++- ...quid_legions_v2_encryption_utility_test.cc | 34 +- ...quid_legions_v2_encryption_utility_test.cc | 654 ++++++++++++++++++ 16 files changed, 1772 insertions(+), 174 deletions(-) create mode 100644 src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility_helper.cc create mode 100644 src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility_helper.h create mode 100644 src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility.cc create mode 100644 src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility.h create mode 100644 src/test/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility_test.cc diff --git a/src/main/cc/wfa/measurement/common/crypto/encryption_utility_helper.cc b/src/main/cc/wfa/measurement/common/crypto/encryption_utility_helper.cc index 0ddf2cd6ecd..43a22d736d6 100644 --- a/src/main/cc/wfa/measurement/common/crypto/encryption_utility_helper.cc +++ b/src/main/cc/wfa/measurement/common/crypto/encryption_utility_helper.cc @@ -67,6 +67,26 @@ absl::StatusOr> GetBlindedRegisterIndexes( return blinded_register_indexes; } +absl::StatusOr> GetRollv2BlindedRegisterIndexes( + absl::string_view data, ProtocolCryptor& protocol_cryptor) { + ASSIGN_OR_RETURN(size_t register_count, + GetNumberOfBlocks(data, kBytesPerCipherText)); + std::vector blinded_register_indexes; + blinded_register_indexes.reserve(register_count); + for (size_t index = 0; index < register_count; ++index) { + // The size of data_block is guaranteed to be equal to + // kBytesPerCipherText + absl::string_view data_block = + data.substr(index * kBytesPerCipherText, kBytesPerCipherText); + ASSIGN_OR_RETURN(ElGamalCiphertext ciphertext, + ExtractElGamalCiphertextFromString(data_block)); + ASSIGN_OR_RETURN(std::string decrypted_el_gamal, + protocol_cryptor.DecryptLocalElGamal(ciphertext)); + blinded_register_indexes.push_back(std::move(decrypted_el_gamal)); + } + return blinded_register_indexes; +} + absl::StatusOr ExtractKeyCountPairFromSubstring( absl::string_view str) { if (str.size() != kBytesPerCipherText * 2) { @@ -121,6 +141,17 @@ absl::Status WriteEcPointPairToString(const ElGamalEcPointPair& ec_point_pair, return absl::OkStatus(); } +absl::StatusOr GetEcPointPairFromString( + absl::string_view str, int curve_id) { + std::unique_ptr context(new Context); + ASSIGN_OR_RETURN(ECGroup ec_group, ECGroup::Create(curve_id, context.get())); + ASSIGN_OR_RETURN(ElGamalCiphertext ciphertext, + ExtractElGamalCiphertextFromString(str)); + ASSIGN_OR_RETURN(ElGamalEcPointPair ec_point, + GetElGamalEcPoints(ciphertext, ec_group)); + return ec_point; +} + absl::StatusOr> GetCountValuesPlaintext( int maximum_value, int curve_id) { if (maximum_value < 1) { @@ -142,4 +173,31 @@ absl::StatusOr> GetCountValuesPlaintext( return result; } +absl::Status EncryptCompositeElGamalAndAppendToString( + ProtocolCryptor& protocol_cryptor, CompositeType composite_type, + absl::string_view plaintext_ec, std::string& data) { + ASSIGN_OR_RETURN( + ElGamalCiphertext key, + protocol_cryptor.EncryptCompositeElGamal(plaintext_ec, composite_type)); + data.append(key.first); + data.append(key.second); + return absl::OkStatus(); +} + +absl::Status EncryptCompositeElGamalAndWriteToString( + ProtocolCryptor& protocol_cryptor, CompositeType composite_type, + absl::string_view plaintext_ec, size_t pos, std::string& result) { + if (pos + kBytesPerCipherText > result.size()) { + return absl::InvalidArgumentError("result is not long enough to write."); + } + ASSIGN_OR_RETURN( + ElGamalCiphertext key, + protocol_cryptor.EncryptCompositeElGamal(plaintext_ec, composite_type)); + + result.replace(pos, kBytesPerEcPoint, key.first); + result.replace(pos + kBytesPerEcPoint, kBytesPerEcPoint, key.second); + + return absl::OkStatus(); +} + } // namespace wfa::measurement::common::crypto diff --git a/src/main/cc/wfa/measurement/common/crypto/encryption_utility_helper.h b/src/main/cc/wfa/measurement/common/crypto/encryption_utility_helper.h index 90ae46fe235..e55e91e1499 100644 --- a/src/main/cc/wfa/measurement/common/crypto/encryption_utility_helper.h +++ b/src/main/cc/wfa/measurement/common/crypto/encryption_utility_helper.h @@ -26,6 +26,8 @@ namespace wfa::measurement::common::crypto { +using ::wfa::measurement::common::crypto::CompositeType; + // A pair of ciphertexts which store the key and count values of a liquidlegions // register. struct KeyCountPairCipherText { @@ -46,6 +48,11 @@ absl::StatusOr ExtractElGamalCiphertextFromString( absl::StatusOr> GetBlindedRegisterIndexes( absl::string_view data, ProtocolCryptor& protocol_cryptor); +// Blinds the last layer of ElGamal Encryption of register indexes, and return +// the deterministically encrypted results. +absl::StatusOr> GetRollv2BlindedRegisterIndexes( + absl::string_view data, ProtocolCryptor& protocol_cryptor); + // Extracts a KeyCountPairCipherText from a string_view. absl::StatusOr ExtractKeyCountPairFromSubstring( absl::string_view str); @@ -66,10 +73,30 @@ absl::Status AppendEcPointPairToString(const ElGamalEcPointPair& ec_point_pair, absl::Status WriteEcPointPairToString(const ElGamalEcPointPair& ec_point_pair, size_t pos, std::string& result); +// Extract a ElGamalEcPointPair from a string_view. +absl::StatusOr GetEcPointPairFromString( + absl::string_view str, int curve_id); + // Returns the vector of ECPoints for count values from 1 to maximum_value. absl::StatusOr> GetCountValuesPlaintext( int maximum_value, int curve_id); +// Encrypts plaintext and appends bytes of the cipher text to a target string. +// The length of bytes appened is kBytesPerCipherText = kBytesPerEcPoint * 2. +absl::Status EncryptCompositeElGamalAndAppendToString( + ProtocolCryptor& protocol_cryptor, CompositeType composite_type, + absl::string_view plaintext_ec, std::string& data); + +// Encrypts plaintext and writes bytes of the cipher text to a target string at +// a certain position. +// Bytes are written by replacing content of the string starting at pos. The +// length of bytes written is kBytesPerCipherText = kBytesPerEcPoint * 2. +// Returns a Status with code `INVALID_ARGUMENT` when the result string is not +// long enough. +absl::Status EncryptCompositeElGamalAndWriteToString( + ProtocolCryptor& protocol_cryptor, CompositeType composite_type, + absl::string_view plaintext_ec, size_t pos, std::string& result); + } // namespace wfa::measurement::common::crypto #endif // SRC_MAIN_CC_WFA_MEASUREMENT_COMMON_CRYPTO_ENCRYPTION_UTILITY_HELPER_H_ diff --git a/src/main/cc/wfa/measurement/common/crypto/protocol_cryptor.cc b/src/main/cc/wfa/measurement/common/crypto/protocol_cryptor.cc index 9a4489453b5..de677943496 100644 --- a/src/main/cc/wfa/measurement/common/crypto/protocol_cryptor.cc +++ b/src/main/cc/wfa/measurement/common/crypto/protocol_cryptor.cc @@ -66,6 +66,8 @@ class ProtocolCryptorImpl : public ProtocolCryptor { CompositeType composite_type) override; absl::StatusOr EncryptCompositeElGamal( absl::string_view plain_ec_point, CompositeType composite_type) override; + absl::StatusOr EncryptIntegerWithCompositElGamalAndWriteToString( + int64_t value) override; absl::StatusOr ReRandomize( const ElGamalCiphertext& ciphertext, CompositeType composite_type) override; @@ -173,6 +175,41 @@ absl::StatusOr ProtocolCryptorImpl::EncryptCompositeElGamal( : partial_composite_el_gamal_cipher_->Encrypt(plain_ec_point); } +absl::StatusOr +ProtocolCryptorImpl::EncryptIntegerWithCompositElGamalAndWriteToString( + int64_t value) { + Context ctx; + std::string ciphertext; + ciphertext.resize(kBytesPerCipherText); + if (value < 0) { + return absl::InvalidArgumentError( + absl::StrCat("The value should be non-negative, but is ", value)); + } + if (value == 0) { + ASSIGN_OR_RETURN( + ElGamalEcPointPair zero_ec, + EncryptIdentityElementToEcPointsCompositeElGamal(CompositeType::kFull)); + std::string temp; + ASSIGN_OR_RETURN(temp, zero_ec.u.ToBytesCompressed()); + ciphertext.replace(0, kBytesPerEcPoint, temp); + ASSIGN_OR_RETURN(temp, zero_ec.e.ToBytesCompressed()); + ciphertext.replace(kBytesPerEcPoint, kBytesPerEcPoint, temp); + } else { + ASSIGN_OR_RETURN(ElGamalEcPointPair one_ec, + EncryptPlaintextToEcPointsCompositeElGamal( + kUnitECPointSeed, CompositeType::kFull)); + ASSIGN_OR_RETURN( + ElGamalEcPointPair point_ec, + MultiplyEcPointPairByScalar(one_ec, ctx.CreateBigNum(value))); + std::string temp; + ASSIGN_OR_RETURN(temp, point_ec.u.ToBytesCompressed()); + ciphertext.replace(0, kBytesPerEcPoint, temp); + ASSIGN_OR_RETURN(temp, point_ec.e.ToBytesCompressed()); + ciphertext.replace(kBytesPerEcPoint, kBytesPerEcPoint, temp); + } + return ciphertext; +} + absl::StatusOr ProtocolCryptorImpl::ReRandomize( const ElGamalCiphertext& ciphertext, CompositeType composite_type) { ASSIGN_OR_RETURN( diff --git a/src/main/cc/wfa/measurement/common/crypto/protocol_cryptor.h b/src/main/cc/wfa/measurement/common/crypto/protocol_cryptor.h index 1809619e1be..f1c2c844955 100644 --- a/src/main/cc/wfa/measurement/common/crypto/protocol_cryptor.h +++ b/src/main/cc/wfa/measurement/common/crypto/protocol_cryptor.h @@ -70,6 +70,9 @@ class ProtocolCryptor { // Encrypts the plain EcPoint using the full or partial composite ElGamal Key. virtual absl::StatusOr EncryptCompositeElGamal( absl::string_view plain_ec_point, CompositeType composite_type) = 0; + // Encrypts an integer with the full composite ElGamal Key. + virtual absl::StatusOr + EncryptIntegerWithCompositElGamalAndWriteToString(int64_t value) = 0; // Encrypts the Identity Element using the full or partial composite ElGamal // Key, returns the result as an ElGamalEcPointPair. virtual absl::StatusOr diff --git a/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/BUILD.bazel b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/BUILD.bazel index 1932eb30b7a..2bd3a0af4cb 100644 --- a/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/BUILD.bazel +++ b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/BUILD.bazel @@ -9,6 +9,24 @@ package(default_visibility = [ _INCLUDE_PREFIX = "/src/main/cc" +cc_library( + name = "liquid_legions_v2_encryption_utility_helper", + srcs = [ + "liquid_legions_v2_encryption_utility_helper.cc", + ], + hdrs = [ + "liquid_legions_v2_encryption_utility_helper.h", + ], + strip_include_prefix = _INCLUDE_PREFIX, + deps = [ + "@any_sketch//src/main/cc/estimation:estimators", + "@com_google_absl//absl/status:statusor", + "@any_sketch//src/main/cc/any_sketch/crypto:sketch_encrypter", + "//src/main/proto/wfa/measurement/internal/duchy:crypto_cc_proto", + "//src/main/proto/wfa/measurement/internal/duchy:differential_privacy_cc_proto", + ], +) + cc_library( name = "liquid_legions_v2_encryption_utility", srcs = [ @@ -36,6 +54,33 @@ cc_library( ], ) +cc_library( + name = "reach_only_liquid_legions_v2_encryption_utility", + srcs = [ + "reach_only_liquid_legions_v2_encryption_utility.cc", + ], + hdrs = [ + "reach_only_liquid_legions_v2_encryption_utility.h", + ], + strip_include_prefix = _INCLUDE_PREFIX, + deps = [ + ":multithreading_helper", + ":noise_parameters_computation", + "//src/main/cc/wfa/measurement/common/crypto:constants", + "//src/main/cc/wfa/measurement/common/crypto:encryption_utility_helper", + "//src/main/cc/wfa/measurement/common/crypto:protocol_cryptor", + "//src/main/proto/wfa/measurement/internal/duchy/protocol:reach_only_liquid_legions_v2_encryption_methods_cc_proto", + "@any_sketch//src/main/cc/estimation:estimators", + "@any_sketch//src/main/cc/math:distributed_discrete_gaussian_noiser", + "@any_sketch//src/main/cc/math:distributed_geometric_noiser", + "@com_google_absl//absl/algorithm:container", + "@com_google_private_join_and_compute//private_join_and_compute/crypto:commutative_elgamal", + "@wfa_common_cpp//src/main/cc/common_cpp/jni:jni_wrap", + "@wfa_common_cpp//src/main/cc/common_cpp/macros", + "@wfa_common_cpp//src/main/cc/common_cpp/time:started_thread_cpu_timer", + ], +) + cc_library( name = "liquid_legions_v2_encryption_utility_wrapper", srcs = [ @@ -88,3 +133,4 @@ cc_library( "@wfa_common_cpp//src/main/cc/common_cpp/macros", ], ) + diff --git a/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility.cc b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility.cc index 00b4bbf6603..c257a8c6f6f 100644 --- a/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility.cc +++ b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility.cc @@ -240,39 +240,6 @@ absl::StatusOr> GetSameKeyAggregatorMatrixBase( return std::move(result); } -absl::Status EncryptCompositeElGamalAndAppendToString( - ProtocolCryptor& protocol_cryptor, CompositeType composite_type, - absl::string_view plaintext_ec, std::string& data) { - ASSIGN_OR_RETURN( - ElGamalCiphertext key, - protocol_cryptor.EncryptCompositeElGamal(plaintext_ec, composite_type)); - data.append(key.first); - data.append(key.second); - return absl::OkStatus(); -} - -// Encrypts plaintext and writes bytes of the cipher text to a target string at -// a certain position. -// Bytes are written by replacing content of the string starting at pos. The -// length of bytes written is kBytesPerCipherText = kBytesPerEcPoint * 2. -// Returns a Status with code `INVALID_ARGUMENT` when the result string is not -// long enough. -absl::Status EncryptCompositeElGamalAndWriteToString( - ProtocolCryptor& protocol_cryptor, CompositeType composite_type, - absl::string_view plaintext_ec, size_t pos, std::string& result) { - if (pos + kBytesPerCipherText > result.size()) { - return absl::InvalidArgumentError("result is not long enough to write."); - } - ASSIGN_OR_RETURN( - ElGamalCiphertext key, - protocol_cryptor.EncryptCompositeElGamal(plaintext_ec, composite_type)); - - result.replace(pos, kBytesPerEcPoint, key.first); - result.replace(pos + kBytesPerEcPoint, kBytesPerEcPoint, key.second); - - return absl::OkStatus(); -} - // Adds encrypted blinded-histogram-noise registers to the end of data. // returns the number of such noise registers added. absl::StatusOr AddBlindedHistogramNoise( diff --git a/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility_helper.cc b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility_helper.cc new file mode 100644 index 00000000000..7e4e497ab5a --- /dev/null +++ b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility_helper.cc @@ -0,0 +1,82 @@ +// Copyright 2023 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility_helper.h" + +#include "estimation/estimators.h" + +namespace wfa::measurement::internal::duchy::protocol::liquid_legions_v2 { + +using ::wfa::any_sketch::Sketch; +using ::wfa::any_sketch::SketchConfig; +using ::wfa::measurement::internal::duchy::ElGamalPublicKey; + +::wfa::any_sketch::crypto::ElGamalPublicKey ToAnysketchElGamalKey( + ElGamalPublicKey key) { + ::wfa::any_sketch::crypto::ElGamalPublicKey result; + result.set_generator(key.generator()); + result.set_element(key.element()); + return result; +} + +ElGamalPublicKey ToCmmsElGamalKey( + ::wfa::any_sketch::crypto::ElGamalPublicKey key) { + ElGamalPublicKey result; + result.set_generator(key.generator()); + result.set_element(key.element()); + return result; +} + +Sketch CreateEmptyLiquidLegionsSketch() { + Sketch plain_sketch; + plain_sketch.mutable_config()->add_values()->set_aggregator( + SketchConfig::ValueSpec::UNIQUE); + plain_sketch.mutable_config()->add_values()->set_aggregator( + SketchConfig::ValueSpec::SUM); + return plain_sketch; +} + +Sketch CreateReachOnlyEmptyLiquidLegionsSketch() { + Sketch plain_sketch; + return plain_sketch; +} + +DifferentialPrivacyParams MakeDifferentialPrivacyParams(double epsilon, + double delta) { + DifferentialPrivacyParams params; + params.set_epsilon(epsilon); + params.set_delta(delta); + return params; +} + +absl::StatusOr EstimateReach(double liquid_legions_decay_rate, + int64_t liquid_legions_size, + size_t non_empty_register_count, + float sampling_rate) { + if (liquid_legions_decay_rate <= 1.0) { + return absl::InvalidArgumentError(absl::StrCat( + "The decay rate should be > 1, but is ", liquid_legions_decay_rate)); + } + if (liquid_legions_size <= non_empty_register_count) { + return absl::InvalidArgumentError(absl::StrCat( + "liquid legions size (", liquid_legions_size, + ") should be greater then the number of non empty registers (", + non_empty_register_count, ").")); + } + return wfa::estimation::EstimateCardinalityLiquidLegions( + liquid_legions_decay_rate, liquid_legions_size, non_empty_register_count, + sampling_rate); +} + +} // namespace wfa::measurement::internal::duchy::protocol::liquid_legions_v2 diff --git a/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility_helper.h b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility_helper.h new file mode 100644 index 00000000000..041912eebec --- /dev/null +++ b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility_helper.h @@ -0,0 +1,48 @@ +// Copyright 2023 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SRC_MAIN_CC_WFA_MEASUREMENT_INTERNAL_DUCHY_PROTOCOL_LIQUID_LEGIONS_V2_LIQUID_LEGIONS_V2_ENCRYPTION_UTILITY_HELPER_H_ +#define SRC_MAIN_CC_WFA_MEASUREMENT_INTERNAL_DUCHY_PROTOCOL_LIQUID_LEGIONS_V2_LIQUID_LEGIONS_V2_ENCRYPTION_UTILITY_HELPER_H_ + +#include "absl/status/statusor.h" +#include "any_sketch/crypto/sketch_encrypter.h" +#include "wfa/measurement/internal/duchy/crypto.pb.h" +#include "wfa/measurement/internal/duchy/differential_privacy.pb.h" + +namespace wfa::measurement::internal::duchy::protocol::liquid_legions_v2 { + +using ::wfa::any_sketch::Sketch; +using ::wfa::measurement::internal::duchy::DifferentialPrivacyParams; +using ::wfa::measurement::internal::duchy::ElGamalPublicKey; + +absl::StatusOr EstimateReach(double liquid_legions_decay_rate, + int64_t liquid_legions_size, + size_t non_empty_register_count, + float sampling_rate = 1.0); + +::wfa::any_sketch::crypto::ElGamalPublicKey ToAnysketchElGamalKey( + ElGamalPublicKey key); + +ElGamalPublicKey ToCmmsElGamalKey( + ::wfa::any_sketch::crypto::ElGamalPublicKey key); + +Sketch CreateEmptyLiquidLegionsSketch(); + +Sketch CreateReachOnlyEmptyLiquidLegionsSketch(); + +DifferentialPrivacyParams MakeDifferentialPrivacyParams(double epsilon, + double delta); +} // namespace wfa::measurement::internal::duchy::protocol::liquid_legions_v2 + +#endif // SRC_MAIN_CC_WFA_MEASUREMENT_INTERNAL_DUCHY_PROTOCOL_LIQUID_LEGIONS_V2_LIQUID_LEGIONS_V2_ENCRYPTION_UTILITY_HELPER_H_ diff --git a/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility.cc b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility.cc new file mode 100644 index 00000000000..aaf4e93aa6c --- /dev/null +++ b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility.cc @@ -0,0 +1,629 @@ +// Copyright 2023 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility.h" + +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/memory/memory.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "common_cpp/macros/macros.h" +#include "common_cpp/time/started_thread_cpu_timer.h" +#include "estimation/estimators.h" +#include "math/distributed_noiser.h" +#include "private_join_and_compute/crypto/commutative_elgamal.h" +#include "wfa/measurement/common/crypto/constants.h" +#include "wfa/measurement/common/crypto/encryption_utility_helper.h" +#include "wfa/measurement/common/crypto/protocol_cryptor.h" +#include "wfa/measurement/common/string_block_sorter.h" +#include "wfa/measurement/internal/duchy/protocol/liquid_legions_v2/multithreading_helper.h" +#include "wfa/measurement/internal/duchy/protocol/liquid_legions_v2/noise_parameters_computation.h" + +namespace wfa::measurement::internal::duchy::protocol::liquid_legions_v2 { + +namespace { + +using ::private_join_and_compute::BigNum; +using ::private_join_and_compute::CommutativeElGamal; +using ::private_join_and_compute::Context; +using ::private_join_and_compute::ECGroup; +using ::wfa::measurement::common::SortStringByBlock; +using ::wfa::measurement::common::crypto::Action; +using ::wfa::measurement::common::crypto::CompositeType; +using ::wfa::measurement::common::crypto::CreateProtocolCryptor; +using ::wfa::measurement::common::crypto::ElGamalCiphertext; +using ::wfa::measurement::common::crypto::ElGamalEcPointPair; +using ::wfa::measurement::common::crypto::ExtractElGamalCiphertextFromString; +using ::wfa::measurement::common::crypto::ExtractKeyCountPairFromRegisters; +using ::wfa::measurement::common::crypto::GetCountValuesPlaintext; +using ::wfa::measurement::common::crypto::GetEcPointPairFromString; +using ::wfa::measurement::common::crypto::GetElGamalEcPoints; +using ::wfa::measurement::common::crypto::GetNumberOfBlocks; +using ::wfa::measurement::common::crypto::kBlindedHistogramNoiseRegisterKey; +using ::wfa::measurement::common::crypto::kBytesPerCipherText; +using ::wfa::measurement::common::crypto::kBytesPerEcPoint; +using ::wfa::measurement::common::crypto::kBytesPerFlagsCountTuple; +using ::wfa::measurement::common::crypto::kDefaultEllipticCurveId; +using ::wfa::measurement::common::crypto::kDestroyedRegisterKey; +using ::wfa::measurement::common::crypto::KeyCountPairCipherText; +using ::wfa::measurement::common::crypto::kFlagZeroBase; +using ::wfa::measurement::common::crypto::kGenerateNewCompositeCipher; +using ::wfa::measurement::common::crypto::kGenerateNewParitialCompositeCipher; +using ::wfa::measurement::common::crypto::kGenerateWithNewElGamalPrivateKey; +using ::wfa::measurement::common::crypto::kGenerateWithNewElGamalPublicKey; +using ::wfa::measurement::common::crypto::kGenerateWithNewPohligHellmanKey; +using ::wfa::measurement::common::crypto::kPaddingNoiseRegisterId; +using ::wfa::measurement::common::crypto::kPublisherNoiseRegisterId; +using ::wfa::measurement::common::crypto::kUnitECPointSeed; +using ::wfa::measurement::common::crypto::MultiplyEcPointPairByScalar; +using ::wfa::measurement::common::crypto::ProtocolCryptor; +using ::wfa::measurement::common::crypto::ProtocolCryptorOptions; +using ::wfa::measurement::internal::duchy::ElGamalPublicKey; +using ::wfa::measurement::internal::duchy::protocol::LiquidLegionsV2NoiseConfig; + +absl::StatusOr EstimateReach(double liquid_legions_decay_rate, + int64_t liquid_legions_size, + size_t non_empty_register_count, + float sampling_rate = 1.0) { + if (liquid_legions_decay_rate <= 1.0) { + return absl::InvalidArgumentError(absl::StrCat( + "The decay rate should be > 1, but is ", liquid_legions_decay_rate)); + } + if (liquid_legions_size <= non_empty_register_count) { + return absl::InvalidArgumentError(absl::StrCat( + "liquid legions size (", liquid_legions_size, + ") should be greater then the number of non empty registers (", + non_empty_register_count, ").")); + } + return wfa::estimation::EstimateCardinalityLiquidLegions( + liquid_legions_decay_rate, liquid_legions_size, non_empty_register_count, + sampling_rate); +} + +int64_t CountUniqueElements(const std::vector& arr) { + if (arr.empty()) { + return 0; + } + // Create a sorting permutation of the array, such that we don't need to + // modify the data, whose size could be huge. + std::vector permutation(arr.size()); + absl::c_iota(permutation, 0); + absl::c_sort(permutation, + [&](size_t a, size_t b) { return arr[a] < arr[b]; }); + + // Counting the number of unique elements by iterating through the indices. + int64_t count = 1; + int start = 0; + for (size_t i = 0; i < arr.size(); ++i) { + if (arr[permutation[i]] == arr[permutation[start]]) { + // This register has the same index, it belongs to the same group; + continue; + } else { + // This register belongs to a new group. Increase the unique register + // count by 1. + count++; + // Reset the starting point. + start = i; + } + } + return count; +} + +// Adds encrypted blinded-histogram-noise registers to the end of data. +// returns the number of such noise registers added. +absl::StatusOr AddReachOnlyBlindedHistogramNoise( + ProtocolCryptor& protocol_cryptor, int total_sketches_count, + const math::DistributedNoiser& distributed_noiser, size_t pos, + std::string& data, int64_t& num_unique_noise_id) { + ASSIGN_OR_RETURN( + std::string blinded_histogram_noise_key_ec, + protocol_cryptor.MapToCurve(kBlindedHistogramNoiseRegisterKey)); + + int64_t noise_register_added = 0; + num_unique_noise_id = 0; + + for (int k = 1; k <= total_sketches_count; ++k) { + // The random number of distinct register_ids that should appear k times. + ASSIGN_OR_RETURN(int64_t noise_register_count_for_bucket_k, + distributed_noiser.GenerateNoiseComponent()); + num_unique_noise_id += noise_register_count_for_bucket_k; + + // Add noise_register_count_for_bucket_k such distinct register ids. + for (int i = 0; i < noise_register_count_for_bucket_k; ++i) { + // The prefix is to ensure the value is not in the regular id space. + std::string register_id = + absl::StrCat("blinded_histogram_noise", + protocol_cryptor.NextRandomBigNumAsString()); + ASSIGN_OR_RETURN(std::string register_id_ec, + protocol_cryptor.MapToCurve(register_id)); + // Add k registers with the same register_id. + for (int j = 0; j < k; ++j) { + // Add register_id + RETURN_IF_ERROR(EncryptCompositeElGamalAndWriteToString( + protocol_cryptor, CompositeType::kFull, register_id_ec, pos, data)); + pos += kBytesPerCipherText; + + ++noise_register_added; + } + } + } + + return noise_register_added; +} + +// Adds encrypted noise-for-publisher-noise registers to the end of data. +// returns the number of such noise registers added. +absl::StatusOr AddReachOnlyNoiseForPublisherNoise( + MultithreadingHelper& helper, + const math::DistributedNoiser& distributed_noiser, size_t pos, + std::string& data) { + ASSIGN_OR_RETURN( + std::string publisher_noise_register_id_ec, + helper.GetProtocolCryptor().MapToCurve(kPublisherNoiseRegisterId)); + + ASSIGN_OR_RETURN(int64_t noise_registers_count, + distributed_noiser.GenerateNoiseComponent()); + // Make sure that there is at least one publisher noise added. + noise_registers_count++; + absl::AnyInvocable f = + [&](ProtocolCryptor& cryptor, size_t index) -> absl::Status { + size_t current_pos = pos + kBytesPerCipherText * index; + RETURN_IF_ERROR(EncryptCompositeElGamalAndWriteToString( + cryptor, CompositeType::kFull, publisher_noise_register_id_ec, + current_pos, data)); + + return absl::OkStatus(); + }; + RETURN_IF_ERROR(helper.Execute(noise_registers_count, f)); + + return noise_registers_count; +} + +// Adds encrypted global-reach-DP-noise registers to the end of data. +// returns the number of such noise registers added. +absl::StatusOr AddReachOnlyGlobalReachDpNoise( + MultithreadingHelper& helper, + const math::DistributedNoiser& distributed_noiser, size_t pos, + std::string& data) { + ASSIGN_OR_RETURN(int64_t noise_registers_count, + distributed_noiser.GenerateNoiseComponent()); + absl::AnyInvocable f = + [&](ProtocolCryptor& cryptor, size_t index) -> absl::Status { + size_t current_pos = pos + kBytesPerCipherText * index; + // Add register id, a random number. + // The prefix is to ensure the value is not in the regular id space. + std::string register_id = + absl::StrCat("reach_dp_noise", cryptor.NextRandomBigNumAsString()); + ASSIGN_OR_RETURN(std::string register_id_ec, + cryptor.MapToCurve(register_id)); + RETURN_IF_ERROR(EncryptCompositeElGamalAndWriteToString( + cryptor, CompositeType::kFull, register_id_ec, current_pos, data)); + + return absl::OkStatus(); + }; + RETURN_IF_ERROR(helper.Execute(noise_registers_count, f)); + + return noise_registers_count; +} + +// Adds encrypted padding-noise registers to the end of data. +absl::Status AddReachOnlyPaddingReachNoise(MultithreadingHelper& helper, + int64_t count, size_t pos, + std::string& data) { + if (count < 0) { + return absl::InvalidArgumentError("Count should >= 0."); + } + + ASSIGN_OR_RETURN( + std::string padding_noise_register_id_ec, + helper.GetProtocolCryptor().MapToCurve(kPaddingNoiseRegisterId)); + + absl::AnyInvocable f = + [&](ProtocolCryptor& cryptor, size_t index) -> absl::Status { + size_t current_pos = pos + kBytesPerCipherText * index; + // Add register_id, a predefined constant + RETURN_IF_ERROR(EncryptCompositeElGamalAndWriteToString( + cryptor, CompositeType::kFull, padding_noise_register_id_ec, + current_pos, data)); + + return absl::OkStatus(); + }; + RETURN_IF_ERROR(helper.Execute(count, f)); + pos += kBytesPerCipherText * count; + + return absl::OkStatus(); +} + +absl::Status ValidateReachOnlySetupNoiseParameters( + const RegisterNoiseGenerationParameters& parameters) { + if (parameters.contributors_count() < 1) { + return absl::InvalidArgumentError("contributors_count should be positive."); + } + if (parameters.total_sketches_count() < 1) { + return absl::InvalidArgumentError( + "total_sketches_count should be positive."); + } + if (parameters.dp_params().blind_histogram().epsilon() <= 0 || + parameters.dp_params().blind_histogram().delta() <= 0) { + return absl::InvalidArgumentError( + "Invalid blind_histogram dp parameter. epsilon/delta should be " + "positive."); + } + if (parameters.dp_params().noise_for_publisher_noise().epsilon() <= 0 || + parameters.dp_params().noise_for_publisher_noise().delta() <= 0) { + return absl::InvalidArgumentError( + "Invalid noise_for_publisher_noise dp parameter. epsilon/delta should " + "be positive."); + } + if (parameters.dp_params().global_reach_dp_noise().epsilon() <= 0 || + parameters.dp_params().global_reach_dp_noise().delta() <= 0) { + return absl::InvalidArgumentError( + "Invalid global_reach_dp_noise dp parameter. epsilon/delta should be " + "positive."); + } + return absl::OkStatus(); +} + +} // namespace + +absl::StatusOr +CompleteReachOnlyInitializationPhase( + const CompleteReachOnlyInitializationPhaseRequest& request) { + StartedThreadCpuTimer timer; + + ASSIGN_OR_RETURN( + std::unique_ptr cipher, + CommutativeElGamal::CreateWithNewKeyPair(request.curve_id())); + ASSIGN_OR_RETURN(ElGamalCiphertext public_key, cipher->GetPublicKeyBytes()); + ASSIGN_OR_RETURN(std::string private_key, cipher->GetPrivateKeyBytes()); + + CompleteReachOnlyInitializationPhaseResponse response; + response.mutable_el_gamal_key_pair()->mutable_public_key()->set_generator( + public_key.first); + response.mutable_el_gamal_key_pair()->mutable_public_key()->set_element( + public_key.second); + response.mutable_el_gamal_key_pair()->set_secret_key(private_key); + + response.set_elapsed_cpu_time_millis(timer.ElapsedMillis()); + return response; +} + +absl::StatusOr CompleteReachOnlySetupPhase( + const CompleteReachOnlySetupPhaseRequest& request) { + StartedThreadCpuTimer timer; + + CompleteReachOnlySetupPhaseResponse response; + std::string* response_crv = response.mutable_combined_register_vector(); + + *response_crv = request.combined_register_vector(); + + ProtocolCryptorOptions protocol_cryptor_options{ + .curve_id = static_cast(request.curve_id()), + .local_el_gamal_public_key = kGenerateWithNewElGamalPublicKey, + .local_el_gamal_private_key = + std::string(kGenerateWithNewElGamalPrivateKey), + .local_pohlig_hellman_private_key = + std::string(kGenerateWithNewPohligHellmanKey), + .composite_el_gamal_public_key = + std::make_pair(request.composite_el_gamal_public_key().generator(), + request.composite_el_gamal_public_key().element()), + .partial_composite_el_gamal_public_key = + kGenerateNewParitialCompositeCipher}; + + int64_t excessive_noise_count = 0; + + if (request.has_noise_parameters()) { + const RegisterNoiseGenerationParameters& noise_parameters = + request.noise_parameters(); + + auto blind_histogram_noiser = GetBlindHistogramNoiser( + noise_parameters.dp_params().blind_histogram(), + noise_parameters.contributors_count(), request.noise_mechanism()); + + auto publisher_noiser = GetPublisherNoiser( + noise_parameters.dp_params().noise_for_publisher_noise(), + noise_parameters.total_sketches_count(), + noise_parameters.contributors_count(), request.noise_mechanism()); + + auto global_reach_dp_noiser = GetGlobalReachDpNoiser( + noise_parameters.dp_params().global_reach_dp_noise(), + noise_parameters.contributors_count(), request.noise_mechanism()); + + // The total noise registers added. There are additional 2 noise count here + // to make sure that at least 1 publisher noise and 1 padding noise will be + // added. + int64_t total_noise_registers_count = + publisher_noiser->options().shift_offset * 2 + + global_reach_dp_noiser->options().shift_offset * 2 + + blind_histogram_noiser->options().shift_offset * + noise_parameters.total_sketches_count() * + (noise_parameters.total_sketches_count() + 1) + + 2; + + // Resize the space to hold all output data. + size_t pos = response_crv->size(); + response_crv->resize(request.combined_register_vector().size() + + total_noise_registers_count * kBytesPerCipherText); + + RETURN_IF_ERROR(ValidateReachOnlySetupNoiseParameters(noise_parameters)); + ASSIGN_OR_RETURN(auto multithreading_helper, + MultithreadingHelper::CreateMultithreadingHelper( + request.parallelism(), protocol_cryptor_options)); + + // 1. Add blinded histogram noise. + ASSIGN_OR_RETURN( + int64_t blinded_histogram_noise_count, + AddReachOnlyBlindedHistogramNoise( + multithreading_helper->GetProtocolCryptor(), + noise_parameters.total_sketches_count(), *blind_histogram_noiser, + pos, *response_crv, excessive_noise_count)); + pos += kBytesPerCipherText * blinded_histogram_noise_count; + // 2. Add noise for publisher noise. Publisher noise count is at least 1. + ASSIGN_OR_RETURN( + int64_t publisher_noise_count, + AddReachOnlyNoiseForPublisherNoise( + *multithreading_helper, *publisher_noiser, pos, *response_crv)); + pos += kBytesPerCipherText * publisher_noise_count; + // 3. Add reach DP noise. + ASSIGN_OR_RETURN(int64_t reach_dp_noise_count, + AddReachOnlyGlobalReachDpNoise(*multithreading_helper, + *global_reach_dp_noiser, + pos, *response_crv)); + pos += kBytesPerCipherText * reach_dp_noise_count; + // 4. Add padding noise. Padding noise count will be at least 1. + int64_t padding_noise_count = total_noise_registers_count - + blinded_histogram_noise_count - + publisher_noise_count - reach_dp_noise_count; + RETURN_IF_ERROR(AddReachOnlyPaddingReachNoise( + *multithreading_helper, padding_noise_count, pos, *response_crv)); + } + + // Encrypt the excessive noise. + ASSIGN_OR_RETURN(std::unique_ptr protocol_cryptor, + CreateProtocolCryptor(protocol_cryptor_options)); + ASSIGN_OR_RETURN( + std::string serialized_excessive_noise_ciphertext, + protocol_cryptor->EncryptIntegerWithCompositElGamalAndWriteToString( + excessive_noise_count)); + + response.set_serialized_excessive_noise_ciphertext( + serialized_excessive_noise_ciphertext); + + RETURN_IF_ERROR(SortStringByBlock( + *response.mutable_combined_register_vector())); + + response.set_elapsed_cpu_time_millis(timer.ElapsedMillis()); + return response; +} + +absl::StatusOr +CompleteReachOnlySetupPhaseAtAggregator( + const CompleteReachOnlySetupPhaseRequest& request) { + StartedThreadCpuTimer timer; + ASSIGN_OR_RETURN(CompleteReachOnlySetupPhaseResponse response, + CompleteReachOnlySetupPhase(request)); + + // Get the ElGamal encryption of the excessive noise on the aggregator. + ASSIGN_OR_RETURN( + ElGamalEcPointPair ec_point, + GetEcPointPairFromString(response.serialized_excessive_noise_ciphertext(), + request.curve_id())); + + // Combined the excessive_noise ciphertexts. + int num_ciphertexts = request.serialized_excessive_noise_ciphertext().size() / + kBytesPerCipherText; + for (int i = 0; i < num_ciphertexts; i++) { + ASSIGN_OR_RETURN(ElGamalEcPointPair temp, + GetEcPointPairFromString( + request.serialized_excessive_noise_ciphertext().substr( + i * kBytesPerCipherText, kBytesPerCipherText), + request.curve_id())); + ASSIGN_OR_RETURN(ec_point, AddEcPointPairs(ec_point, temp)); + } + + std::string excessive_noise_string; + excessive_noise_string.resize(kBytesPerCipherText); + RETURN_IF_ERROR( + WriteEcPointPairToString(ec_point, 0, excessive_noise_string)); + + response.set_serialized_excessive_noise_ciphertext(excessive_noise_string); + + response.set_elapsed_cpu_time_millis(timer.ElapsedMillis()); + return response; +} + +absl::StatusOr +CompleteReachOnlyExecutionPhase( + const CompleteReachOnlyExecutionPhaseRequest& request) { + StartedThreadCpuTimer timer; + + ASSIGN_OR_RETURN(size_t register_count, + GetNumberOfBlocks(request.combined_register_vector(), + kBytesPerCipherText)); + + ProtocolCryptorOptions protocol_cryptor_options{ + .curve_id = static_cast(request.curve_id()), + .local_el_gamal_public_key = std::make_pair( + request.local_el_gamal_key_pair().public_key().generator(), + request.local_el_gamal_key_pair().public_key().element()), + .local_el_gamal_private_key = + request.local_el_gamal_key_pair().secret_key(), + .local_pohlig_hellman_private_key = + std::string(kGenerateWithNewPohligHellmanKey), + .composite_el_gamal_public_key = kGenerateNewCompositeCipher, + .partial_composite_el_gamal_public_key = + kGenerateNewParitialCompositeCipher}; + ASSIGN_OR_RETURN(auto multithreading_helper, + MultithreadingHelper::CreateMultithreadingHelper( + request.parallelism(), protocol_cryptor_options)); + + CompleteReachOnlyExecutionPhaseResponse response; + // Partially decrypt the aggregated excessive noise ciphertext. + ASSIGN_OR_RETURN(std::unique_ptr protocol_cryptor, + CreateProtocolCryptor(protocol_cryptor_options)); + + std::string updated_noise_ciphertext; + updated_noise_ciphertext.resize(kBytesPerCipherText); + RETURN_IF_ERROR(protocol_cryptor->BatchProcess( + request.serialized_excessive_noise_ciphertext(), + {Action::kPartialDecrypt}, 0, updated_noise_ciphertext)); + response.set_serialized_excessive_noise_ciphertext(updated_noise_ciphertext); + + std::string* response_crv = response.mutable_combined_register_vector(); + // The output crv is the same size with the input crv. + size_t start_pos = 0; + response_crv->resize(request.combined_register_vector().size()); + + absl::AnyInvocable f = + [&](ProtocolCryptor& cryptor, size_t index) -> absl::Status { + absl::string_view current_block = + absl::string_view(request.combined_register_vector()) + .substr(index * kBytesPerCipherText, kBytesPerCipherText); + size_t pos = start_pos + kBytesPerCipherText * index; + + RETURN_IF_ERROR(cryptor.BatchProcess(current_block, {Action::kBlind}, pos, + *response_crv)); + + return absl::OkStatus(); + }; + + RETURN_IF_ERROR(multithreading_helper->Execute(register_count, f)); + RETURN_IF_ERROR(SortStringByBlock(*response_crv)); + + response.set_elapsed_cpu_time_millis(timer.ElapsedMillis()); + return response; +} + +absl::StatusOr +CompleteReachOnlyExecutionPhaseAtAggregator( + const CompleteReachOnlyExecutionPhaseAtAggregatorRequest& request) { + StartedThreadCpuTimer timer; + + if (request.combined_register_vector().size() % kBytesPerCipherText != 0) { + return absl::InvalidArgumentError(absl::StrCat( + "The size of byte array is not divisible by the block_size: ", + request.combined_register_vector().size())); + } + + ProtocolCryptorOptions protocol_cryptor_options{ + .curve_id = static_cast(request.curve_id()), + .local_el_gamal_public_key = std::make_pair( + request.local_el_gamal_key_pair().public_key().generator(), + request.local_el_gamal_key_pair().public_key().element()), + .local_el_gamal_private_key = + request.local_el_gamal_key_pair().secret_key(), + .local_pohlig_hellman_private_key = + std::string(kGenerateWithNewPohligHellmanKey), + .composite_el_gamal_public_key = kGenerateNewCompositeCipher, + .partial_composite_el_gamal_public_key = + kGenerateNewParitialCompositeCipher}; + + // Decrypt the aggregated excessive noise ciphertext to get the excessive + // noise count. + int64_t excessive_noise_count = 0; + ASSIGN_OR_RETURN(std::unique_ptr protocol_cryptor, + CreateProtocolCryptor(protocol_cryptor_options)); + ASSIGN_OR_RETURN(ElGamalCiphertext ciphertext, + ExtractElGamalCiphertextFromString( + request.serialized_excessive_noise_ciphertext())); + ASSIGN_OR_RETURN( + bool isZero, + protocol_cryptor->IsDecryptLocalElGamalResultZero(ciphertext)); + if (!isZero) { + ASSIGN_OR_RETURN(std::string plaintext, + protocol_cryptor->DecryptLocalElGamal(ciphertext)); + + auto blind_histogram_noiser = GetBlindHistogramNoiser( + request.noise_parameters().dp_params().blind_histogram(), + request.noise_parameters().contributors_count(), + request.noise_mechanism()); + int max_excessive_noise = + blind_histogram_noiser->options().shift_offset * 2 * + request.noise_parameters().total_sketches_count() * + request.noise_parameters().total_sketches_count(); + // The lookup table stores the max_excessive_noise EC points where + // ec_lookup_table[i] = (i+1)*ec_generator. + ASSIGN_OR_RETURN( + std::vector ec_lookup_table, + GetCountValuesPlaintext(max_excessive_noise, request.curve_id())); + // Decrypt the excessive noise using the lookup table. + for (int i = 0; i < ec_lookup_table.size(); i++) { + if (ec_lookup_table[i] == plaintext) { + excessive_noise_count = i + 1; + break; + } + } + } + + ASSIGN_OR_RETURN(auto multithreading_helper, + MultithreadingHelper::CreateMultithreadingHelper( + request.parallelism(), protocol_cryptor_options)); + + ASSIGN_OR_RETURN(std::vector blinded_register_indexes, + GetRollv2BlindedRegisterIndexes( + request.combined_register_vector(), + multithreading_helper->GetProtocolCryptor())); + CompleteReachOnlyExecutionPhaseAtAggregatorResponse response; + + // Counting the number of unique registers. + int64_t non_empty_register_count = + CountUniqueElements(blinded_register_indexes); + + // Excluding the blind histogram, padding noise, and the excessive noise from + // the unique register count. It is guaranteed that if noise is added, then + // there exist publisher noise and padding noise. + if (request.has_noise_parameters()) { + non_empty_register_count -= 2; + } + // Remove the total excessive blind histogram noise. + non_empty_register_count = non_empty_register_count - excessive_noise_count; + // Remove the reach dp noise baseline. + if (request.has_reach_dp_noise_baseline()) { + auto noiser = GetGlobalReachDpNoiser( + request.reach_dp_noise_baseline().global_reach_dp_noise(), + request.reach_dp_noise_baseline().contributors_count(), + request.noise_mechanism()); + const auto& noise_options = noiser->options(); + int64_t global_reach_dp_noise_baseline = + noise_options.shift_offset * noise_options.contributor_count; + non_empty_register_count -= global_reach_dp_noise_baseline; + } + + // If the noise added is less than the baseline, the non empty register count + // could be negative. Make sure that it is non-negative. + if (non_empty_register_count < 0) { + non_empty_register_count = 0; + } + + // Estimate the reach + ASSIGN_OR_RETURN( + int64_t reach, + EstimateReach(request.liquid_legions_parameters().decay_rate(), + request.liquid_legions_parameters().size(), + non_empty_register_count, + request.vid_sampling_interval_width())); + + response.set_reach(reach); + response.set_elapsed_cpu_time_millis(timer.ElapsedMillis()); + return response; +} + +} // namespace wfa::measurement::internal::duchy::protocol::liquid_legions_v2 diff --git a/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility.h b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility.h new file mode 100644 index 00000000000..a410bf4eae1 --- /dev/null +++ b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility.h @@ -0,0 +1,93 @@ +// Copyright 2023 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SRC_MAIN_CC_WFA_MEASUREMENT_INTERNAL_DUCHY_PROTOCOL_REACH_ONLY_LIQUID_LEGIONS_V2_LIQUID_LEGIONS_V2_ENCRYPTION_UTILITY_H_ +#define SRC_MAIN_CC_WFA_MEASUREMENT_INTERNAL_DUCHY_PROTOCOL_REACH_ONLY_LIQUID_LEGIONS_V2_LIQUID_LEGIONS_V2_ENCRYPTION_UTILITY_H_ + +#include "absl/status/statusor.h" +#include "wfa/measurement/internal/duchy/protocol/reach_only_liquid_legions_v2_encryption_methods.pb.h" + +namespace wfa::measurement::internal::duchy::protocol::liquid_legions_v2 { + +using ::wfa::measurement::internal::duchy::protocol:: + CompleteReachOnlyExecutionPhaseAtAggregatorRequest; +using ::wfa::measurement::internal::duchy::protocol:: + CompleteReachOnlyExecutionPhaseAtAggregatorResponse; +using ::wfa::measurement::internal::duchy::protocol:: + CompleteReachOnlyExecutionPhaseRequest; +using ::wfa::measurement::internal::duchy::protocol:: + CompleteReachOnlyExecutionPhaseResponse; +using ::wfa::measurement::internal::duchy::protocol:: + CompleteReachOnlyInitializationPhaseRequest; +using ::wfa::measurement::internal::duchy::protocol:: + CompleteReachOnlyInitializationPhaseResponse; +using ::wfa::measurement::internal::duchy::protocol:: + CompleteReachOnlySetupPhaseAtAggregatorResponse; +using ::wfa::measurement::internal::duchy::protocol:: + CompleteReachOnlySetupPhaseRequest; +using ::wfa::measurement::internal::duchy::protocol:: + CompleteReachOnlySetupPhaseResponse; + +// Complete work in the initialization phase at both the aggregator and +// non-aggregator workers. More specifically, the worker would generate a random +// set of ElGamal Key pair. +absl::StatusOr +CompleteReachOnlyInitializationPhase( + const CompleteReachOnlyInitializationPhaseRequest& request); + +// Complete work in the setup phase at the non-aggregator workers. More +// specifically, the worker would +// 1. add local noise registers (if configured to). +// 2. shuffle all registers. +// 3. stores the amount of excessive noise that it can remove to the database. +absl::StatusOr CompleteReachOnlySetupPhase( + const CompleteReachOnlySetupPhaseRequest& request); + +// Complete work in the setup phase at the aggregator. More specifically, the +// aggregator would +// 1. add local noise registers (if configured to). +// 2. shuffle all registers. +// 3. sample a Paillier keypair +// 4. encrypt the excessive noise using Paillier encryption. +absl::StatusOr +CompleteReachOnlySetupPhaseAtAggregator( + const CompleteReachOnlySetupPhaseRequest& request); + +// Complete work in the execution phase one at a non-aggregator worker. +// More specifically, the worker would +// 1. blind the positions (decrypt local ElGamal layer and then add another +// layer of deterministic pohlig_hellman encryption. +// 2. re-randomize keys and counts. +// 3. shuffle all registers. +// 4. adds its excessive noise to the ciphertext that stores the aggregated +// excessive noise to be removed. +absl::StatusOr +CompleteReachOnlyExecutionPhase( + const CompleteReachOnlyExecutionPhaseRequest& request); + +// Complete work in the execution phase one at the aggregator worker. +// More specifically, the worker would +// 1. decrypt the local ElGamal encryption on the positions. +// 2. join the registers by positions. +// 3. count the number of unique registers, excluding the blinded histogram +// noise and the publisher noise. +// 4. decrypt the Paillier ciphertext that stores the aggregated excessive +// noise and subtract it from the total register count. +absl::StatusOr +CompleteReachOnlyExecutionPhaseAtAggregator( + const CompleteReachOnlyExecutionPhaseAtAggregatorRequest& request); + +} // namespace wfa::measurement::internal::duchy::protocol::liquid_legions_v2 + +#endif // SRC_MAIN_CC_WFA_MEASUREMENT_INTERNAL_DUCHY_PROTOCOL_REACH_ONLY_LIQUID_LEGIONS_V2_LIQUID_LEGIONS_V2_ENCRYPTION_UTILITY_H_ diff --git a/src/main/proto/wfa/measurement/internal/duchy/crypto.proto b/src/main/proto/wfa/measurement/internal/duchy/crypto.proto index 8992321dc27..3b4c3061b62 100644 --- a/src/main/proto/wfa/measurement/internal/duchy/crypto.proto +++ b/src/main/proto/wfa/measurement/internal/duchy/crypto.proto @@ -59,25 +59,3 @@ message EncryptionPublicKey { // decrypt messages given a private key. bytes data = 2; } - -// Holds a Paillier Public Key. -message PaillierPublicKey { - // The Paillier modulus n. - optional bytes n = 1; - // Contains the Damgard-Jurik exponent corresponding to this key. The Paillier - // modulus will be n^(s+1), and the message space will be n^s. - optional int32 s = 2; -} - -// Holds a Paillier Private Key. -message PaillierPrivateKey { - // One of the two large prime factors of the Paillier modulus n. - optional bytes p = 1; - - // One of the two large prime factors of the Paillier modulus n. - optional bytes q = 2; - - // Contains the Damgard-Jurik exponent corresponding to this key. The Paillier - // modulus will be n^(s+1), and the message space will be n^s. - optional int32 s = 3; -} diff --git a/src/main/proto/wfa/measurement/internal/duchy/protocol/reach_only_liquid_legions_sketch_aggregation_v2.proto b/src/main/proto/wfa/measurement/internal/duchy/protocol/reach_only_liquid_legions_sketch_aggregation_v2.proto index 8e754dc8039..37ec2855620 100644 --- a/src/main/proto/wfa/measurement/internal/duchy/protocol/reach_only_liquid_legions_sketch_aggregation_v2.proto +++ b/src/main/proto/wfa/measurement/internal/duchy/protocol/reach_only_liquid_legions_sketch_aggregation_v2.proto @@ -151,16 +151,6 @@ message ReachOnlyLiquidLegionsSketchAggregationV2 { // TODO(@ple13): delete this field when we switch to use a secure key // store for duchy private keys. ElGamalKeyPair local_elgamal_key = 7; - - // Paillier Private Key used to decrypt the total excessive noise to be - // removed from the register count in the execution phase. - // TODO(@ple13): delete this field when we switch to use a secure key - // store for duchy private keys. Only the aggregator samples and stores this - // key. - PaillierPrivateKey local_paillier_key = 8; - - // Noise to be removed. - int64 excess_noise = 9; } // Details about a particular attempt of running a stage of the LiquidLegionV2 diff --git a/src/main/proto/wfa/measurement/internal/duchy/protocol/reach_only_liquid_legions_v2_encryption_methods.proto b/src/main/proto/wfa/measurement/internal/duchy/protocol/reach_only_liquid_legions_v2_encryption_methods.proto index cff012019dc..a332441e1ca 100644 --- a/src/main/proto/wfa/measurement/internal/duchy/protocol/reach_only_liquid_legions_v2_encryption_methods.proto +++ b/src/main/proto/wfa/measurement/internal/duchy/protocol/reach_only_liquid_legions_v2_encryption_methods.proto @@ -53,13 +53,20 @@ message CompleteReachOnlySetupPhaseRequest { // The CRV is only needed so the noise can be interleaved and hidden in the // CRV. The registers in the CRV are unchanged, except for their orders. bytes combined_register_vector = 1; + // The elliptical curve to work on. + int64 curve_id = 2; // The parameters required for generating noise registers. // if unset, the worker only shuffles the register without adding any noise. - RegisterNoiseGenerationParameters noise_parameters = 2; + RegisterNoiseGenerationParameters noise_parameters = 3; // The mechanism used to generate noise. - LiquidLegionsV2NoiseConfig.NoiseMechanism noise_mechanism = 3; + LiquidLegionsV2NoiseConfig.NoiseMechanism noise_mechanism = 4; + // Public Key of the composite ElGamal cipher. Used to encrypt the excessive + // noise (which is zero) when noise_parameters is not available. + ElGamalPublicKey composite_el_gamal_public_key = 5; + // The attached encrypted excessive noises. Only for the aggregator. + bytes serialized_excessive_noise_ciphertext = 6; // The maximum number of threads used by crypto actions. - int32 parallelism = 4; + int32 parallelism = 7; } // Response of the CompleteReachOnlySetupPhase method. @@ -68,27 +75,23 @@ message CompleteReachOnlySetupPhaseResponse { // and noise registers. bytes combined_register_vector = 1; // The excessive noise that can be removed in the execution phase. - int64 excessive_noise = 2; + bytes serialized_excessive_noise_ciphertext = 2; // The CPU time of processing the request. int64 elapsed_cpu_time_millis = 3; } // Response of the CompleteReachOnlySetupPhase method at the aggregate worker. -// Different from the non-aggregator, the aggregator samples the Paillier key +// Different from the non-aggregator, the aggregator samples the El Gamal key // pair and encrypts its excessive noise with the public key. message CompleteReachOnlySetupPhaseAtAggregatorResponse { // The output combined register vector (CRV), which contains shuffled input // and noise registers. bytes combined_register_vector = 1; - // The Paillier private key. - PaillierPrivateKey paillier_private_key = 2; - // The Paillier public key. - PaillierPublicKey paillier_public_key = 3; - // The serialized Paillier ciphertext that encrypts the aggregated excessive + // The serialized El Gamal ciphertext that encrypts the aggregated excessive // noise of the aggregator. - bytes serialized_aggregated_noise_ciphertext = 4; + bytes serialized_excessive_noise_ciphertext = 2; // The CPU time of processing the request. - int64 elapsed_cpu_time_millis = 5; + int64 elapsed_cpu_time_millis = 3; } // The request to complete work in the execution phase at a non-aggregator @@ -100,21 +103,13 @@ message CompleteReachOnlyExecutionPhaseRequest { bytes combined_register_vector = 1; // Key pair of the local ElGamal cipher. Required. ElGamalKeyPair local_el_gamal_key_pair = 2; - // Public Key of the composite ElGamal cipher. Used to re-randomize the keys - // and counts. - ElGamalPublicKey composite_el_gamal_public_key = 3; // The elliptical curve to work on. - int64 curve_id = 4; - // The excessive noise that will be removed. The noise was computed during the - // Setup phase, and stored in the database. - int64 excessive_noise = 5; - // The Paillier public key. - PaillierPublicKey paillier_public_key = 6; - // The serialized Paillier ciphertext that encrypts the aggregated excessive + int64 curve_id = 3; + // The serialized El Gamal ciphertext that encrypts the aggregated excessive // noise. - bytes serialized_aggregated_noise_ciphertext = 7; + bytes serialized_excessive_noise_ciphertext = 4; // The maximum number of threads used by crypto actions. - int32 parallelism = 8; + int32 parallelism = 5; } // Response of the CompleteReachOnlyExecution method. @@ -124,9 +119,9 @@ message CompleteReachOnlyExecutionPhaseResponse { // bytes ElGamal ciphertext. In other words, the CRV size should be divisible // by 66. bytes combined_register_vector = 1; - // The serialized Paillier ciphertext that encrypts the aggregated excessive + // The serialized El Gamal ciphertext that encrypts the aggregated excessive // noise. - bytes serialized_aggregated_noise_ciphertext = 2; + bytes serialized_excessive_noise_ciphertext = 2; // The CPU time of processing the request. int64 elapsed_cpu_time_millis = 3; } @@ -142,29 +137,33 @@ message CompleteReachOnlyExecutionPhaseAtAggregatorRequest { bytes combined_register_vector = 1; // Key pair of the local ElGamal cipher. Required. ElGamalKeyPair local_el_gamal_key_pair = 2; - // Public Key of the composite ElGamal cipher. Used to encrypt the random - // numbers in SameKeyAggregation. - ElGamalPublicKey composite_el_gamal_public_key = 3; // The elliptical curve to work on. - int64 curve_id = 4; - // The Paillier private key to decrypt the aggregated noise ciphertext. - PaillierPrivateKey paillier_private_key = 5; - // The serialized Paillier ciphertext that encrypts the aggregated excessive + int64 curve_id = 3; + // The serialized El Gamal ciphertext that encrypts the aggregated excessive // noise. - bytes serialized_aggregated_noise_ciphertext = 6; + bytes serialized_excessive_noise_ciphertext = 4; + // Parameters for computing the noise baseline of the global reach DP noise + // registers added in the setup phase. + // The baseline is subtracted before reach is estimated. + GlobalReachDpNoiseBaseline reach_dp_noise_baseline = 5; // LiquidLegions parameters used for reach estimation. - LiquidLegionsSketchParameters liquid_legions_parameters = 7; + LiquidLegionsSketchParameters liquid_legions_parameters = 6; // The sampling rate to be used by the LiquidLegionsV2 protocol. // This is taken from the VidSamplingInterval.width parameter in the // MeasurementSpec. - float vid_sampling_interval_width = 8; + float vid_sampling_interval_width = 7; + // The parameters required for generating noise registers. + // if unset, the worker only shuffles the register without adding any noise. + RegisterNoiseGenerationParameters noise_parameters = 8; + // The mechanism used to generate noise in previous phases. + LiquidLegionsV2NoiseConfig.NoiseMechanism noise_mechanism = 9; // The maximum number of threads used by crypto actions. - int32 parallelism = 9; + int32 parallelism = 10; } // The response of the CompleteReachOnlyExecutionAtAggregator method. message CompleteReachOnlyExecutionPhaseAtAggregatorResponse { - // The number of register count. + // The estimated reach. int64 reach = 1; // The CPU time of processing the request. int64 elapsed_cpu_time_millis = 2; diff --git a/src/test/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/BUILD.bazel b/src/test/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/BUILD.bazel index c1b7631349a..ebf6eba79b1 100644 --- a/src/test/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/BUILD.bazel +++ b/src/test/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/BUILD.bazel @@ -1,51 +1,70 @@ load("@rules_cc//cc:defs.bzl", "cc_test") -cc_test( - name = "liquid_legions_v2_encryption_utility_test", - size = "small", +cc_test(name = "liquid_legions_v2_encryption_utility_test", size = "small", timeout = "moderate", - srcs = [ - "liquid_legions_v2_encryption_utility_test.cc", - ], - deps = [ - "//src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2:liquid_legions_v2_encryption_utility", - "//src/main/proto/wfa/measurement/internal/duchy/protocol:liquid_legions_v2_noise_config_cc_proto", - "@any_sketch//src/main/cc/any_sketch/crypto:sketch_encrypter", - "@any_sketch//src/main/cc/estimation:estimators", - "@any_sketch//src/main/proto/wfa/any_sketch:sketch_cc_proto", - "@com_google_googletest//:gtest", - "@com_google_googletest//:gtest_main", - "@wfa_common_cpp//src/main/cc/common_cpp/testing:status", - ], -) + srcs = + [ + "liquid_legions_v2_encryption_utility_test.cc", + ], + deps = + [ + "//src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2:liquid_legions_v2_encryption_utility", + "//src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2:liquid_legions_v2_encryption_utility_helper", + "//src/main/proto/wfa/measurement/internal/duchy/protocol:liquid_legions_v2_noise_config_cc_proto", + "@any_sketch//src/main/cc/any_sketch/crypto:sketch_encrypter", + "@any_sketch//src/main/cc/estimation:estimators", + "@any_sketch//src/main/proto/wfa/any_sketch:sketch_cc_proto", + "@com_google_googletest//:gtest", + "@com_google_googletest//:gtest_main", + "@wfa_common_cpp//src/main/cc/common_cpp/testing:status", + ], ) cc_test( - name = "noise_parameters_computation_test", - size = "small", - srcs = [ - "noise_parameters_computation_test.cc", - ], - deps = [ - "//src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2:noise_parameters_computation", - "//src/main/proto/wfa/measurement/internal/duchy/protocol:liquid_legions_v2_noise_config_cc_proto", - "@any_sketch//src/main/cc/math:distributed_discrete_gaussian_noiser", - "@any_sketch//src/main/cc/math:distributed_geometric_noiser", - "@com_google_googletest//:gtest", - "@com_google_googletest//:gtest_main", - ], -) + name = "reach_only_liquid_legions_v2_encryption_utility_test", + size = "small", timeout = "moderate", + srcs = + [ + "reach_only_liquid_legions_v2_encryption_utility_test.cc", + ], + deps = + [ + "//src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2:reach_only_liquid_legions_v2_encryption_utility", + "//src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2:liquid_legions_v2_encryption_utility_helper", + "//src/main/proto/wfa/measurement/internal/duchy/protocol:liquid_legions_v2_noise_config_cc_proto", + "@any_sketch//src/main/cc/any_sketch/crypto:sketch_encrypter", + "@any_sketch//src/main/cc/estimation:estimators", + "@any_sketch//src/main/proto/wfa/any_sketch:sketch_cc_proto", + "@com_google_googletest//:gtest", + "@com_google_googletest//:gtest_main", + "@wfa_common_cpp//src/main/cc/common_cpp/testing:status", + ], ) + +cc_test( + name = "noise_parameters_computation_test", size = "small", + srcs = + [ + "noise_parameters_computation_test.cc", + ], + deps = + [ + "//src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2:noise_parameters_computation", + "//src/main/proto/wfa/measurement/internal/duchy/protocol:liquid_legions_v2_noise_config_cc_proto", + "@any_sketch//src/main/cc/math:distributed_discrete_gaussian_noiser", + "@any_sketch//src/main/cc/math:distributed_geometric_noiser", + "@com_google_googletest//:gtest", + "@com_google_googletest//:gtest_main", + ], ) cc_test( - name = "multithreading_helper_test", - size = "small", - srcs = [ - "multithreading_helper_test.cc", - ], + name = "multithreading_helper_test", size = "small", + srcs = + [ + "multithreading_helper_test.cc", + ], deps = [ "//src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2:multithreading_helper", "@com_google_absl//absl/functional:any_invocable", "@com_google_googletest//:gtest", "@com_google_googletest//:gtest_main", "@wfa_common_cpp//src/main/cc/common_cpp/testing:status", - ], -) + ], ) diff --git a/src/test/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility_test.cc b/src/test/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility_test.cc index d3c9696f8ef..1921500f5b2 100644 --- a/src/test/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility_test.cc +++ b/src/test/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility_test.cc @@ -30,6 +30,7 @@ #include "wfa/measurement/common/crypto/constants.h" #include "wfa/measurement/common/crypto/ec_point_util.h" #include "wfa/measurement/common/crypto/encryption_utility_helper.h" +#include "wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility_helper.h" #include "wfa/measurement/internal/duchy/protocol/liquid_legions_v2_encryption_methods.pb.h" namespace wfa::measurement::internal::duchy::protocol::liquid_legions_v2 { @@ -86,39 +87,6 @@ void AddRegister(Sketch* sketch, const int index, const int key, register_ptr->add_values(count); } -::wfa::any_sketch::crypto::ElGamalPublicKey ToAnysketchElGamalKey( - ElGamalPublicKey key) { - ::wfa::any_sketch::crypto::ElGamalPublicKey result; - result.set_generator(key.generator()); - result.set_element(key.element()); - return result; -} - -ElGamalPublicKey ToCmmsElGamalKey( - ::wfa::any_sketch::crypto::ElGamalPublicKey key) { - ElGamalPublicKey result; - result.set_generator(key.generator()); - result.set_element(key.element()); - return result; -} - -Sketch CreateEmptyLiquidLegionsSketch() { - Sketch plain_sketch; - plain_sketch.mutable_config()->add_values()->set_aggregator( - SketchConfig::ValueSpec::UNIQUE); - plain_sketch.mutable_config()->add_values()->set_aggregator( - SketchConfig::ValueSpec::SUM); - return plain_sketch; -} - -DifferentialPrivacyParams MakeDifferentialPrivacyParams(double epsilon, - double delta) { - DifferentialPrivacyParams params; - params.set_epsilon(epsilon); - params.set_delta(delta); - return params; -} - // Partition the char vector 33 by 33, and convert the results to strings std::vector GetCipherStrings(absl::string_view bytes) { ABSL_ASSERT(bytes.size() % 66 == 0); diff --git a/src/test/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility_test.cc b/src/test/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility_test.cc new file mode 100644 index 00000000000..8e4a1e3af9b --- /dev/null +++ b/src/test/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility_test.cc @@ -0,0 +1,654 @@ +// Copyright 2023 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility.h" + +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "any_sketch/crypto/sketch_encrypter.h" +#include "common_cpp/testing/status_macros.h" +#include "common_cpp/testing/status_matchers.h" +#include "estimation/estimators.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "openssl/obj_mac.h" +#include "private_join_and_compute/crypto/commutative_elgamal.h" +#include "private_join_and_compute/crypto/ec_commutative_cipher.h" +#include "wfa/any_sketch/sketch.pb.h" +#include "wfa/measurement/common/crypto/constants.h" +#include "wfa/measurement/common/crypto/ec_point_util.h" +#include "wfa/measurement/common/crypto/encryption_utility_helper.h" +#include "wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility_helper.h" +#include "wfa/measurement/internal/duchy/protocol/liquid_legions_v2_encryption_methods.pb.h" + +namespace wfa::measurement::internal::duchy::protocol::liquid_legions_v2 { +namespace { + +using ::private_join_and_compute::BigNum; +using ::private_join_and_compute::CommutativeElGamal; +using ::private_join_and_compute::Context; +using ::private_join_and_compute::ECCommutativeCipher; +using ::private_join_and_compute::ECGroup; +using ::private_join_and_compute::ECPoint; +using ::testing::DoubleNear; +using ::testing::Pair; +using ::testing::SizeIs; +using ::testing::UnorderedElementsAre; +using ::wfa::any_sketch::Sketch; +using ::wfa::any_sketch::SketchConfig; +using ::wfa::measurement::common::crypto::ElGamalCiphertext; +using ::wfa::measurement::common::crypto::ExtractElGamalCiphertextFromString; +using ::wfa::measurement::common::crypto::GetCountValuesPlaintext; +using ::wfa::measurement::common::crypto::kBlindedHistogramNoiseRegisterKey; +using ::wfa::measurement::common::crypto::kBytesPerCipherText; +using ::wfa::measurement::common::crypto::kDestroyedRegisterKey; +using ::wfa::measurement::common::crypto::kPaddingNoiseRegisterId; +using ::wfa::measurement::common::crypto::kPublisherNoiseRegisterId; +using ::wfa::measurement::internal::duchy::DifferentialPrivacyParams; +using ::wfa::measurement::internal::duchy::ElGamalKeyPair; +using ::wfa::measurement::internal::duchy::ElGamalPublicKey; +using ::wfa::measurement::internal::duchy::protocol::LiquidLegionsV2NoiseConfig; + +constexpr int kWorkerCount = 3; +constexpr int kPublisherCount = 3; +constexpr int kMaxFrequency = 10; +constexpr int kTestCurveId = NID_X9_62_prime256v1; +constexpr int kParallelism = 3; +constexpr int kBytesPerEcPoint = 33; +constexpr int kBytesCipherText = kBytesPerEcPoint * 2; +constexpr int kDecayRate = 12; +constexpr int kLiquidLegionsSize = 100 * 1000; +constexpr float kVidSamplingIntervalWidth = 0.5; + +struct ReachOnlyMpcResult { + int64_t reach; +}; + +void AddRegister(Sketch* sketch, const int index) { + auto register_ptr = sketch->add_registers(); + register_ptr->set_index(index); +} + +MATCHER_P(IsBlockSorted, block_size, "") { + if (arg.length() % block_size != 0) { + return false; + } + for (size_t i = block_size; i < arg.length(); i += block_size) { + if (arg.substr(i, block_size) < arg.substr(i - block_size, block_size)) { + return false; + } + } + return true; +} + +// The ReachOnlyTest generates cipher keys for 3 duchies, and the combined +// public key for the data providers. +class ReachOnlyTest { + public: + ElGamalKeyPair duchy_1_el_gamal_key_pair_; + std::string duchy_1_p_h_key_; + ElGamalKeyPair duchy_2_el_gamal_key_pair_; + std::string duchy_2_p_h_key_; + ElGamalKeyPair duchy_3_el_gamal_key_pair_; + std::string duchy_3_p_h_key_; + ElGamalPublicKey client_el_gamal_public_key_; // combined from 3 duchy keys; + ElGamalPublicKey duchy_2_3_composite_public_key_; // combined from duchy 2 + // and duchy_3 public keys; + std::unique_ptr sketch_encrypter_; + + ReachOnlyTest() { + CompleteReachOnlyInitializationPhaseRequest + complete_reach_only_initialization_phase_request; + complete_reach_only_initialization_phase_request.set_curve_id(kTestCurveId); + + duchy_1_el_gamal_key_pair_ = + CompleteReachOnlyInitializationPhase( + complete_reach_only_initialization_phase_request) + ->el_gamal_key_pair(); + duchy_2_el_gamal_key_pair_ = + CompleteReachOnlyInitializationPhase( + complete_reach_only_initialization_phase_request) + ->el_gamal_key_pair(); + duchy_3_el_gamal_key_pair_ = + CompleteReachOnlyInitializationPhase( + complete_reach_only_initialization_phase_request) + ->el_gamal_key_pair(); + + // Combine the el_gamal keys from all duchies to generate the data provider + // el_gamal key. + client_el_gamal_public_key_ = ToCmmsElGamalKey( + any_sketch::crypto::CombineElGamalPublicKeys( + kTestCurveId, + {ToAnysketchElGamalKey(duchy_1_el_gamal_key_pair_.public_key()), + ToAnysketchElGamalKey(duchy_2_el_gamal_key_pair_.public_key()), + ToAnysketchElGamalKey(duchy_3_el_gamal_key_pair_.public_key())}) + .value()); + duchy_2_3_composite_public_key_ = ToCmmsElGamalKey( + any_sketch::crypto::CombineElGamalPublicKeys( + kTestCurveId, + {ToAnysketchElGamalKey(duchy_2_el_gamal_key_pair_.public_key()), + ToAnysketchElGamalKey(duchy_3_el_gamal_key_pair_.public_key())}) + .value()); + + any_sketch::crypto::CiphertextString client_public_key = { + .u = client_el_gamal_public_key_.generator(), + .e = client_el_gamal_public_key_.element(), + }; + + // Create a sketch_encrypter for encrypting plaintext any_sketch data. + sketch_encrypter_ = any_sketch::crypto::CreateWithPublicKey( + kTestCurveId, kMaxFrequency, client_public_key) + .value(); + } + + absl::StatusOr EncryptWithFlaggedKey(const Sketch& sketch) { + return sketch_encrypter_->Encrypt( + sketch, any_sketch::crypto::EncryptSketchRequest::FLAGGED_KEY); + } + + // Helper function to go through the entire MPC protocol using the input data. + // The final ReachOnlyMpcResult are returned. + absl::StatusOr GoThroughEntireMpcProtocol( + const std::string& encrypted_sketch, + RegisterNoiseGenerationParameters* reach_noise_parameters, + LiquidLegionsV2NoiseConfig::NoiseMechanism noise_mechanism) { + // Setup phase at Duchy 1. + // We assume all test data comes from duchy 1 in the test. + CompleteReachOnlySetupPhaseRequest + complete_reach_only_setup_phase_request_1; + complete_reach_only_setup_phase_request_1.set_combined_register_vector( + encrypted_sketch); + + if (reach_noise_parameters != nullptr) { + *complete_reach_only_setup_phase_request_1.mutable_noise_parameters() = + *reach_noise_parameters; + } + complete_reach_only_setup_phase_request_1.set_noise_mechanism( + noise_mechanism); + complete_reach_only_setup_phase_request_1.set_curve_id(kTestCurveId); + *complete_reach_only_setup_phase_request_1 + .mutable_composite_el_gamal_public_key() = client_el_gamal_public_key_; + complete_reach_only_setup_phase_request_1.set_parallelism(kParallelism); + + ASSIGN_OR_RETURN( + CompleteReachOnlySetupPhaseResponse + complete_reach_only_setup_phase_response_1, + CompleteReachOnlySetupPhase(complete_reach_only_setup_phase_request_1)); + EXPECT_THAT( + complete_reach_only_setup_phase_response_1.combined_register_vector(), + IsBlockSorted(kBytesCipherText)); + + // Setup phase at Duchy 2. + // We assume all test data comes from duchy 1 in the test, so there is only + // noise from duchy 2 (if configured) + CompleteReachOnlySetupPhaseRequest + complete_reach_only_setup_phase_request_2; + if (reach_noise_parameters != nullptr) { + *complete_reach_only_setup_phase_request_2.mutable_noise_parameters() = + *reach_noise_parameters; + } + complete_reach_only_setup_phase_request_2.set_noise_mechanism( + noise_mechanism); + complete_reach_only_setup_phase_request_2.set_curve_id(kTestCurveId); + *complete_reach_only_setup_phase_request_2 + .mutable_composite_el_gamal_public_key() = client_el_gamal_public_key_; + complete_reach_only_setup_phase_request_2.set_parallelism(kParallelism); + + ASSIGN_OR_RETURN( + CompleteReachOnlySetupPhaseResponse + complete_reach_only_setup_phase_response_2, + CompleteReachOnlySetupPhase(complete_reach_only_setup_phase_request_2)); + EXPECT_THAT( + complete_reach_only_setup_phase_response_2.combined_register_vector(), + IsBlockSorted(kBytesCipherText)); + + // Setup phase at Duchy 3. + // We assume all test data comes from duchy 1 in the test, so there is only + // noise from duchy 3 (if configured) + CompleteReachOnlySetupPhaseRequest + complete_reach_only_setup_phase_request_3; + if (reach_noise_parameters != nullptr) { + *complete_reach_only_setup_phase_request_3.mutable_noise_parameters() = + *reach_noise_parameters; + } + complete_reach_only_setup_phase_request_3.set_curve_id(kTestCurveId); + complete_reach_only_setup_phase_request_3.set_noise_mechanism( + noise_mechanism); + *complete_reach_only_setup_phase_request_3 + .mutable_composite_el_gamal_public_key() = client_el_gamal_public_key_; + complete_reach_only_setup_phase_request_3.set_parallelism(kParallelism); + + std::string serialized_excessive_noise_ciphertext = + absl::StrCat(complete_reach_only_setup_phase_response_1 + .serialized_excessive_noise_ciphertext(), + complete_reach_only_setup_phase_response_2 + .serialized_excessive_noise_ciphertext()); + complete_reach_only_setup_phase_request_3 + .set_serialized_excessive_noise_ciphertext( + serialized_excessive_noise_ciphertext); + + // Combine all CRVs from the workers. + std::string combine_data = absl::StrCat( + complete_reach_only_setup_phase_response_1.combined_register_vector(), + complete_reach_only_setup_phase_response_2.combined_register_vector()); + complete_reach_only_setup_phase_request_3.set_combined_register_vector( + combine_data); + + ASSIGN_OR_RETURN(CompleteReachOnlySetupPhaseResponse + complete_reach_only_setup_phase_response_3, + CompleteReachOnlySetupPhaseAtAggregator( + complete_reach_only_setup_phase_request_3)); + EXPECT_THAT( + complete_reach_only_setup_phase_response_3.combined_register_vector(), + IsBlockSorted(kBytesCipherText)); + + // Execution phase at duchy 1 (non-aggregator). + CompleteReachOnlyExecutionPhaseRequest + complete_reach_only_execution_phase_request_1; + *complete_reach_only_execution_phase_request_1 + .mutable_local_el_gamal_key_pair() = duchy_1_el_gamal_key_pair_; + complete_reach_only_execution_phase_request_1.set_curve_id(kTestCurveId); + *complete_reach_only_execution_phase_request_1 + .mutable_serialized_excessive_noise_ciphertext() = + complete_reach_only_setup_phase_response_3 + .serialized_excessive_noise_ciphertext(); + complete_reach_only_execution_phase_request_1.set_parallelism(kParallelism); + complete_reach_only_execution_phase_request_1.set_combined_register_vector( + complete_reach_only_setup_phase_response_3.combined_register_vector()); + complete_reach_only_execution_phase_request_1 + .set_serialized_excessive_noise_ciphertext( + complete_reach_only_setup_phase_response_3 + .serialized_excessive_noise_ciphertext()); + ASSIGN_OR_RETURN(CompleteReachOnlyExecutionPhaseResponse + complete_reach_only_execution_phase_response_1, + CompleteReachOnlyExecutionPhase( + complete_reach_only_execution_phase_request_1)); + EXPECT_THAT(complete_reach_only_execution_phase_response_1 + .combined_register_vector(), + IsBlockSorted(kBytesCipherText)); + + // Execution phase at duchy 2 (non-aggregator). + CompleteReachOnlyExecutionPhaseRequest + complete_reach_only_execution_phase_request_2; + *complete_reach_only_execution_phase_request_2 + .mutable_local_el_gamal_key_pair() = duchy_2_el_gamal_key_pair_; + complete_reach_only_execution_phase_request_2.set_curve_id(kTestCurveId); + *complete_reach_only_execution_phase_request_2 + .mutable_serialized_excessive_noise_ciphertext() = + complete_reach_only_execution_phase_response_1 + .serialized_excessive_noise_ciphertext(); + complete_reach_only_execution_phase_request_2.set_parallelism(kParallelism); + complete_reach_only_execution_phase_request_2.set_combined_register_vector( + complete_reach_only_execution_phase_response_1 + .combined_register_vector()); + complete_reach_only_execution_phase_request_2 + .set_serialized_excessive_noise_ciphertext( + complete_reach_only_execution_phase_response_1 + .serialized_excessive_noise_ciphertext()); + ASSIGN_OR_RETURN(CompleteReachOnlyExecutionPhaseResponse + complete_execution_phase_one_response_2, + CompleteReachOnlyExecutionPhase( + complete_reach_only_execution_phase_request_2)); + EXPECT_THAT( + complete_execution_phase_one_response_2.combined_register_vector(), + IsBlockSorted(kBytesCipherText)); + + // Execution phase at duchy 3 (aggregator). + CompleteReachOnlyExecutionPhaseAtAggregatorRequest + complete_reach_only_execution_phase_at_aggregator_request; + complete_reach_only_execution_phase_at_aggregator_request + .set_combined_register_vector( + complete_execution_phase_one_response_2.combined_register_vector()); + *complete_reach_only_execution_phase_at_aggregator_request + .mutable_local_el_gamal_key_pair() = duchy_3_el_gamal_key_pair_; + complete_reach_only_execution_phase_at_aggregator_request.set_curve_id( + kTestCurveId); + complete_reach_only_execution_phase_at_aggregator_request.set_parallelism( + kParallelism); + *complete_reach_only_execution_phase_at_aggregator_request + .mutable_serialized_excessive_noise_ciphertext() = + complete_execution_phase_one_response_2 + .serialized_excessive_noise_ciphertext(); + if (reach_noise_parameters != nullptr) { + complete_reach_only_execution_phase_at_aggregator_request + .mutable_reach_dp_noise_baseline() + ->set_contributors_count(3); + *complete_reach_only_execution_phase_at_aggregator_request + .mutable_reach_dp_noise_baseline() + ->mutable_global_reach_dp_noise() = + reach_noise_parameters->dp_params().global_reach_dp_noise(); + *complete_reach_only_execution_phase_at_aggregator_request + .mutable_noise_parameters() = *reach_noise_parameters; + } + complete_reach_only_execution_phase_at_aggregator_request + .mutable_liquid_legions_parameters() + ->set_decay_rate(kDecayRate); + complete_reach_only_execution_phase_at_aggregator_request + .mutable_liquid_legions_parameters() + ->set_size(kLiquidLegionsSize); + complete_reach_only_execution_phase_at_aggregator_request + .set_vid_sampling_interval_width(kVidSamplingIntervalWidth); + complete_reach_only_execution_phase_at_aggregator_request + .set_noise_mechanism(noise_mechanism); + complete_reach_only_execution_phase_at_aggregator_request + .set_serialized_excessive_noise_ciphertext( + complete_execution_phase_one_response_2 + .serialized_excessive_noise_ciphertext()); + ASSIGN_OR_RETURN( + CompleteReachOnlyExecutionPhaseAtAggregatorResponse + complete_reach_only_execution_phase_at_aggregator_response, + CompleteReachOnlyExecutionPhaseAtAggregator( + complete_reach_only_execution_phase_at_aggregator_request)); + + ReachOnlyMpcResult result; + result.reach = + complete_reach_only_execution_phase_at_aggregator_response.reach(); + return result; + } +}; + +TEST(CompleteReachOnlySetupPhase, WrongInputSketchSizeShouldThrow) { + ReachOnlyTest test_data; + CompleteReachOnlySetupPhaseRequest request; + request.set_curve_id(kTestCurveId); + *request.mutable_composite_el_gamal_public_key() = + test_data.client_el_gamal_public_key_; + request.set_combined_register_vector("1234"); + request.set_parallelism(kParallelism); + + auto result = CompleteReachOnlySetupPhase(request); + ASSERT_FALSE(result.ok()); + EXPECT_THAT(result.status(), + StatusIs(absl::StatusCode::kInvalidArgument, "not divisible")); +} + +TEST(CompleteReachOnlySetupPhase, SetupPhaseWorksAsExpectedWithoutNoise) { + ReachOnlyTest test_data; + CompleteReachOnlySetupPhaseRequest request; + request.set_curve_id(kTestCurveId); + *request.mutable_composite_el_gamal_public_key() = + test_data.client_el_gamal_public_key_; + request.set_parallelism(kParallelism); + + std::string register1 = "abc"; + std::string register2 = "def"; + for (int i = 3; i < kBytesCipherText; i++) { + register1 = register1 + " "; + register2 = register2 + " "; + } + std::string registers = register1 + register2; + + request.set_combined_register_vector(registers); + + auto result = CompleteReachOnlySetupPhase(request); + ASSERT_TRUE(result.ok()); + + std::string response_crv = result->combined_register_vector(); + EXPECT_EQ(registers, response_crv); + EXPECT_EQ(registers.length(), response_crv.length()); + EXPECT_EQ("abc", response_crv.substr(0, 3)); + EXPECT_EQ("def", response_crv.substr(kBytesCipherText, 3)); +} + +TEST(CompleteReachOnlySetupPhase, SetupPhaseWorksAsExpectedWithGeometricNoise) { + Context ctx; + ASSERT_OK_AND_ASSIGN(ECGroup ec_group, ECGroup::Create(kTestCurveId, &ctx)); + ASSERT_OK_AND_ASSIGN(std::unique_ptr el_gamal_cipher, + CommutativeElGamal::CreateWithNewKeyPair(kTestCurveId)); + ASSERT_OK_AND_ASSIGN(auto public_key_pair, + el_gamal_cipher->GetPublicKeyBytes()); + ElGamalPublicKey public_key; + public_key.set_generator(public_key_pair.first); + public_key.set_element(public_key_pair.second); + + int64_t computed_blinded_histogram_noise_offset = 7; + int64_t computed_publisher_noise_offset = 7; + int64_t computed_reach_dp_noise_offset = 4; + int64_t expected_total_register_count = + computed_publisher_noise_offset * 2 + computed_reach_dp_noise_offset * 2 + + computed_blinded_histogram_noise_offset * kPublisherCount * + (kPublisherCount + 1) + + 2; + + CompleteReachOnlySetupPhaseRequest request; + request.set_curve_id(kTestCurveId); + RegisterNoiseGenerationParameters* noise_parameters = + request.mutable_noise_parameters(); + noise_parameters->set_curve_id(kTestCurveId); + noise_parameters->set_total_sketches_count(kPublisherCount); + noise_parameters->set_contributors_count(kWorkerCount); + *noise_parameters->mutable_composite_el_gamal_public_key() = public_key; + // resulted p ~= 0 , offset = 7 + *noise_parameters->mutable_dp_params()->mutable_blind_histogram() = + MakeDifferentialPrivacyParams(40, std::exp(-80)); + // resulted p ~= 0 , offset = 7 + *noise_parameters->mutable_dp_params()->mutable_noise_for_publisher_noise() = + MakeDifferentialPrivacyParams(40, std::exp(-40)); + // resulted p ~= 0 , offset = 4 + *noise_parameters->mutable_dp_params()->mutable_global_reach_dp_noise() = + MakeDifferentialPrivacyParams(40, std::exp(-80)); + request.set_noise_mechanism(LiquidLegionsV2NoiseConfig::GEOMETRIC); + request.set_parallelism(kParallelism); + + ASSERT_OK_AND_ASSIGN(CompleteReachOnlySetupPhaseResponse response, + CompleteReachOnlySetupPhase(request)); + + // There was no data in the request, so all registers in the response are + // noise. + std::string noises = response.combined_register_vector(); + ASSERT_THAT(noises, + SizeIs(expected_total_register_count * kBytesPerCipherText)); +} + +TEST(CompleteReachOnlySetupPhase, SetupPhaseWorksAsExpectedWithGaussianNoise) { + Context ctx; + ASSERT_OK_AND_ASSIGN(ECGroup ec_group, ECGroup::Create(kTestCurveId, &ctx)); + ASSERT_OK_AND_ASSIGN(std::unique_ptr el_gamal_cipher, + CommutativeElGamal::CreateWithNewKeyPair(kTestCurveId)); + ASSERT_OK_AND_ASSIGN(auto public_key_pair, + el_gamal_cipher->GetPublicKeyBytes()); + ElGamalPublicKey public_key; + public_key.set_generator(public_key_pair.first); + public_key.set_element(public_key_pair.second); + + int64_t computed_blinded_histogram_noise_offset = 3; + int64_t computed_publisher_noise_offset = 2; + int64_t computed_reach_dp_noise_offset = 3; + int64_t expected_total_register_count = + computed_publisher_noise_offset * 2 + computed_reach_dp_noise_offset * 2 + + computed_blinded_histogram_noise_offset * kPublisherCount * + (kPublisherCount + 1) + + 2; + + CompleteReachOnlySetupPhaseRequest request; + request.set_curve_id(kTestCurveId); + RegisterNoiseGenerationParameters* noise_parameters = + request.mutable_noise_parameters(); + noise_parameters->set_curve_id(kTestCurveId); + noise_parameters->set_total_sketches_count(kPublisherCount); + noise_parameters->set_contributors_count(kWorkerCount); + *noise_parameters->mutable_composite_el_gamal_public_key() = public_key; + // resulted sigma_distributed ~= 0.18, offset = 3 + *noise_parameters->mutable_dp_params()->mutable_blind_histogram() = + MakeDifferentialPrivacyParams(40, std::exp(-80)); + // resulted sigma_distributed ~= 0.13, offset = 2 + *noise_parameters->mutable_dp_params()->mutable_noise_for_publisher_noise() = + MakeDifferentialPrivacyParams(40, std::exp(-40)); + // resulted sigma_distributed ~= 0.18, offset = 3 + *noise_parameters->mutable_dp_params()->mutable_global_reach_dp_noise() = + MakeDifferentialPrivacyParams(40, std::exp(-80)); + request.set_noise_mechanism(LiquidLegionsV2NoiseConfig::DISCRETE_GAUSSIAN); + request.set_parallelism(kParallelism); + + ASSERT_OK_AND_ASSIGN(CompleteReachOnlySetupPhaseResponse response, + CompleteReachOnlySetupPhase(request)); + // There was no data in the request, so all registers in the response are + // noise. + std::string noises = response.combined_register_vector(); + ASSERT_THAT(noises, + SizeIs(expected_total_register_count * kBytesPerCipherText)); +} + +TEST(CompleteReachOnlyExecutionPhase, WrongInputSketchSizeShouldThrow) { + CompleteReachOnlyExecutionPhaseRequest request; + request.set_combined_register_vector("1234"); + request.set_parallelism(kParallelism); + + auto result = CompleteReachOnlyExecutionPhase(request); + ASSERT_FALSE(result.ok()); + EXPECT_THAT(result.status(), + StatusIs(absl::StatusCode::kInvalidArgument, "not divisible")); +} + +TEST(CompleteReachOnlyExecutionPhaseAtAggregator, + WrongInputSketchSizeShouldThrow) { + CompleteReachOnlyExecutionPhaseAtAggregatorRequest request; + request.set_curve_id(kTestCurveId); + request.set_combined_register_vector("1234"); + request.set_parallelism(kParallelism); + + auto result = CompleteReachOnlyExecutionPhaseAtAggregator(request); + ASSERT_FALSE(result.ok()); + EXPECT_THAT(result.status(), + StatusIs(absl::StatusCode::kInvalidArgument, "not divisible")); +} + +TEST(EndToEnd, SumOfCountsShouldBeCorrectWithoutNoise) { + ReachOnlyTest test_data; + Sketch plain_sketch = CreateReachOnlyEmptyLiquidLegionsSketch(); + int num_registers = 100; + for (int i = 1; i <= num_registers; i++) { + AddRegister(&plain_sketch, /*index=*/i); + } + + std::string encrypted_sketch = + test_data.EncryptWithFlaggedKey(plain_sketch).value(); + int64_t expected_reach = wfa::estimation::EstimateCardinalityLiquidLegions( + kDecayRate, kLiquidLegionsSize, num_registers, kVidSamplingIntervalWidth); + + ASSERT_OK_AND_ASSIGN(ReachOnlyMpcResult result_with_geometric_noise, + test_data.GoThroughEntireMpcProtocol( + encrypted_sketch, /*reach_noise=*/nullptr, + LiquidLegionsV2NoiseConfig::GEOMETRIC)); + + EXPECT_EQ(result_with_geometric_noise.reach, expected_reach); + + ASSERT_OK_AND_ASSIGN(ReachOnlyMpcResult result_with_gaussian_noise, + test_data.GoThroughEntireMpcProtocol( + encrypted_sketch, /*reach_noise=*/nullptr, + LiquidLegionsV2NoiseConfig::DISCRETE_GAUSSIAN)); + EXPECT_EQ(result_with_gaussian_noise.reach, expected_reach); +} + +TEST(EndToEnd, CombinedCasesWithDeterministicReachDpNoises) { + ReachOnlyTest test_data; + Sketch plain_sketch = CreateReachOnlyEmptyLiquidLegionsSketch(); + int valid_register_count = 30; + for (int i = 1; i <= valid_register_count; i++) { + AddRegister(&plain_sketch, /*index=*/i); + } + + std::string encrypted_sketch = + test_data.EncryptWithFlaggedKey(plain_sketch).value(); + + RegisterNoiseGenerationParameters reach_noise_parameters; + reach_noise_parameters.set_curve_id(kTestCurveId); + reach_noise_parameters.set_total_sketches_count(3); + reach_noise_parameters.set_contributors_count(kWorkerCount); + // For geometric noise, resulted p = 0.716531, offset = 15. + // Random blind histogram noise. + *reach_noise_parameters.mutable_dp_params()->mutable_blind_histogram() = + MakeDifferentialPrivacyParams(0.11, 0.11); + // For geometric noise, resulted p = 0.716531, offset = 10. + // Random noise for publisher noise. + *reach_noise_parameters.mutable_dp_params() + ->mutable_noise_for_publisher_noise() = + MakeDifferentialPrivacyParams(1, 1); + // For geometric noise, resulted p ~= 0 , offset = 3. + // Deterministic reach dp noise. + *reach_noise_parameters.mutable_dp_params()->mutable_global_reach_dp_noise() = + MakeDifferentialPrivacyParams(40, std::exp(-80)); + *reach_noise_parameters.mutable_composite_el_gamal_public_key() = + test_data.client_el_gamal_public_key_; + + int64_t expected_reach = wfa::estimation::EstimateCardinalityLiquidLegions( + kDecayRate, kLiquidLegionsSize, valid_register_count, + kVidSamplingIntervalWidth); + + ASSERT_OK_AND_ASSIGN(ReachOnlyMpcResult result_with_geometric_noise, + test_data.GoThroughEntireMpcProtocol( + encrypted_sketch, &reach_noise_parameters, + LiquidLegionsV2NoiseConfig::GEOMETRIC)); + + EXPECT_EQ(result_with_geometric_noise.reach, expected_reach); + + ASSERT_OK_AND_ASSIGN(ReachOnlyMpcResult result_with_gaussian_noise, + test_data.GoThroughEntireMpcProtocol( + encrypted_sketch, &reach_noise_parameters, + LiquidLegionsV2NoiseConfig::DISCRETE_GAUSSIAN)); + + EXPECT_EQ(result_with_gaussian_noise.reach, expected_reach); +} + +TEST(ReachEstimation, NonDpNoiseShouldNotImpactTheResult) { + ReachOnlyTest test_data; + Sketch plain_sketch = CreateReachOnlyEmptyLiquidLegionsSketch(); + int valid_register_count = 30; + for (int i = 1; i <= valid_register_count; ++i) { + AddRegister(&plain_sketch, /*index =*/i); + } + + RegisterNoiseGenerationParameters reach_noise_parameters; + reach_noise_parameters.set_curve_id(kTestCurveId); + reach_noise_parameters.set_total_sketches_count(kPublisherCount); + reach_noise_parameters.set_contributors_count(kWorkerCount); + // For geometric noise, resulted p = 0.716531, offset = 15. + // Random blind histogram noise. + *reach_noise_parameters.mutable_dp_params()->mutable_blind_histogram() = + MakeDifferentialPrivacyParams(0.1, 0.1); + // For geometric noise, resulted p = 0.716531, offset = 10. + // Random noise for publisher noise. + *reach_noise_parameters.mutable_dp_params() + ->mutable_noise_for_publisher_noise() = + MakeDifferentialPrivacyParams(1, 1); + // For geometric noise, resulted p ~= 0 , offset = 3. + // Deterministic reach dp noise. + *reach_noise_parameters.mutable_dp_params()->mutable_global_reach_dp_noise() = + MakeDifferentialPrivacyParams(40, std::exp(-80)); + *reach_noise_parameters.mutable_composite_el_gamal_public_key() = + test_data.client_el_gamal_public_key_; + + std::string encrypted_sketch = + test_data.EncryptWithFlaggedKey(plain_sketch).value(); + + int64_t expected_reach = wfa::estimation::EstimateCardinalityLiquidLegions( + kDecayRate, kLiquidLegionsSize, valid_register_count, + kVidSamplingIntervalWidth); + + ASSERT_OK_AND_ASSIGN(ReachOnlyMpcResult result_with_geometric_noise, + test_data.GoThroughEntireMpcProtocol( + encrypted_sketch, &reach_noise_parameters, + LiquidLegionsV2NoiseConfig::GEOMETRIC)); + EXPECT_EQ(result_with_geometric_noise.reach, expected_reach); + + ASSERT_OK_AND_ASSIGN(ReachOnlyMpcResult result_with_gaussian_noise, + test_data.GoThroughEntireMpcProtocol( + encrypted_sketch, &reach_noise_parameters, + LiquidLegionsV2NoiseConfig::DISCRETE_GAUSSIAN)); + EXPECT_EQ(result_with_gaussian_noise.reach, expected_reach); +} + +} // namespace +} // namespace wfa::measurement::internal::duchy::protocol::liquid_legions_v2 From 3e15807d27ac2b0797da32fc671c6b3d1088f979 Mon Sep 17 00:00:00 2001 From: Phi Hung Le Date: Sat, 22 Jul 2023 02:51:58 -0400 Subject: [PATCH 02/15] Fixed some typo and format files. --- ...liquid_legions_v2_encryption_methods.proto | 2 - .../protocol/liquid_legions_v2/BUILD.bazel | 99 +++++++++---------- 2 files changed, 48 insertions(+), 53 deletions(-) diff --git a/src/main/proto/wfa/measurement/internal/duchy/protocol/reach_only_liquid_legions_v2_encryption_methods.proto b/src/main/proto/wfa/measurement/internal/duchy/protocol/reach_only_liquid_legions_v2_encryption_methods.proto index a332441e1ca..46df969a70f 100644 --- a/src/main/proto/wfa/measurement/internal/duchy/protocol/reach_only_liquid_legions_v2_encryption_methods.proto +++ b/src/main/proto/wfa/measurement/internal/duchy/protocol/reach_only_liquid_legions_v2_encryption_methods.proto @@ -81,8 +81,6 @@ message CompleteReachOnlySetupPhaseResponse { } // Response of the CompleteReachOnlySetupPhase method at the aggregate worker. -// Different from the non-aggregator, the aggregator samples the El Gamal key -// pair and encrypts its excessive noise with the public key. message CompleteReachOnlySetupPhaseAtAggregatorResponse { // The output combined register vector (CRV), which contains shuffled input // and noise registers. diff --git a/src/test/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/BUILD.bazel b/src/test/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/BUILD.bazel index ebf6eba79b1..aeac7682ee4 100644 --- a/src/test/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/BUILD.bazel +++ b/src/test/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/BUILD.bazel @@ -1,70 +1,67 @@ load("@rules_cc//cc:defs.bzl", "cc_test") -cc_test(name = "liquid_legions_v2_encryption_utility_test", size = "small", +cc_test( + name = "liquid_legions_v2_encryption_utility_test", size = "small", timeout = "moderate", - srcs = - [ - "liquid_legions_v2_encryption_utility_test.cc", - ], - deps = - [ - "//src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2:liquid_legions_v2_encryption_utility", - "//src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2:liquid_legions_v2_encryption_utility_helper", - "//src/main/proto/wfa/measurement/internal/duchy/protocol:liquid_legions_v2_noise_config_cc_proto", - "@any_sketch//src/main/cc/any_sketch/crypto:sketch_encrypter", - "@any_sketch//src/main/cc/estimation:estimators", - "@any_sketch//src/main/proto/wfa/any_sketch:sketch_cc_proto", - "@com_google_googletest//:gtest", - "@com_google_googletest//:gtest_main", - "@wfa_common_cpp//src/main/cc/common_cpp/testing:status", - ], ) + srcs = [ + "liquid_legions_v2_encryption_utility_test.cc", + ], + deps = [ + "//src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2:liquid_legions_v2_encryption_utility", + "//src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2:liquid_legions_v2_encryption_utility_helper", + "//src/main/proto/wfa/measurement/internal/duchy/protocol:liquid_legions_v2_noise_config_cc_proto", + "@any_sketch//src/main/cc/any_sketch/crypto:sketch_encrypter", + "@any_sketch//src/main/cc/estimation:estimators", + "@any_sketch//src/main/proto/wfa/any_sketch:sketch_cc_proto", + "@com_google_googletest//:gtest", + "@com_google_googletest//:gtest_main", + "@wfa_common_cpp//src/main/cc/common_cpp/testing:status", + ], +) cc_test( name = "reach_only_liquid_legions_v2_encryption_utility_test", size = "small", timeout = "moderate", - srcs = - [ - "reach_only_liquid_legions_v2_encryption_utility_test.cc", - ], - deps = - [ - "//src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2:reach_only_liquid_legions_v2_encryption_utility", - "//src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2:liquid_legions_v2_encryption_utility_helper", - "//src/main/proto/wfa/measurement/internal/duchy/protocol:liquid_legions_v2_noise_config_cc_proto", - "@any_sketch//src/main/cc/any_sketch/crypto:sketch_encrypter", - "@any_sketch//src/main/cc/estimation:estimators", - "@any_sketch//src/main/proto/wfa/any_sketch:sketch_cc_proto", - "@com_google_googletest//:gtest", - "@com_google_googletest//:gtest_main", - "@wfa_common_cpp//src/main/cc/common_cpp/testing:status", - ], ) + srcs = [ + "reach_only_liquid_legions_v2_encryption_utility_test.cc", + ], + deps = [ + "//src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2:reach_only_liquid_legions_v2_encryption_utility", + "//src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2:liquid_legions_v2_encryption_utility_helper", + "//src/main/proto/wfa/measurement/internal/duchy/protocol:liquid_legions_v2_noise_config_cc_proto", + "@any_sketch//src/main/cc/any_sketch/crypto:sketch_encrypter", + "@any_sketch//src/main/cc/estimation:estimators", + "@any_sketch//src/main/proto/wfa/any_sketch:sketch_cc_proto", + "@com_google_googletest//:gtest", + "@com_google_googletest//:gtest_main", + "@wfa_common_cpp//src/main/cc/common_cpp/testing:status", + ], ) cc_test( name = "noise_parameters_computation_test", size = "small", - srcs = - [ - "noise_parameters_computation_test.cc", - ], - deps = - [ - "//src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2:noise_parameters_computation", - "//src/main/proto/wfa/measurement/internal/duchy/protocol:liquid_legions_v2_noise_config_cc_proto", - "@any_sketch//src/main/cc/math:distributed_discrete_gaussian_noiser", - "@any_sketch//src/main/cc/math:distributed_geometric_noiser", - "@com_google_googletest//:gtest", - "@com_google_googletest//:gtest_main", - ], ) + srcs = [ + "noise_parameters_computation_test.cc", + ], + deps = [ + "//src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2:noise_parameters_computation", + "//src/main/proto/wfa/measurement/internal/duchy/protocol:liquid_legions_v2_noise_config_cc_proto", + "@any_sketch//src/main/cc/math:distributed_discrete_gaussian_noiser", + "@any_sketch//src/main/cc/math:distributed_geometric_noiser", + "@com_google_googletest//:gtest", + "@com_google_googletest//:gtest_main", + ], +) cc_test( name = "multithreading_helper_test", size = "small", - srcs = - [ - "multithreading_helper_test.cc", - ], + srcs = [ + "multithreading_helper_test.cc", + ], deps = [ "//src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2:multithreading_helper", "@com_google_absl//absl/functional:any_invocable", "@com_google_googletest//:gtest", "@com_google_googletest//:gtest_main", "@wfa_common_cpp//src/main/cc/common_cpp/testing:status", - ], ) + ], +) From 655c2fdb08e14d2c51b8957c353d071ad56e7eb1 Mon Sep 17 00:00:00 2001 From: Phi Hung Le Date: Sat, 22 Jul 2023 02:56:07 -0400 Subject: [PATCH 03/15] Format build file. --- .../duchy/protocol/liquid_legions_v2/BUILD.bazel | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/test/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/BUILD.bazel b/src/test/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/BUILD.bazel index aeac7682ee4..4c8e5198f89 100644 --- a/src/test/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/BUILD.bazel +++ b/src/test/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/BUILD.bazel @@ -1,7 +1,8 @@ load("@rules_cc//cc:defs.bzl", "cc_test") cc_test( - name = "liquid_legions_v2_encryption_utility_test", size = "small", + name = "liquid_legions_v2_encryption_utility_test", + size = "small", timeout = "moderate", srcs = [ "liquid_legions_v2_encryption_utility_test.cc", @@ -21,7 +22,8 @@ cc_test( cc_test( name = "reach_only_liquid_legions_v2_encryption_utility_test", - size = "small", timeout = "moderate", + size = "small", + timeout = "moderate", srcs = [ "reach_only_liquid_legions_v2_encryption_utility_test.cc", ], @@ -38,7 +40,8 @@ cc_test( ], ) cc_test( - name = "noise_parameters_computation_test", size = "small", + name = "noise_parameters_computation_test", + size = "small", srcs = [ "noise_parameters_computation_test.cc", ], @@ -53,7 +56,8 @@ cc_test( ) cc_test( - name = "multithreading_helper_test", size = "small", + name = "multithreading_helper_test", + size = "small", srcs = [ "multithreading_helper_test.cc", ], From e30b04e7ac45a61490042e63c8d36041a5fe173c Mon Sep 17 00:00:00 2001 From: Phi Hung Le Date: Mon, 24 Jul 2023 01:30:52 -0400 Subject: [PATCH 04/15] Update comments and fix typos. --- .../protocol/liquid_legions_v2/BUILD.bazel | 2 ++ ...id_legions_v2_encryption_utility_helper.cc | 19 ------------------- ...uid_legions_v2_encryption_utility_helper.h | 6 +----- ...nly_liquid_legions_v2_encryption_utility.h | 18 ++++++++---------- 4 files changed, 11 insertions(+), 34 deletions(-) diff --git a/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/BUILD.bazel b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/BUILD.bazel index 2bd3a0af4cb..f3e4b20510e 100644 --- a/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/BUILD.bazel +++ b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/BUILD.bazel @@ -37,6 +37,7 @@ cc_library( ], strip_include_prefix = _INCLUDE_PREFIX, deps = [ + ":liquid_legions_v2_encryption_utility_helper", ":multithreading_helper", ":noise_parameters_computation", "//src/main/cc/wfa/measurement/common/crypto:constants", @@ -64,6 +65,7 @@ cc_library( ], strip_include_prefix = _INCLUDE_PREFIX, deps = [ + ":liquid_legions_v2_encryption_utility_helper", ":multithreading_helper", ":noise_parameters_computation", "//src/main/cc/wfa/measurement/common/crypto:constants", diff --git a/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility_helper.cc b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility_helper.cc index 7e4e497ab5a..499dfd5ef78 100644 --- a/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility_helper.cc +++ b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility_helper.cc @@ -60,23 +60,4 @@ DifferentialPrivacyParams MakeDifferentialPrivacyParams(double epsilon, return params; } -absl::StatusOr EstimateReach(double liquid_legions_decay_rate, - int64_t liquid_legions_size, - size_t non_empty_register_count, - float sampling_rate) { - if (liquid_legions_decay_rate <= 1.0) { - return absl::InvalidArgumentError(absl::StrCat( - "The decay rate should be > 1, but is ", liquid_legions_decay_rate)); - } - if (liquid_legions_size <= non_empty_register_count) { - return absl::InvalidArgumentError(absl::StrCat( - "liquid legions size (", liquid_legions_size, - ") should be greater then the number of non empty registers (", - non_empty_register_count, ").")); - } - return wfa::estimation::EstimateCardinalityLiquidLegions( - liquid_legions_decay_rate, liquid_legions_size, non_empty_register_count, - sampling_rate); -} - } // namespace wfa::measurement::internal::duchy::protocol::liquid_legions_v2 diff --git a/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility_helper.h b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility_helper.h index 041912eebec..3e4ec061377 100644 --- a/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility_helper.h +++ b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility_helper.h @@ -26,11 +26,6 @@ using ::wfa::any_sketch::Sketch; using ::wfa::measurement::internal::duchy::DifferentialPrivacyParams; using ::wfa::measurement::internal::duchy::ElGamalPublicKey; -absl::StatusOr EstimateReach(double liquid_legions_decay_rate, - int64_t liquid_legions_size, - size_t non_empty_register_count, - float sampling_rate = 1.0); - ::wfa::any_sketch::crypto::ElGamalPublicKey ToAnysketchElGamalKey( ElGamalPublicKey key); @@ -43,6 +38,7 @@ Sketch CreateReachOnlyEmptyLiquidLegionsSketch(); DifferentialPrivacyParams MakeDifferentialPrivacyParams(double epsilon, double delta); + } // namespace wfa::measurement::internal::duchy::protocol::liquid_legions_v2 #endif // SRC_MAIN_CC_WFA_MEASUREMENT_INTERNAL_DUCHY_PROTOCOL_LIQUID_LEGIONS_V2_LIQUID_LEGIONS_V2_ENCRYPTION_UTILITY_HELPER_H_ diff --git a/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility.h b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility.h index a410bf4eae1..3a4ced9dc3d 100644 --- a/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility.h +++ b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility.h @@ -50,7 +50,8 @@ CompleteReachOnlyInitializationPhase( // specifically, the worker would // 1. add local noise registers (if configured to). // 2. shuffle all registers. -// 3. stores the amount of excessive noise that it can remove to the database. +// 3. encrypt the amount of excessive noise with the composit ElGamal public +// key. absl::StatusOr CompleteReachOnlySetupPhase( const CompleteReachOnlySetupPhaseRequest& request); @@ -58,8 +59,8 @@ absl::StatusOr CompleteReachOnlySetupPhase( // aggregator would // 1. add local noise registers (if configured to). // 2. shuffle all registers. -// 3. sample a Paillier keypair -// 4. encrypt the excessive noise using Paillier encryption. +// 3. encrypt its excessive noise using the composite ElGamal public key. +// 4. combine its noise ciphertext with those from the workers. absl::StatusOr CompleteReachOnlySetupPhaseAtAggregator( const CompleteReachOnlySetupPhaseRequest& request); @@ -68,10 +69,9 @@ CompleteReachOnlySetupPhaseAtAggregator( // More specifically, the worker would // 1. blind the positions (decrypt local ElGamal layer and then add another // layer of deterministic pohlig_hellman encryption. -// 2. re-randomize keys and counts. +// 2. partially decrypt the noise ciphertext using its partial ElGamal +// private key. // 3. shuffle all registers. -// 4. adds its excessive noise to the ciphertext that stores the aggregated -// excessive noise to be removed. absl::StatusOr CompleteReachOnlyExecutionPhase( const CompleteReachOnlyExecutionPhaseRequest& request); @@ -79,11 +79,9 @@ CompleteReachOnlyExecutionPhase( // Complete work in the execution phase one at the aggregator worker. // More specifically, the worker would // 1. decrypt the local ElGamal encryption on the positions. -// 2. join the registers by positions. +// 2. decrypt the total excessive noise. // 3. count the number of unique registers, excluding the blinded histogram -// noise and the publisher noise. -// 4. decrypt the Paillier ciphertext that stores the aggregated excessive -// noise and subtract it from the total register count. +// noise, the publisher noise, and the excessive noise. absl::StatusOr CompleteReachOnlyExecutionPhaseAtAggregator( const CompleteReachOnlyExecutionPhaseAtAggregatorRequest& request); From 7ed4d9b90feed62cb037f138c8a8a21706d3b556 Mon Sep 17 00:00:00 2001 From: Phi Hung Le Date: Mon, 24 Jul 2023 02:44:14 -0400 Subject: [PATCH 05/15] Fix cpplint. --- .../common/crypto/encryption_utility_helper.cc | 1 + .../duchy/protocol/liquid_legions_v2/BUILD.bazel | 9 ++++----- .../reach_only_liquid_legions_v2_encryption_utility.h | 6 +++--- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/main/cc/wfa/measurement/common/crypto/encryption_utility_helper.cc b/src/main/cc/wfa/measurement/common/crypto/encryption_utility_helper.cc index 43a22d736d6..de1e0d44e77 100644 --- a/src/main/cc/wfa/measurement/common/crypto/encryption_utility_helper.cc +++ b/src/main/cc/wfa/measurement/common/crypto/encryption_utility_helper.cc @@ -14,6 +14,7 @@ #include "wfa/measurement/common/crypto/encryption_utility_helper.h" +#include #include #include "absl/status/status.h" diff --git a/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/BUILD.bazel b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/BUILD.bazel index f3e4b20510e..4d6b5ab38d7 100644 --- a/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/BUILD.bazel +++ b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/BUILD.bazel @@ -19,11 +19,11 @@ cc_library( ], strip_include_prefix = _INCLUDE_PREFIX, deps = [ - "@any_sketch//src/main/cc/estimation:estimators", - "@com_google_absl//absl/status:statusor", - "@any_sketch//src/main/cc/any_sketch/crypto:sketch_encrypter", "//src/main/proto/wfa/measurement/internal/duchy:crypto_cc_proto", "//src/main/proto/wfa/measurement/internal/duchy:differential_privacy_cc_proto", + "@any_sketch//src/main/cc/any_sketch/crypto:sketch_encrypter", + "@any_sketch//src/main/cc/estimation:estimators", + "@com_google_absl//absl/status:statusor", ], ) @@ -82,7 +82,7 @@ cc_library( "@wfa_common_cpp//src/main/cc/common_cpp/time:started_thread_cpu_timer", ], ) - + cc_library( name = "liquid_legions_v2_encryption_utility_wrapper", srcs = [ @@ -135,4 +135,3 @@ cc_library( "@wfa_common_cpp//src/main/cc/common_cpp/macros", ], ) - diff --git a/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility.h b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility.h index 3a4ced9dc3d..afb5f8fcfa6 100644 --- a/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility.h +++ b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef SRC_MAIN_CC_WFA_MEASUREMENT_INTERNAL_DUCHY_PROTOCOL_REACH_ONLY_LIQUID_LEGIONS_V2_LIQUID_LEGIONS_V2_ENCRYPTION_UTILITY_H_ -#define SRC_MAIN_CC_WFA_MEASUREMENT_INTERNAL_DUCHY_PROTOCOL_REACH_ONLY_LIQUID_LEGIONS_V2_LIQUID_LEGIONS_V2_ENCRYPTION_UTILITY_H_ +#ifndef SRC_MAIN_CC_WFA_MEASUREMENT_INTERNAL_DUCHY_PROTOCOL_LIQUID_LEGIONS_V2_REACH_ONLY_LIQUID_LEGIONS_V2_ENCRYPTION_UTILITY_H_ +#define SRC_MAIN_CC_WFA_MEASUREMENT_INTERNAL_DUCHY_PROTOCOL_LIQUID_LEGIONS_V2_REACH_ONLY_LIQUID_LEGIONS_V2_ENCRYPTION_UTILITY_H_ #include "absl/status/statusor.h" #include "wfa/measurement/internal/duchy/protocol/reach_only_liquid_legions_v2_encryption_methods.pb.h" @@ -88,4 +88,4 @@ CompleteReachOnlyExecutionPhaseAtAggregator( } // namespace wfa::measurement::internal::duchy::protocol::liquid_legions_v2 -#endif // SRC_MAIN_CC_WFA_MEASUREMENT_INTERNAL_DUCHY_PROTOCOL_REACH_ONLY_LIQUID_LEGIONS_V2_LIQUID_LEGIONS_V2_ENCRYPTION_UTILITY_H_ +#endif // SRC_MAIN_CC_WFA_MEASUREMENT_INTERNAL_DUCHY_PROTOCOL_LIQUID_LEGIONS_V2_REACH_ONLY_LIQUID_LEGIONS_V2_ENCRYPTION_UTILITY_H_ From c453df7ef9e72de352a7fefe504db48412e85882 Mon Sep 17 00:00:00 2001 From: Phi Hung Le Date: Mon, 24 Jul 2023 17:19:26 -0400 Subject: [PATCH 06/15] Fixed typo and cpplint. --- .../liquid_legions_v2_encryption_utility_helper.cc | 2 +- .../liquid_legions_v2_encryption_utility_helper.h | 2 +- .../duchy/protocol/liquid_legions_v2/BUILD.bazel | 3 ++- .../liquid_legions_v2_encryption_utility_test.cc | 10 +++++----- ...h_only_liquid_legions_v2_encryption_utility_test.cc | 10 +++++----- 5 files changed, 14 insertions(+), 13 deletions(-) diff --git a/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility_helper.cc b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility_helper.cc index 499dfd5ef78..f0b59c717ab 100644 --- a/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility_helper.cc +++ b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility_helper.cc @@ -22,7 +22,7 @@ using ::wfa::any_sketch::Sketch; using ::wfa::any_sketch::SketchConfig; using ::wfa::measurement::internal::duchy::ElGamalPublicKey; -::wfa::any_sketch::crypto::ElGamalPublicKey ToAnysketchElGamalKey( +::wfa::any_sketch::crypto::ElGamalPublicKey ToAnySketchElGamalKey( ElGamalPublicKey key) { ::wfa::any_sketch::crypto::ElGamalPublicKey result; result.set_generator(key.generator()); diff --git a/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility_helper.h b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility_helper.h index 3e4ec061377..091d6c1c79e 100644 --- a/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility_helper.h +++ b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility_helper.h @@ -26,7 +26,7 @@ using ::wfa::any_sketch::Sketch; using ::wfa::measurement::internal::duchy::DifferentialPrivacyParams; using ::wfa::measurement::internal::duchy::ElGamalPublicKey; -::wfa::any_sketch::crypto::ElGamalPublicKey ToAnysketchElGamalKey( +::wfa::any_sketch::crypto::ElGamalPublicKey ToAnySketchElGamalKey( ElGamalPublicKey key); ElGamalPublicKey ToCmmsElGamalKey( diff --git a/src/test/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/BUILD.bazel b/src/test/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/BUILD.bazel index 4c8e5198f89..b17fe2d0751 100644 --- a/src/test/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/BUILD.bazel +++ b/src/test/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/BUILD.bazel @@ -37,7 +37,8 @@ cc_test( "@com_google_googletest//:gtest", "@com_google_googletest//:gtest_main", "@wfa_common_cpp//src/main/cc/common_cpp/testing:status", - ], ) + ], +) cc_test( name = "noise_parameters_computation_test", diff --git a/src/test/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility_test.cc b/src/test/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility_test.cc index 1921500f5b2..b45b23ce045 100644 --- a/src/test/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility_test.cc +++ b/src/test/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility_test.cc @@ -261,15 +261,15 @@ class TestData { client_el_gamal_public_key_ = ToCmmsElGamalKey( any_sketch::crypto::CombineElGamalPublicKeys( kTestCurveId, - {ToAnysketchElGamalKey(duchy_1_el_gamal_key_pair_.public_key()), - ToAnysketchElGamalKey(duchy_2_el_gamal_key_pair_.public_key()), - ToAnysketchElGamalKey(duchy_3_el_gamal_key_pair_.public_key())}) + {ToAnySketchElGamalKey(duchy_1_el_gamal_key_pair_.public_key()), + ToAnySketchElGamalKey(duchy_2_el_gamal_key_pair_.public_key()), + ToAnySketchElGamalKey(duchy_3_el_gamal_key_pair_.public_key())}) .value()); duchy_2_3_composite_public_key_ = ToCmmsElGamalKey( any_sketch::crypto::CombineElGamalPublicKeys( kTestCurveId, - {ToAnysketchElGamalKey(duchy_2_el_gamal_key_pair_.public_key()), - ToAnysketchElGamalKey(duchy_3_el_gamal_key_pair_.public_key())}) + {ToAnySketchElGamalKey(duchy_2_el_gamal_key_pair_.public_key()), + ToAnySketchElGamalKey(duchy_3_el_gamal_key_pair_.public_key())}) .value()); any_sketch::crypto::CiphertextString client_public_key = { diff --git a/src/test/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility_test.cc b/src/test/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility_test.cc index 8e4a1e3af9b..22f54d649a5 100644 --- a/src/test/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility_test.cc +++ b/src/test/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility_test.cc @@ -131,15 +131,15 @@ class ReachOnlyTest { client_el_gamal_public_key_ = ToCmmsElGamalKey( any_sketch::crypto::CombineElGamalPublicKeys( kTestCurveId, - {ToAnysketchElGamalKey(duchy_1_el_gamal_key_pair_.public_key()), - ToAnysketchElGamalKey(duchy_2_el_gamal_key_pair_.public_key()), - ToAnysketchElGamalKey(duchy_3_el_gamal_key_pair_.public_key())}) + {ToAnySketchElGamalKey(duchy_1_el_gamal_key_pair_.public_key()), + ToAnySketchElGamalKey(duchy_2_el_gamal_key_pair_.public_key()), + ToAnySketchElGamalKey(duchy_3_el_gamal_key_pair_.public_key())}) .value()); duchy_2_3_composite_public_key_ = ToCmmsElGamalKey( any_sketch::crypto::CombineElGamalPublicKeys( kTestCurveId, - {ToAnysketchElGamalKey(duchy_2_el_gamal_key_pair_.public_key()), - ToAnysketchElGamalKey(duchy_3_el_gamal_key_pair_.public_key())}) + {ToAnySketchElGamalKey(duchy_2_el_gamal_key_pair_.public_key()), + ToAnySketchElGamalKey(duchy_3_el_gamal_key_pair_.public_key())}) .value()); any_sketch::crypto::CiphertextString client_public_key = { From a9664fe8105c2cc91d1a1e5f52ab4e0a31c4cd9e Mon Sep 17 00:00:00 2001 From: Phi Hung Le Date: Wed, 26 Jul 2023 17:32:22 -0400 Subject: [PATCH 07/15] Adding comments, throws error when excessive noise ciphertext cannot be decrypted, making GetBlindedRegisterIndexes and GetRollv2BlindedRegisterIndexes work with MultithreadHelper. --- .../wfa/measurement/common/crypto/BUILD.bazel | 1 + .../crypto/encryption_utility_helper.cc | 45 ++++++++++------- .../common/crypto/encryption_utility_helper.h | 7 ++- .../common/crypto/protocol_cryptor.cc | 49 +++++++++++++------ .../common/crypto/protocol_cryptor.h | 7 +-- .../protocol/liquid_legions_v2/BUILD.bazel | 21 +------- .../liquid_legions_v2_encryption_utility.cc | 8 +-- ...ly_liquid_legions_v2_encryption_utility.cc | 38 ++++++++------ .../liquid_legions_v2/testing/BUILD.bazel | 29 +++++++++++ ...id_legions_v2_encryption_utility_helper.cc | 4 +- ...uid_legions_v2_encryption_utility_helper.h | 8 +-- .../daemon/herald/LiquidLegionsV2Starter.kt | 4 +- ...iquidLegionsSketchAggregationV2Protocol.kt | 4 +- .../computationcontrol/ProtocolStages.kt | 4 +- ...liquid_legions_sketch_aggregation_v2.proto | 2 +- ...liquid_legions_v2_encryption_methods.proto | 3 +- .../protocol/liquid_legions_v2/BUILD.bazel | 4 +- ...quid_legions_v2_encryption_utility_test.cc | 6 +-- ...quid_legions_v2_encryption_utility_test.cc | 18 +++++-- ...etchAggregationV2ProtocolEnumStagesTest.kt | 2 +- 20 files changed, 162 insertions(+), 102 deletions(-) create mode 100644 src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/testing/BUILD.bazel rename src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/{ => testing}/liquid_legions_v2_encryption_utility_helper.cc (95%) rename src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/{ => testing}/liquid_legions_v2_encryption_utility_helper.h (87%) diff --git a/src/main/cc/wfa/measurement/common/crypto/BUILD.bazel b/src/main/cc/wfa/measurement/common/crypto/BUILD.bazel index 825cda1c722..765279f7f4e 100644 --- a/src/main/cc/wfa/measurement/common/crypto/BUILD.bazel +++ b/src/main/cc/wfa/measurement/common/crypto/BUILD.bazel @@ -28,6 +28,7 @@ cc_library( ":constants", ":ec_point_util", ":protocol_cryptor", + "//src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2:multithreading_helper", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@wfa_common_cpp//src/main/cc/common_cpp/macros", diff --git a/src/main/cc/wfa/measurement/common/crypto/encryption_utility_helper.cc b/src/main/cc/wfa/measurement/common/crypto/encryption_utility_helper.cc index de1e0d44e77..35a0ac999f0 100644 --- a/src/main/cc/wfa/measurement/common/crypto/encryption_utility_helper.cc +++ b/src/main/cc/wfa/measurement/common/crypto/encryption_utility_helper.cc @@ -25,6 +25,9 @@ namespace wfa::measurement::common::crypto { +using ::wfa::measurement::internal::duchy::protocol::liquid_legions_v2:: + MultithreadingHelper; + absl::StatusOr GetNumberOfBlocks(absl::string_view data, size_t block_size) { if (block_size == 0) { @@ -49,42 +52,48 @@ absl::StatusOr ExtractElGamalCiphertextFromString( } absl::StatusOr> GetBlindedRegisterIndexes( - absl::string_view data, ProtocolCryptor& protocol_cryptor) { + absl::string_view data, MultithreadingHelper& helper) { ASSIGN_OR_RETURN(size_t register_count, GetNumberOfBlocks(data, kBytesPerCipherRegister)); std::vector blinded_register_indexes; - blinded_register_indexes.reserve(register_count); - for (size_t index = 0; index < register_count; ++index) { - // The size of data_block is guaranteed to be equal to - // kBytesPerCipherText + blinded_register_indexes.resize(register_count); + + absl::AnyInvocable f = + [&](ProtocolCryptor& cryptor, size_t index) -> absl::Status { absl::string_view data_block = data.substr(index * kBytesPerCipherRegister, kBytesPerCipherText); ASSIGN_OR_RETURN(ElGamalCiphertext ciphertext, ExtractElGamalCiphertextFromString(data_block)); ASSIGN_OR_RETURN(std::string decrypted_el_gamal, - protocol_cryptor.DecryptLocalElGamal(ciphertext)); - blinded_register_indexes.push_back(std::move(decrypted_el_gamal)); - } + cryptor.DecryptLocalElGamal(ciphertext)); + blinded_register_indexes[index] = std::move(decrypted_el_gamal); + return absl::OkStatus(); + }; + RETURN_IF_ERROR(helper.Execute(register_count, f)); + return blinded_register_indexes; } absl::StatusOr> GetRollv2BlindedRegisterIndexes( - absl::string_view data, ProtocolCryptor& protocol_cryptor) { + absl::string_view data, MultithreadingHelper& helper) { ASSIGN_OR_RETURN(size_t register_count, GetNumberOfBlocks(data, kBytesPerCipherText)); std::vector blinded_register_indexes; - blinded_register_indexes.reserve(register_count); - for (size_t index = 0; index < register_count; ++index) { - // The size of data_block is guaranteed to be equal to - // kBytesPerCipherText + blinded_register_indexes.resize(register_count); + + absl::AnyInvocable f = + [&](ProtocolCryptor& cryptor, size_t index) -> absl::Status { absl::string_view data_block = data.substr(index * kBytesPerCipherText, kBytesPerCipherText); ASSIGN_OR_RETURN(ElGamalCiphertext ciphertext, ExtractElGamalCiphertextFromString(data_block)); ASSIGN_OR_RETURN(std::string decrypted_el_gamal, - protocol_cryptor.DecryptLocalElGamal(ciphertext)); - blinded_register_indexes.push_back(std::move(decrypted_el_gamal)); - } + cryptor.DecryptLocalElGamal(ciphertext)); + blinded_register_indexes[index] = std::move(decrypted_el_gamal); + return absl::OkStatus(); + }; + RETURN_IF_ERROR(helper.Execute(register_count, f)); + return blinded_register_indexes; } @@ -144,8 +153,8 @@ absl::Status WriteEcPointPairToString(const ElGamalEcPointPair& ec_point_pair, absl::StatusOr GetEcPointPairFromString( absl::string_view str, int curve_id) { - std::unique_ptr context(new Context); - ASSIGN_OR_RETURN(ECGroup ec_group, ECGroup::Create(curve_id, context.get())); + Context ctx; + ASSIGN_OR_RETURN(ECGroup ec_group, ECGroup::Create(curve_id, &ctx)); ASSIGN_OR_RETURN(ElGamalCiphertext ciphertext, ExtractElGamalCiphertextFromString(str)); ASSIGN_OR_RETURN(ElGamalEcPointPair ec_point, diff --git a/src/main/cc/wfa/measurement/common/crypto/encryption_utility_helper.h b/src/main/cc/wfa/measurement/common/crypto/encryption_utility_helper.h index e55e91e1499..834d7b9c2ba 100644 --- a/src/main/cc/wfa/measurement/common/crypto/encryption_utility_helper.h +++ b/src/main/cc/wfa/measurement/common/crypto/encryption_utility_helper.h @@ -23,10 +23,13 @@ #include "absl/strings/string_view.h" #include "wfa/measurement/common/crypto/ec_point_util.h" #include "wfa/measurement/common/crypto/protocol_cryptor.h" +#include "wfa/measurement/internal/duchy/protocol/liquid_legions_v2/multithreading_helper.h" namespace wfa::measurement::common::crypto { using ::wfa::measurement::common::crypto::CompositeType; +using ::wfa::measurement::internal::duchy::protocol::liquid_legions_v2:: + MultithreadingHelper; // A pair of ciphertexts which store the key and count values of a liquidlegions // register. @@ -46,12 +49,12 @@ absl::StatusOr ExtractElGamalCiphertextFromString( // Blinds the last layer of ElGamal Encryption of register indexes, and return // the deterministically encrypted results. absl::StatusOr> GetBlindedRegisterIndexes( - absl::string_view data, ProtocolCryptor& protocol_cryptor); + absl::string_view data, MultithreadingHelper& helper); // Blinds the last layer of ElGamal Encryption of register indexes, and return // the deterministically encrypted results. absl::StatusOr> GetRollv2BlindedRegisterIndexes( - absl::string_view data, ProtocolCryptor& protocol_cryptor); + absl::string_view data, MultithreadingHelper& helper); // Extracts a KeyCountPairCipherText from a string_view. absl::StatusOr ExtractKeyCountPairFromSubstring( diff --git a/src/main/cc/wfa/measurement/common/crypto/protocol_cryptor.cc b/src/main/cc/wfa/measurement/common/crypto/protocol_cryptor.cc index de677943496..ae1e040b5c4 100644 --- a/src/main/cc/wfa/measurement/common/crypto/protocol_cryptor.cc +++ b/src/main/cc/wfa/measurement/common/crypto/protocol_cryptor.cc @@ -66,7 +66,7 @@ class ProtocolCryptorImpl : public ProtocolCryptor { CompositeType composite_type) override; absl::StatusOr EncryptCompositeElGamal( absl::string_view plain_ec_point, CompositeType composite_type) override; - absl::StatusOr EncryptIntegerWithCompositElGamalAndWriteToString( + absl::StatusOr EncryptIntegerToStringCompositeElGamal( int64_t value) override; absl::StatusOr ReRandomize( const ElGamalCiphertext& ciphertext, @@ -176,8 +176,7 @@ absl::StatusOr ProtocolCryptorImpl::EncryptCompositeElGamal( } absl::StatusOr -ProtocolCryptorImpl::EncryptIntegerWithCompositElGamalAndWriteToString( - int64_t value) { +ProtocolCryptorImpl::EncryptIntegerToStringCompositeElGamal(int64_t value) { Context ctx; std::string ciphertext; ciphertext.resize(kBytesPerCipherText); @@ -189,11 +188,20 @@ ProtocolCryptorImpl::EncryptIntegerWithCompositElGamalAndWriteToString( ASSIGN_OR_RETURN( ElGamalEcPointPair zero_ec, EncryptIdentityElementToEcPointsCompositeElGamal(CompositeType::kFull)); - std::string temp; - ASSIGN_OR_RETURN(temp, zero_ec.u.ToBytesCompressed()); - ciphertext.replace(0, kBytesPerEcPoint, temp); - ASSIGN_OR_RETURN(temp, zero_ec.e.ToBytesCompressed()); - ciphertext.replace(kBytesPerEcPoint, kBytesPerEcPoint, temp); + + if (absl::StatusOr result = zero_ec.u.ToBytesCompressed(); + result.ok()) { + ciphertext.replace(0, kBytesPerEcPoint, *result); + } else { + return result.status(); + } + + if (absl::StatusOr result = zero_ec.e.ToBytesCompressed(); + result.ok()) { + ciphertext.replace(kBytesPerEcPoint, kBytesPerEcPoint, *result); + } else { + return result.status(); + } } else { ASSIGN_OR_RETURN(ElGamalEcPointPair one_ec, EncryptPlaintextToEcPointsCompositeElGamal( @@ -201,11 +209,20 @@ ProtocolCryptorImpl::EncryptIntegerWithCompositElGamalAndWriteToString( ASSIGN_OR_RETURN( ElGamalEcPointPair point_ec, MultiplyEcPointPairByScalar(one_ec, ctx.CreateBigNum(value))); - std::string temp; - ASSIGN_OR_RETURN(temp, point_ec.u.ToBytesCompressed()); - ciphertext.replace(0, kBytesPerEcPoint, temp); - ASSIGN_OR_RETURN(temp, point_ec.e.ToBytesCompressed()); - ciphertext.replace(kBytesPerEcPoint, kBytesPerEcPoint, temp); + + if (absl::StatusOr result = point_ec.u.ToBytesCompressed(); + result.ok()) { + ciphertext.replace(0, kBytesPerEcPoint, *result); + } else { + return result.status(); + } + + if (absl::StatusOr result = point_ec.e.ToBytesCompressed(); + result.ok()) { + ciphertext.replace(kBytesPerEcPoint, kBytesPerEcPoint, *result); + } else { + return result.status(); + } } return ciphertext; } @@ -286,9 +303,9 @@ absl::Status ProtocolCryptorImpl::BatchProcess(absl::string_view data, } case Action::kPartialDecrypt: { ASSIGN_OR_RETURN(std::string temp, DecryptLocalElGamal(ciphertext)); - // The first part of the ciphertext is the random number which is still - // required to decrypt the other layers of ElGamal encryptions (at the - // subsequent duchies. So we keep it. + // The first part of the ciphertext is the random number which is + // still required to decrypt the other layers of ElGamal encryptions + // (at the subsequent duchies. So we keep it. result.replace(pos, kBytesPerEcPoint, ciphertext.first); pos += kBytesPerEcPoint; result.replace(pos, kBytesPerEcPoint, temp); diff --git a/src/main/cc/wfa/measurement/common/crypto/protocol_cryptor.h b/src/main/cc/wfa/measurement/common/crypto/protocol_cryptor.h index f1c2c844955..f7a7ecc68e0 100644 --- a/src/main/cc/wfa/measurement/common/crypto/protocol_cryptor.h +++ b/src/main/cc/wfa/measurement/common/crypto/protocol_cryptor.h @@ -70,9 +70,10 @@ class ProtocolCryptor { // Encrypts the plain EcPoint using the full or partial composite ElGamal Key. virtual absl::StatusOr EncryptCompositeElGamal( absl::string_view plain_ec_point, CompositeType composite_type) = 0; - // Encrypts an integer with the full composite ElGamal Key. - virtual absl::StatusOr - EncryptIntegerWithCompositElGamalAndWriteToString(int64_t value) = 0; + // Maps the integer onto the curve and then encrypts the EcPoint with the full + // composite ElGamal Key, returns the string representation of the ciphertext. + virtual absl::StatusOr EncryptIntegerToStringCompositeElGamal( + int64_t value) = 0; // Encrypts the Identity Element using the full or partial composite ElGamal // Key, returns the result as an ElGamalEcPointPair. virtual absl::StatusOr diff --git a/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/BUILD.bazel b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/BUILD.bazel index 4d6b5ab38d7..79de5b8dfca 100644 --- a/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/BUILD.bazel +++ b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/BUILD.bazel @@ -4,29 +4,12 @@ load("@wfa_common_jvm//build:defs.bzl", "test_target") package(default_visibility = [ ":__pkg__", test_target(":__pkg__"), + "//src/main/cc/wfa/measurement:__subpackages__", "//src/main/swig/protocol:__subpackages__", ]) _INCLUDE_PREFIX = "/src/main/cc" -cc_library( - name = "liquid_legions_v2_encryption_utility_helper", - srcs = [ - "liquid_legions_v2_encryption_utility_helper.cc", - ], - hdrs = [ - "liquid_legions_v2_encryption_utility_helper.h", - ], - strip_include_prefix = _INCLUDE_PREFIX, - deps = [ - "//src/main/proto/wfa/measurement/internal/duchy:crypto_cc_proto", - "//src/main/proto/wfa/measurement/internal/duchy:differential_privacy_cc_proto", - "@any_sketch//src/main/cc/any_sketch/crypto:sketch_encrypter", - "@any_sketch//src/main/cc/estimation:estimators", - "@com_google_absl//absl/status:statusor", - ], -) - cc_library( name = "liquid_legions_v2_encryption_utility", srcs = [ @@ -37,7 +20,6 @@ cc_library( ], strip_include_prefix = _INCLUDE_PREFIX, deps = [ - ":liquid_legions_v2_encryption_utility_helper", ":multithreading_helper", ":noise_parameters_computation", "//src/main/cc/wfa/measurement/common/crypto:constants", @@ -65,7 +47,6 @@ cc_library( ], strip_include_prefix = _INCLUDE_PREFIX, deps = [ - ":liquid_legions_v2_encryption_utility_helper", ":multithreading_helper", ":noise_parameters_computation", "//src/main/cc/wfa/measurement/common/crypto:constants", diff --git a/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility.cc b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility.cc index c257a8c6f6f..bb6fcfde660 100644 --- a/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility.cc +++ b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility.cc @@ -50,6 +50,7 @@ using ::wfa::measurement::common::crypto::ElGamalCiphertext; using ::wfa::measurement::common::crypto::ElGamalEcPointPair; using ::wfa::measurement::common::crypto::ExtractElGamalCiphertextFromString; using ::wfa::measurement::common::crypto::ExtractKeyCountPairFromRegisters; +using ::wfa::measurement::common::crypto::GetBlindedRegisterIndexes; using ::wfa::measurement::common::crypto::GetCountValuesPlaintext; using ::wfa::measurement::common::crypto::GetNumberOfBlocks; using ::wfa::measurement::common::crypto::kBlindedHistogramNoiseRegisterKey; @@ -906,10 +907,9 @@ CompleteExecutionPhaseOneAtAggregator( MultithreadingHelper::CreateMultithreadingHelper( request.parallelism(), protocol_cryptor_options)); - ASSIGN_OR_RETURN( - std::vector blinded_register_indexes, - GetBlindedRegisterIndexes(request.combined_register_vector(), - multithreading_helper->GetProtocolCryptor())); + ASSIGN_OR_RETURN(std::vector blinded_register_indexes, + GetBlindedRegisterIndexes(request.combined_register_vector(), + *multithreading_helper)); // Create a sorting permutation of the blinded register indexes, such that we // don't need to modify the sketch data, whose size could be huge. We only diff --git a/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility.cc b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility.cc index aaf4e93aa6c..9d6e0aef48e 100644 --- a/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility.cc +++ b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility.cc @@ -57,6 +57,7 @@ using ::wfa::measurement::common::crypto::GetCountValuesPlaintext; using ::wfa::measurement::common::crypto::GetEcPointPairFromString; using ::wfa::measurement::common::crypto::GetElGamalEcPoints; using ::wfa::measurement::common::crypto::GetNumberOfBlocks; +using ::wfa::measurement::common::crypto::GetRollv2BlindedRegisterIndexes; using ::wfa::measurement::common::crypto::kBlindedHistogramNoiseRegisterKey; using ::wfa::measurement::common::crypto::kBytesPerCipherText; using ::wfa::measurement::common::crypto::kBytesPerEcPoint; @@ -133,10 +134,6 @@ absl::StatusOr AddReachOnlyBlindedHistogramNoise( ProtocolCryptor& protocol_cryptor, int total_sketches_count, const math::DistributedNoiser& distributed_noiser, size_t pos, std::string& data, int64_t& num_unique_noise_id) { - ASSIGN_OR_RETURN( - std::string blinded_histogram_noise_key_ec, - protocol_cryptor.MapToCurve(kBlindedHistogramNoiseRegisterKey)); - int64_t noise_register_added = 0; num_unique_noise_id = 0; @@ -181,8 +178,11 @@ absl::StatusOr AddReachOnlyNoiseForPublisherNoise( ASSIGN_OR_RETURN(int64_t noise_registers_count, distributed_noiser.GenerateNoiseComponent()); - // Make sure that there is at least one publisher noise added. + // Make sure that there is at least one publisher noise added so that we can + // always subtract 1 for the publisher noise later. This is to avoid the + // corner case where the noise_registers_count is zero for all workers. noise_registers_count++; + absl::AnyInvocable f = [&](ProtocolCryptor& cryptor, size_t index) -> absl::Status { size_t current_pos = pos + kBytesPerCipherText * index; @@ -399,10 +399,9 @@ absl::StatusOr CompleteReachOnlySetupPhase( // Encrypt the excessive noise. ASSIGN_OR_RETURN(std::unique_ptr protocol_cryptor, CreateProtocolCryptor(protocol_cryptor_options)); - ASSIGN_OR_RETURN( - std::string serialized_excessive_noise_ciphertext, - protocol_cryptor->EncryptIntegerWithCompositElGamalAndWriteToString( - excessive_noise_count)); + ASSIGN_OR_RETURN(std::string serialized_excessive_noise_ciphertext, + protocol_cryptor->EncryptIntegerToStringCompositeElGamal( + excessive_noise_count)); response.set_serialized_excessive_noise_ciphertext( serialized_excessive_noise_ciphertext); @@ -555,32 +554,41 @@ CompleteReachOnlyExecutionPhaseAtAggregator( request.noise_parameters().dp_params().blind_histogram(), request.noise_parameters().contributors_count(), request.noise_mechanism()); + // For each a in [1; number_of_EDPs], each worker samples at most + // blind_histogram_noiser->options().shift_offset * 2 noise registers. So + // all workers sample at most #EDPs*#max_per_worker noise registers. int max_excessive_noise = blind_histogram_noiser->options().shift_offset * 2 * request.noise_parameters().total_sketches_count() * - request.noise_parameters().total_sketches_count(); + request.noise_parameters().contributors_count(); // The lookup table stores the max_excessive_noise EC points where // ec_lookup_table[i] = (i+1)*ec_generator. ASSIGN_OR_RETURN( std::vector ec_lookup_table, GetCountValuesPlaintext(max_excessive_noise, request.curve_id())); // Decrypt the excessive noise using the lookup table. - for (int i = 0; i < ec_lookup_table.size(); i++) { + int i = 0; + for (i = 0; i < ec_lookup_table.size(); i++) { if (ec_lookup_table[i] == plaintext) { excessive_noise_count = i + 1; break; } } + // Throws an error if the decryption fails. + if (i == ec_lookup_table.size()) { + return absl::InternalError( + "Failed to decrypt the excessive noise ciphertext."); + } } ASSIGN_OR_RETURN(auto multithreading_helper, MultithreadingHelper::CreateMultithreadingHelper( request.parallelism(), protocol_cryptor_options)); - ASSIGN_OR_RETURN(std::vector blinded_register_indexes, - GetRollv2BlindedRegisterIndexes( - request.combined_register_vector(), - multithreading_helper->GetProtocolCryptor())); + ASSIGN_OR_RETURN( + std::vector blinded_register_indexes, + GetRollv2BlindedRegisterIndexes(request.combined_register_vector(), + *multithreading_helper)); CompleteReachOnlyExecutionPhaseAtAggregatorResponse response; // Counting the number of unique registers. diff --git a/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/testing/BUILD.bazel b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/testing/BUILD.bazel new file mode 100644 index 00000000000..0f6f2be24d5 --- /dev/null +++ b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/testing/BUILD.bazel @@ -0,0 +1,29 @@ +load("@rules_cc//cc:defs.bzl", "cc_library") +load("@wfa_common_jvm//build:defs.bzl", "test_target") + +package( + default_testonly = True, + default_visibility = [ + "//src/test/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2:__subpackages__", + ], +) + +_INCLUDE_PREFIX = "/src/main/cc" + +cc_library( + name = "liquid_legions_v2_encryption_utility_helper", + srcs = [ + "liquid_legions_v2_encryption_utility_helper.cc", + ], + hdrs = [ + "liquid_legions_v2_encryption_utility_helper.h", + ], + strip_include_prefix = _INCLUDE_PREFIX, + deps = [ + "//src/main/proto/wfa/measurement/internal/duchy:crypto_cc_proto", + "//src/main/proto/wfa/measurement/internal/duchy:differential_privacy_cc_proto", + "@any_sketch//src/main/cc/any_sketch/crypto:sketch_encrypter", + "@any_sketch//src/main/cc/estimation:estimators", + "@com_google_absl//absl/status:statusor", + ], +) diff --git a/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility_helper.cc b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/testing/liquid_legions_v2_encryption_utility_helper.cc similarity index 95% rename from src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility_helper.cc rename to src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/testing/liquid_legions_v2_encryption_utility_helper.cc index f0b59c717ab..601e6326225 100644 --- a/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility_helper.cc +++ b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/testing/liquid_legions_v2_encryption_utility_helper.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility_helper.h" +#include "wfa/measurement/internal/duchy/protocol/liquid_legions_v2/testing/liquid_legions_v2_encryption_utility_helper.h" #include "estimation/estimators.h" @@ -30,7 +30,7 @@ ::wfa::any_sketch::crypto::ElGamalPublicKey ToAnySketchElGamalKey( return result; } -ElGamalPublicKey ToCmmsElGamalKey( +ElGamalPublicKey ToDuchyInternalElGamalKey( ::wfa::any_sketch::crypto::ElGamalPublicKey key) { ElGamalPublicKey result; result.set_generator(key.generator()); diff --git a/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility_helper.h b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/testing/liquid_legions_v2_encryption_utility_helper.h similarity index 87% rename from src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility_helper.h rename to src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/testing/liquid_legions_v2_encryption_utility_helper.h index 091d6c1c79e..e7775c8f67e 100644 --- a/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility_helper.h +++ b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/testing/liquid_legions_v2_encryption_utility_helper.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef SRC_MAIN_CC_WFA_MEASUREMENT_INTERNAL_DUCHY_PROTOCOL_LIQUID_LEGIONS_V2_LIQUID_LEGIONS_V2_ENCRYPTION_UTILITY_HELPER_H_ -#define SRC_MAIN_CC_WFA_MEASUREMENT_INTERNAL_DUCHY_PROTOCOL_LIQUID_LEGIONS_V2_LIQUID_LEGIONS_V2_ENCRYPTION_UTILITY_HELPER_H_ +#ifndef SRC_MAIN_CC_WFA_MEASUREMENT_INTERNAL_DUCHY_PROTOCOL_LIQUID_LEGIONS_V2_TESTING_LIQUID_LEGIONS_V2_ENCRYPTION_UTILITY_HELPER_H_ +#define SRC_MAIN_CC_WFA_MEASUREMENT_INTERNAL_DUCHY_PROTOCOL_LIQUID_LEGIONS_V2_TESTING_LIQUID_LEGIONS_V2_ENCRYPTION_UTILITY_HELPER_H_ #include "absl/status/statusor.h" #include "any_sketch/crypto/sketch_encrypter.h" @@ -29,7 +29,7 @@ using ::wfa::measurement::internal::duchy::ElGamalPublicKey; ::wfa::any_sketch::crypto::ElGamalPublicKey ToAnySketchElGamalKey( ElGamalPublicKey key); -ElGamalPublicKey ToCmmsElGamalKey( +ElGamalPublicKey ToDuchyInternalElGamalKey( ::wfa::any_sketch::crypto::ElGamalPublicKey key); Sketch CreateEmptyLiquidLegionsSketch(); @@ -41,4 +41,4 @@ DifferentialPrivacyParams MakeDifferentialPrivacyParams(double epsilon, } // namespace wfa::measurement::internal::duchy::protocol::liquid_legions_v2 -#endif // SRC_MAIN_CC_WFA_MEASUREMENT_INTERNAL_DUCHY_PROTOCOL_LIQUID_LEGIONS_V2_LIQUID_LEGIONS_V2_ENCRYPTION_UTILITY_HELPER_H_ +#endif // SRC_MAIN_CC_WFA_MEASUREMENT_INTERNAL_DUCHY_PROTOCOL_LIQUID_LEGIONS_V2_TESTING_LIQUID_LEGIONS_V2_ENCRYPTION_UTILITY_HELPER_H_ diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/herald/LiquidLegionsV2Starter.kt b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/herald/LiquidLegionsV2Starter.kt index f84898a73af..7a0bc4e3a22 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/herald/LiquidLegionsV2Starter.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/herald/LiquidLegionsV2Starter.kt @@ -197,7 +197,7 @@ object LiquidLegionsV2Starter { // For weird stages, we throw. Stage.UNRECOGNIZED, - Stage.STAGE_UNKNOWN -> { + Stage.STAGE_UNSPECIFIED -> { error("[id=${token.globalComputationId}]: Unrecognized stage '$stage'") } } @@ -254,7 +254,7 @@ object LiquidLegionsV2Starter { // For weird stages, we throw. Stage.UNRECOGNIZED, - Stage.STAGE_UNKNOWN -> { + Stage.STAGE_UNSPECIFIED -> { error("[id=${token.globalComputationId}]: Unrecognized stage '$stage'") } } diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/LiquidLegionsSketchAggregationV2Protocol.kt b/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/LiquidLegionsSketchAggregationV2Protocol.kt index cb5cdac58f8..d8de7f5dea7 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/LiquidLegionsSketchAggregationV2Protocol.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/LiquidLegionsSketchAggregationV2Protocol.kt @@ -121,7 +121,7 @@ object LiquidLegionsSketchAggregationV2Protocol { COMPLETE -> error("Computation should be ended with call to endComputation(...)") // Stages that we can't transition to ever. UNRECOGNIZED, - LiquidLegionsSketchAggregationV2.Stage.STAGE_UNKNOWN, + LiquidLegionsSketchAggregationV2.Stage.STAGE_UNSPECIFIED, INITIALIZATION_PHASE -> error("Cannot make transition function to stage $stage") } } @@ -151,7 +151,7 @@ object LiquidLegionsSketchAggregationV2Protocol { COMPLETE -> error("Computation should be ended with call to endComputation(...)") // Stages that we can't transition to ever. UNRECOGNIZED, - LiquidLegionsSketchAggregationV2.Stage.STAGE_UNKNOWN, + LiquidLegionsSketchAggregationV2.Stage.STAGE_UNSPECIFIED, INITIALIZATION_PHASE -> error("Cannot make transition function to stage $stage") } } diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/computationcontrol/ProtocolStages.kt b/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/computationcontrol/ProtocolStages.kt index c7f591a1628..1af30e1d058 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/computationcontrol/ProtocolStages.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/computationcontrol/ProtocolStages.kt @@ -79,7 +79,7 @@ class LiquidLegionsV2Stages() : Stage.EXECUTION_PHASE_TWO, Stage.EXECUTION_PHASE_THREE, Stage.COMPLETE, - Stage.STAGE_UNKNOWN, + Stage.STAGE_UNSPECIFIED, Stage.UNRECOGNIZED -> throw IllegalStageException(token.computationStage) { "Unexpected $stageType stage: $protocolStage" @@ -104,7 +104,7 @@ class LiquidLegionsV2Stages() : Stage.EXECUTION_PHASE_TWO, Stage.EXECUTION_PHASE_THREE, Stage.COMPLETE, - Stage.STAGE_UNKNOWN, + Stage.STAGE_UNSPECIFIED, Stage.UNRECOGNIZED -> throw IllegalStageException(stage) { "Next $stageType stage unknown for $protocolStage" } }.toProtocolStage() diff --git a/src/main/proto/wfa/measurement/internal/duchy/protocol/liquid_legions_sketch_aggregation_v2.proto b/src/main/proto/wfa/measurement/internal/duchy/protocol/liquid_legions_sketch_aggregation_v2.proto index 0568f525341..764905c0d32 100644 --- a/src/main/proto/wfa/measurement/internal/duchy/protocol/liquid_legions_sketch_aggregation_v2.proto +++ b/src/main/proto/wfa/measurement/internal/duchy/protocol/liquid_legions_sketch_aggregation_v2.proto @@ -28,7 +28,7 @@ option java_multiple_files = true; message LiquidLegionsSketchAggregationV2 { enum Stage { // The computation stage is unknown. This is never set intentionally. - STAGE_UNKNOWN = 0; + STAGE_UNSPECIFIED = 0; // The worker is in the initialization phase. // More specifically, each worker will create a new ElGamal key pair solely diff --git a/src/main/proto/wfa/measurement/internal/duchy/protocol/reach_only_liquid_legions_v2_encryption_methods.proto b/src/main/proto/wfa/measurement/internal/duchy/protocol/reach_only_liquid_legions_v2_encryption_methods.proto index 46df969a70f..351e32c3abc 100644 --- a/src/main/proto/wfa/measurement/internal/duchy/protocol/reach_only_liquid_legions_v2_encryption_methods.proto +++ b/src/main/proto/wfa/measurement/internal/duchy/protocol/reach_only_liquid_legions_v2_encryption_methods.proto @@ -63,7 +63,8 @@ message CompleteReachOnlySetupPhaseRequest { // Public Key of the composite ElGamal cipher. Used to encrypt the excessive // noise (which is zero) when noise_parameters is not available. ElGamalPublicKey composite_el_gamal_public_key = 5; - // The attached encrypted excessive noises. Only for the aggregator. + // This field is only set for the aggregator. There will be one encrypted + // noise element for each non-aggregator worker. bytes serialized_excessive_noise_ciphertext = 6; // The maximum number of threads used by crypto actions. int32 parallelism = 7; diff --git a/src/test/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/BUILD.bazel b/src/test/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/BUILD.bazel index b17fe2d0751..27c843d8678 100644 --- a/src/test/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/BUILD.bazel +++ b/src/test/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/BUILD.bazel @@ -9,7 +9,7 @@ cc_test( ], deps = [ "//src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2:liquid_legions_v2_encryption_utility", - "//src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2:liquid_legions_v2_encryption_utility_helper", + "//src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/testing:liquid_legions_v2_encryption_utility_helper", "//src/main/proto/wfa/measurement/internal/duchy/protocol:liquid_legions_v2_noise_config_cc_proto", "@any_sketch//src/main/cc/any_sketch/crypto:sketch_encrypter", "@any_sketch//src/main/cc/estimation:estimators", @@ -29,7 +29,7 @@ cc_test( ], deps = [ "//src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2:reach_only_liquid_legions_v2_encryption_utility", - "//src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2:liquid_legions_v2_encryption_utility_helper", + "//src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/testing:liquid_legions_v2_encryption_utility_helper", "//src/main/proto/wfa/measurement/internal/duchy/protocol:liquid_legions_v2_noise_config_cc_proto", "@any_sketch//src/main/cc/any_sketch/crypto:sketch_encrypter", "@any_sketch//src/main/cc/estimation:estimators", diff --git a/src/test/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility_test.cc b/src/test/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility_test.cc index b45b23ce045..909c98ef13f 100644 --- a/src/test/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility_test.cc +++ b/src/test/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility_test.cc @@ -30,7 +30,7 @@ #include "wfa/measurement/common/crypto/constants.h" #include "wfa/measurement/common/crypto/ec_point_util.h" #include "wfa/measurement/common/crypto/encryption_utility_helper.h" -#include "wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility_helper.h" +#include "wfa/measurement/internal/duchy/protocol/liquid_legions_v2/testing/liquid_legions_v2_encryption_utility_helper.h" #include "wfa/measurement/internal/duchy/protocol/liquid_legions_v2_encryption_methods.pb.h" namespace wfa::measurement::internal::duchy::protocol::liquid_legions_v2 { @@ -258,14 +258,14 @@ class TestData { // Combine the el_gamal keys from all duchies to generate the data provider // el_gamal key. - client_el_gamal_public_key_ = ToCmmsElGamalKey( + client_el_gamal_public_key_ = ToDuchyInternalElGamalKey( any_sketch::crypto::CombineElGamalPublicKeys( kTestCurveId, {ToAnySketchElGamalKey(duchy_1_el_gamal_key_pair_.public_key()), ToAnySketchElGamalKey(duchy_2_el_gamal_key_pair_.public_key()), ToAnySketchElGamalKey(duchy_3_el_gamal_key_pair_.public_key())}) .value()); - duchy_2_3_composite_public_key_ = ToCmmsElGamalKey( + duchy_2_3_composite_public_key_ = ToDuchyInternalElGamalKey( any_sketch::crypto::CombineElGamalPublicKeys( kTestCurveId, {ToAnySketchElGamalKey(duchy_2_el_gamal_key_pair_.public_key()), diff --git a/src/test/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility_test.cc b/src/test/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility_test.cc index 22f54d649a5..575e5285cae 100644 --- a/src/test/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility_test.cc +++ b/src/test/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility_test.cc @@ -30,7 +30,7 @@ #include "wfa/measurement/common/crypto/constants.h" #include "wfa/measurement/common/crypto/ec_point_util.h" #include "wfa/measurement/common/crypto/encryption_utility_helper.h" -#include "wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility_helper.h" +#include "wfa/measurement/internal/duchy/protocol/liquid_legions_v2/testing/liquid_legions_v2_encryption_utility_helper.h" #include "wfa/measurement/internal/duchy/protocol/liquid_legions_v2_encryption_methods.pb.h" namespace wfa::measurement::internal::duchy::protocol::liquid_legions_v2 { @@ -94,7 +94,8 @@ MATCHER_P(IsBlockSorted, block_size, "") { } // The ReachOnlyTest generates cipher keys for 3 duchies, and the combined -// public key for the data providers. +// public key for the data providers. The duchy 1 and 2 are non-aggregator +// workers, while the duchy 3 is the aggregator. class ReachOnlyTest { public: ElGamalKeyPair duchy_1_el_gamal_key_pair_; @@ -128,14 +129,14 @@ class ReachOnlyTest { // Combine the el_gamal keys from all duchies to generate the data provider // el_gamal key. - client_el_gamal_public_key_ = ToCmmsElGamalKey( + client_el_gamal_public_key_ = ToDuchyInternalElGamalKey( any_sketch::crypto::CombineElGamalPublicKeys( kTestCurveId, {ToAnySketchElGamalKey(duchy_1_el_gamal_key_pair_.public_key()), ToAnySketchElGamalKey(duchy_2_el_gamal_key_pair_.public_key()), ToAnySketchElGamalKey(duchy_3_el_gamal_key_pair_.public_key())}) .value()); - duchy_2_3_composite_public_key_ = ToCmmsElGamalKey( + duchy_2_3_composite_public_key_ = ToDuchyInternalElGamalKey( any_sketch::crypto::CombineElGamalPublicKeys( kTestCurveId, {ToAnySketchElGamalKey(duchy_2_el_gamal_key_pair_.public_key()), @@ -189,6 +190,9 @@ class ReachOnlyTest { EXPECT_THAT( complete_reach_only_setup_phase_response_1.combined_register_vector(), IsBlockSorted(kBytesCipherText)); + EXPECT_THAT(complete_reach_only_setup_phase_response_1 + .serialized_excessive_noise_ciphertext(), + SizeIs(kBytesCipherText)); // Setup phase at Duchy 2. // We assume all test data comes from duchy 1 in the test, so there is only @@ -213,6 +217,9 @@ class ReachOnlyTest { EXPECT_THAT( complete_reach_only_setup_phase_response_2.combined_register_vector(), IsBlockSorted(kBytesCipherText)); + EXPECT_THAT(complete_reach_only_setup_phase_response_2 + .serialized_excessive_noise_ciphertext(), + SizeIs(kBytesCipherText)); // Setup phase at Duchy 3. // We assume all test data comes from duchy 1 in the test, so there is only @@ -253,6 +260,9 @@ class ReachOnlyTest { EXPECT_THAT( complete_reach_only_setup_phase_response_3.combined_register_vector(), IsBlockSorted(kBytesCipherText)); + EXPECT_THAT(complete_reach_only_setup_phase_response_3 + .serialized_excessive_noise_ciphertext(), + SizeIs(kBytesCipherText)); // Execution phase at duchy 1 (non-aggregator). CompleteReachOnlyExecutionPhaseRequest diff --git a/src/test/kotlin/org/wfanet/measurement/duchy/db/computation/LiquidLegionsSketchAggregationV2ProtocolEnumStagesTest.kt b/src/test/kotlin/org/wfanet/measurement/duchy/db/computation/LiquidLegionsSketchAggregationV2ProtocolEnumStagesTest.kt index 4e761a83b78..341f4606ab3 100644 --- a/src/test/kotlin/org/wfanet/measurement/duchy/db/computation/LiquidLegionsSketchAggregationV2ProtocolEnumStagesTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/duchy/db/computation/LiquidLegionsSketchAggregationV2ProtocolEnumStagesTest.kt @@ -86,7 +86,7 @@ class LiquidLegionsSketchAggregationV2ProtocolEnumStagesTest { assertFalse { LiquidLegionsSketchAggregationV2Protocol.EnumStages.validTransition( - LiquidLegionsSketchAggregationV2.Stage.STAGE_UNKNOWN, + LiquidLegionsSketchAggregationV2.Stage.STAGE_UNSPECIFIED, LiquidLegionsSketchAggregationV2.Stage.CONFIRMATION_PHASE ) } From 55fe2b2f4079d568e0055ff3632b8611557883e7 Mon Sep 17 00:00:00 2001 From: Phi Hung Le Date: Thu, 27 Jul 2023 22:21:44 -0400 Subject: [PATCH 08/15] Moving the function GetBlindedRegisterIndexes and GetRollv2BlindedRegisterIndexes into wfa::measurement::internal::duchy::protocol::liquid_ligions_v2. --- .../wfa/measurement/common/crypto/BUILD.bazel | 1 - .../crypto/encryption_utility_helper.cc | 49 ------------------- .../common/crypto/encryption_utility_helper.h | 13 ----- .../common/crypto/protocol_cryptor.cc | 6 +-- .../liquid_legions_v2_encryption_utility.cc | 26 +++++++++- ...ly_liquid_legions_v2_encryption_utility.cc | 28 ++++++++++- .../liquid_legions_v2/testing/BUILD.bazel | 1 - 7 files changed, 54 insertions(+), 70 deletions(-) diff --git a/src/main/cc/wfa/measurement/common/crypto/BUILD.bazel b/src/main/cc/wfa/measurement/common/crypto/BUILD.bazel index 765279f7f4e..825cda1c722 100644 --- a/src/main/cc/wfa/measurement/common/crypto/BUILD.bazel +++ b/src/main/cc/wfa/measurement/common/crypto/BUILD.bazel @@ -28,7 +28,6 @@ cc_library( ":constants", ":ec_point_util", ":protocol_cryptor", - "//src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2:multithreading_helper", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@wfa_common_cpp//src/main/cc/common_cpp/macros", diff --git a/src/main/cc/wfa/measurement/common/crypto/encryption_utility_helper.cc b/src/main/cc/wfa/measurement/common/crypto/encryption_utility_helper.cc index 35a0ac999f0..70e7f1541c5 100644 --- a/src/main/cc/wfa/measurement/common/crypto/encryption_utility_helper.cc +++ b/src/main/cc/wfa/measurement/common/crypto/encryption_utility_helper.cc @@ -25,9 +25,6 @@ namespace wfa::measurement::common::crypto { -using ::wfa::measurement::internal::duchy::protocol::liquid_legions_v2:: - MultithreadingHelper; - absl::StatusOr GetNumberOfBlocks(absl::string_view data, size_t block_size) { if (block_size == 0) { @@ -51,52 +48,6 @@ absl::StatusOr ExtractElGamalCiphertextFromString( std::string(str.substr(kBytesPerEcPoint, kBytesPerEcPoint))); } -absl::StatusOr> GetBlindedRegisterIndexes( - absl::string_view data, MultithreadingHelper& helper) { - ASSIGN_OR_RETURN(size_t register_count, - GetNumberOfBlocks(data, kBytesPerCipherRegister)); - std::vector blinded_register_indexes; - blinded_register_indexes.resize(register_count); - - absl::AnyInvocable f = - [&](ProtocolCryptor& cryptor, size_t index) -> absl::Status { - absl::string_view data_block = - data.substr(index * kBytesPerCipherRegister, kBytesPerCipherText); - ASSIGN_OR_RETURN(ElGamalCiphertext ciphertext, - ExtractElGamalCiphertextFromString(data_block)); - ASSIGN_OR_RETURN(std::string decrypted_el_gamal, - cryptor.DecryptLocalElGamal(ciphertext)); - blinded_register_indexes[index] = std::move(decrypted_el_gamal); - return absl::OkStatus(); - }; - RETURN_IF_ERROR(helper.Execute(register_count, f)); - - return blinded_register_indexes; -} - -absl::StatusOr> GetRollv2BlindedRegisterIndexes( - absl::string_view data, MultithreadingHelper& helper) { - ASSIGN_OR_RETURN(size_t register_count, - GetNumberOfBlocks(data, kBytesPerCipherText)); - std::vector blinded_register_indexes; - blinded_register_indexes.resize(register_count); - - absl::AnyInvocable f = - [&](ProtocolCryptor& cryptor, size_t index) -> absl::Status { - absl::string_view data_block = - data.substr(index * kBytesPerCipherText, kBytesPerCipherText); - ASSIGN_OR_RETURN(ElGamalCiphertext ciphertext, - ExtractElGamalCiphertextFromString(data_block)); - ASSIGN_OR_RETURN(std::string decrypted_el_gamal, - cryptor.DecryptLocalElGamal(ciphertext)); - blinded_register_indexes[index] = std::move(decrypted_el_gamal); - return absl::OkStatus(); - }; - RETURN_IF_ERROR(helper.Execute(register_count, f)); - - return blinded_register_indexes; -} - absl::StatusOr ExtractKeyCountPairFromSubstring( absl::string_view str) { if (str.size() != kBytesPerCipherText * 2) { diff --git a/src/main/cc/wfa/measurement/common/crypto/encryption_utility_helper.h b/src/main/cc/wfa/measurement/common/crypto/encryption_utility_helper.h index 834d7b9c2ba..65da3709569 100644 --- a/src/main/cc/wfa/measurement/common/crypto/encryption_utility_helper.h +++ b/src/main/cc/wfa/measurement/common/crypto/encryption_utility_helper.h @@ -23,13 +23,10 @@ #include "absl/strings/string_view.h" #include "wfa/measurement/common/crypto/ec_point_util.h" #include "wfa/measurement/common/crypto/protocol_cryptor.h" -#include "wfa/measurement/internal/duchy/protocol/liquid_legions_v2/multithreading_helper.h" namespace wfa::measurement::common::crypto { using ::wfa::measurement::common::crypto::CompositeType; -using ::wfa::measurement::internal::duchy::protocol::liquid_legions_v2:: - MultithreadingHelper; // A pair of ciphertexts which store the key and count values of a liquidlegions // register. @@ -46,16 +43,6 @@ absl::StatusOr GetNumberOfBlocks(absl::string_view data, absl::StatusOr ExtractElGamalCiphertextFromString( absl::string_view str); -// Blinds the last layer of ElGamal Encryption of register indexes, and return -// the deterministically encrypted results. -absl::StatusOr> GetBlindedRegisterIndexes( - absl::string_view data, MultithreadingHelper& helper); - -// Blinds the last layer of ElGamal Encryption of register indexes, and return -// the deterministically encrypted results. -absl::StatusOr> GetRollv2BlindedRegisterIndexes( - absl::string_view data, MultithreadingHelper& helper); - // Extracts a KeyCountPairCipherText from a string_view. absl::StatusOr ExtractKeyCountPairFromSubstring( absl::string_view str); diff --git a/src/main/cc/wfa/measurement/common/crypto/protocol_cryptor.cc b/src/main/cc/wfa/measurement/common/crypto/protocol_cryptor.cc index ae1e040b5c4..ea3cb0b6493 100644 --- a/src/main/cc/wfa/measurement/common/crypto/protocol_cryptor.cc +++ b/src/main/cc/wfa/measurement/common/crypto/protocol_cryptor.cc @@ -303,9 +303,9 @@ absl::Status ProtocolCryptorImpl::BatchProcess(absl::string_view data, } case Action::kPartialDecrypt: { ASSIGN_OR_RETURN(std::string temp, DecryptLocalElGamal(ciphertext)); - // The first part of the ciphertext is the random number which is - // still required to decrypt the other layers of ElGamal encryptions - // (at the subsequent duchies. So we keep it. + // The first part of the ciphertext is the random number which is still + // required to decrypt the other layers of ElGamal encryptions (at the + // subsequent duchies). So we keep it. result.replace(pos, kBytesPerEcPoint, ciphertext.first); pos += kBytesPerEcPoint; result.replace(pos, kBytesPerEcPoint, temp); diff --git a/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility.cc b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility.cc index bb6fcfde660..62b7efc8c77 100644 --- a/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility.cc +++ b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/liquid_legions_v2_encryption_utility.cc @@ -50,7 +50,6 @@ using ::wfa::measurement::common::crypto::ElGamalCiphertext; using ::wfa::measurement::common::crypto::ElGamalEcPointPair; using ::wfa::measurement::common::crypto::ExtractElGamalCiphertextFromString; using ::wfa::measurement::common::crypto::ExtractKeyCountPairFromRegisters; -using ::wfa::measurement::common::crypto::GetBlindedRegisterIndexes; using ::wfa::measurement::common::crypto::GetCountValuesPlaintext; using ::wfa::measurement::common::crypto::GetNumberOfBlocks; using ::wfa::measurement::common::crypto::kBlindedHistogramNoiseRegisterKey; @@ -76,6 +75,31 @@ using ::wfa::measurement::common::crypto::ProtocolCryptorOptions; using ::wfa::measurement::internal::duchy::ElGamalPublicKey; using ::wfa::measurement::internal::duchy::protocol::LiquidLegionsV2NoiseConfig; +// Blinds the last layer of ElGamal Encryption of register indexes, and return +// the deterministically encrypted results. +absl::StatusOr> GetBlindedRegisterIndexes( + absl::string_view data, MultithreadingHelper& helper) { + ASSIGN_OR_RETURN(size_t register_count, + GetNumberOfBlocks(data, kBytesPerCipherRegister)); + std::vector blinded_register_indexes; + blinded_register_indexes.resize(register_count); + + absl::AnyInvocable f = + [&](ProtocolCryptor& cryptor, size_t index) -> absl::Status { + absl::string_view data_block = + data.substr(index * kBytesPerCipherRegister, kBytesPerCipherText); + ASSIGN_OR_RETURN(ElGamalCiphertext ciphertext, + ExtractElGamalCiphertextFromString(data_block)); + ASSIGN_OR_RETURN(std::string decrypted_el_gamal, + cryptor.DecryptLocalElGamal(ciphertext)); + blinded_register_indexes[index] = std::move(decrypted_el_gamal); + return absl::OkStatus(); + }; + RETURN_IF_ERROR(helper.Execute(register_count, f)); + + return blinded_register_indexes; +} + // Merge all the counts in each group using the SameKeyAggregation algorithm. // The calculated (flag_1, flag_2, flag_3, count) tuple is appended to the // response. 'sub_permutation' contains the locations of the registers belonging diff --git a/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility.cc b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility.cc index 9d6e0aef48e..bb52b3b6be0 100644 --- a/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility.cc +++ b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility.cc @@ -57,7 +57,6 @@ using ::wfa::measurement::common::crypto::GetCountValuesPlaintext; using ::wfa::measurement::common::crypto::GetEcPointPairFromString; using ::wfa::measurement::common::crypto::GetElGamalEcPoints; using ::wfa::measurement::common::crypto::GetNumberOfBlocks; -using ::wfa::measurement::common::crypto::GetRollv2BlindedRegisterIndexes; using ::wfa::measurement::common::crypto::kBlindedHistogramNoiseRegisterKey; using ::wfa::measurement::common::crypto::kBytesPerCipherText; using ::wfa::measurement::common::crypto::kBytesPerEcPoint; @@ -80,6 +79,31 @@ using ::wfa::measurement::common::crypto::ProtocolCryptorOptions; using ::wfa::measurement::internal::duchy::ElGamalPublicKey; using ::wfa::measurement::internal::duchy::protocol::LiquidLegionsV2NoiseConfig; +// Blinds the last layer of ElGamal Encryption of register indexes, and return +// the deterministically encrypted results. +absl::StatusOr> GetRollv2BlindedRegisterIndexes( + absl::string_view data, MultithreadingHelper& helper) { + ASSIGN_OR_RETURN(size_t register_count, + GetNumberOfBlocks(data, kBytesPerCipherText)); + std::vector blinded_register_indexes; + blinded_register_indexes.resize(register_count); + + absl::AnyInvocable f = + [&](ProtocolCryptor& cryptor, size_t index) -> absl::Status { + absl::string_view data_block = + data.substr(index * kBytesPerCipherText, kBytesPerCipherText); + ASSIGN_OR_RETURN(ElGamalCiphertext ciphertext, + ExtractElGamalCiphertextFromString(data_block)); + ASSIGN_OR_RETURN(std::string decrypted_el_gamal, + cryptor.DecryptLocalElGamal(ciphertext)); + blinded_register_indexes[index] = std::move(decrypted_el_gamal); + return absl::OkStatus(); + }; + RETURN_IF_ERROR(helper.Execute(register_count, f)); + + return blinded_register_indexes; +} + absl::StatusOr EstimateReach(double liquid_legions_decay_rate, int64_t liquid_legions_size, size_t non_empty_register_count, @@ -574,7 +598,7 @@ CompleteReachOnlyExecutionPhaseAtAggregator( break; } } - // Throws an error if the decryption fails. + // Returns an error if the decryption fails. if (i == ec_lookup_table.size()) { return absl::InternalError( "Failed to decrypt the excessive noise ciphertext."); diff --git a/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/testing/BUILD.bazel b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/testing/BUILD.bazel index 0f6f2be24d5..0fceac27edc 100644 --- a/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/testing/BUILD.bazel +++ b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/testing/BUILD.bazel @@ -1,5 +1,4 @@ load("@rules_cc//cc:defs.bzl", "cc_library") -load("@wfa_common_jvm//build:defs.bzl", "test_target") package( default_testonly = True, From 49e8cd6af8ba45eb63878110cb8738a1ec3b5234 Mon Sep 17 00:00:00 2001 From: ple13 Date: Fri, 28 Jul 2023 03:46:59 -0400 Subject: [PATCH 09/15] Removing unused message CompleteReachOnlySetupPhaseAtAggregatorResponse. --- ...reach_only_liquid_legions_v2_encryption_utility.h | 2 -- ...h_only_liquid_legions_v2_encryption_methods.proto | 12 ------------ 2 files changed, 14 deletions(-) diff --git a/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility.h b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility.h index afb5f8fcfa6..311ab5e9f29 100644 --- a/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility.h +++ b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility.h @@ -32,8 +32,6 @@ using ::wfa::measurement::internal::duchy::protocol:: CompleteReachOnlyInitializationPhaseRequest; using ::wfa::measurement::internal::duchy::protocol:: CompleteReachOnlyInitializationPhaseResponse; -using ::wfa::measurement::internal::duchy::protocol:: - CompleteReachOnlySetupPhaseAtAggregatorResponse; using ::wfa::measurement::internal::duchy::protocol:: CompleteReachOnlySetupPhaseRequest; using ::wfa::measurement::internal::duchy::protocol:: diff --git a/src/main/proto/wfa/measurement/internal/duchy/protocol/reach_only_liquid_legions_v2_encryption_methods.proto b/src/main/proto/wfa/measurement/internal/duchy/protocol/reach_only_liquid_legions_v2_encryption_methods.proto index 351e32c3abc..743baf1187e 100644 --- a/src/main/proto/wfa/measurement/internal/duchy/protocol/reach_only_liquid_legions_v2_encryption_methods.proto +++ b/src/main/proto/wfa/measurement/internal/duchy/protocol/reach_only_liquid_legions_v2_encryption_methods.proto @@ -81,18 +81,6 @@ message CompleteReachOnlySetupPhaseResponse { int64 elapsed_cpu_time_millis = 3; } -// Response of the CompleteReachOnlySetupPhase method at the aggregate worker. -message CompleteReachOnlySetupPhaseAtAggregatorResponse { - // The output combined register vector (CRV), which contains shuffled input - // and noise registers. - bytes combined_register_vector = 1; - // The serialized El Gamal ciphertext that encrypts the aggregated excessive - // noise of the aggregator. - bytes serialized_excessive_noise_ciphertext = 2; - // The CPU time of processing the request. - int64 elapsed_cpu_time_millis = 3; -} - // The request to complete work in the execution phase at a non-aggregator // worker. message CompleteReachOnlyExecutionPhaseRequest { From 1f8e67be6b598bbc48eed637e4fc776d25cd5d63 Mon Sep 17 00:00:00 2001 From: Phi Hung Le Date: Fri, 4 Aug 2023 16:33:55 -0400 Subject: [PATCH 10/15] Add the reach only protocol to the Duchy Mill. --- build/repositories.bzl | 5 +- .../protocol/liquid_legions_v2/BUILD.bazel | 17 + ...d_legions_v2_encryption_utility_wrapper.cc | 62 + ...id_legions_v2_encryption_utility_wrapper.h | 45 + ...ggregator_protocols_setup_config.textproto | 4 + ...ggregator_protocols_setup_config.textproto | 4 + .../measurement/duchy/ComputationStage.kt | 16 +- .../duchy/daemon/herald/BUILD.bazel | 1 + .../measurement/duchy/daemon/herald/Herald.kt | 33 +- .../herald/ReachOnlyLiquidLegionsV2Starter.kt | 356 ++++ .../daemon/mill/liquidlegionsv2/BUILD.bazel | 38 + .../ReachOnlyLiquidLegionsV2Mill.kt | 763 +++++++ .../mill/liquidlegionsv2/crypto/BUILD.bazel | 21 +- .../JniReachOnlyLiquidLegionsV2Encryption.kt | 93 + .../ReachOnlyLiquidLegionsV2Encryption.kt | 58 + .../ComputationProtocolStageDetails.kt | 18 + .../computation/ComputationProtocolStages.kt | 19 + ...iquidLegionsSketchAggregationV2Protocol.kt | 246 +++ .../testing/FakeComputationsDatabase.kt | 2 + .../computationcontrol/ProtocolStages.kt | 128 +- .../computations/ComputationsService.kt | 8 + .../AdvanceComputationRequestHeaders.kt | 38 +- .../api/v2alpha/RequisitionsService.kt | 7 + .../v1alpha/ComputationParticipantsService.kt | 37 +- .../system/v1alpha/ProtoConversions.kt | 30 +- .../measurement/internal/duchy/BUILD.bazel | 2 + .../internal/duchy/computation_details.proto | 9 + .../duchy/computation_protocols.proto | 11 +- .../duchy/config/protocols_setup_config.proto | 6 +- ...liquid_legions_v2_encryption_methods.proto | 2 +- .../liquid_legions_v2_noise_config.proto | 3 +- ...liquid_legions_sketch_aggregation_v2.proto | 7 +- ...liquid_legions_v2_encryption_methods.proto | 8 +- .../kingdom/computation_participant.proto | 3 + .../computation_participants_service.proto | 3 + .../v1alpha/computation_control_service.proto | 21 + .../v1alpha/computation_participant.proto | 5 +- .../reachonlyliquidlegionsv2/BUILD.bazel | 16 + .../reachonlyliquidlegionsv2/README.md | 22 + ..._liquid_legions_v2_encryption_utility.swig | 67 + .../duchy/daemon/herald/HeraldTest.kt | 557 ++++- .../daemon/mill/liquidlegionsv2/BUILD.bazel | 30 + .../ReachOnlyLiquidLegionsV2MillTest.kt | 1827 +++++++++++++++++ .../mill/liquidlegionsv2/crypto/BUILD.bazel | 32 + ...iReachOnlyLiquidLegionsV2EncryptionTest.kt | 43 + ...nlyLiquidLegionsV2EncryptionUtilityTest.kt | 305 +++ .../duchy/db/computation/BUILD.bazel | 26 + .../computation/ComputationsEnumHelperTest.kt | 20 + ...regationV2ProtocolEnumStagesDetailsTest.kt | 64 + ...etchAggregationV2ProtocolEnumStagesTest.kt | 103 + .../internal/computationcontrol/BUILD.bazel | 16 + .../ReachOnlyLiquidLegionsV2StagesTest.kt | 124 ++ .../v1alpha/ComputationControlServiceTest.kt | 69 + .../ComputationParticipantsServiceTest.kt | 72 +- .../system/v1alpha/ComputationsServiceTest.kt | 166 +- 55 files changed, 5480 insertions(+), 208 deletions(-) create mode 100644 src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility_wrapper.cc create mode 100644 src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility_wrapper.h create mode 100644 src/main/kotlin/org/wfanet/measurement/duchy/daemon/herald/ReachOnlyLiquidLegionsV2Starter.kt create mode 100644 src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2Mill.kt create mode 100644 src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/JniReachOnlyLiquidLegionsV2Encryption.kt create mode 100644 src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/ReachOnlyLiquidLegionsV2Encryption.kt create mode 100644 src/main/kotlin/org/wfanet/measurement/duchy/db/computation/ReachOnlyLiquidLegionsSketchAggregationV2Protocol.kt create mode 100644 src/main/swig/protocol/reachonlyliquidlegionsv2/BUILD.bazel create mode 100644 src/main/swig/protocol/reachonlyliquidlegionsv2/README.md create mode 100644 src/main/swig/protocol/reachonlyliquidlegionsv2/reach_only_liquid_legions_v2_encryption_utility.swig create mode 100644 src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2MillTest.kt create mode 100644 src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/JniReachOnlyLiquidLegionsV2EncryptionTest.kt create mode 100644 src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/ReachOnlyLiquidLegionsV2EncryptionUtilityTest.kt create mode 100644 src/test/kotlin/org/wfanet/measurement/duchy/db/computation/ReachOnlyLiquidLegionsSketchAggregationV2ProtocolEnumStagesDetailsTest.kt create mode 100644 src/test/kotlin/org/wfanet/measurement/duchy/db/computation/ReachOnlyLiquidLegionsSketchAggregationV2ProtocolEnumStagesTest.kt create mode 100644 src/test/kotlin/org/wfanet/measurement/duchy/service/internal/computationcontrol/ReachOnlyLiquidLegionsV2StagesTest.kt diff --git a/build/repositories.bzl b/build/repositories.bzl index 545d09c038c..31f4f956d48 100644 --- a/build/repositories.bzl +++ b/build/repositories.bzl @@ -41,8 +41,9 @@ def wfa_measurement_system_repositories(): wfa_repo_archive( name = "wfa_measurement_proto", repo = "cross-media-measurement-api", - sha256 = "22f32f247c95d5c6efab8b00ecf3019268f293caf5065e1e0ab738419ad3c1d0", - version = "0.38.1", + # DO_NOT_SUBMIT(renjiez): until using a release version. + sha256 = "6133f3d3c30ccb2e92ea9524432deb1f226f25589d17b9f34cf1f35250ff36b9", + commit = "e68544378b77b133b1c1389af375d53cd768d964", ) wfa_repo_archive( diff --git a/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/BUILD.bazel b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/BUILD.bazel index 79de5b8dfca..0dda2d3882e 100644 --- a/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/BUILD.bazel +++ b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/BUILD.bazel @@ -81,6 +81,23 @@ cc_library( ], ) +cc_library( + name = "reach_only_liquid_legions_v2_encryption_utility_wrapper", + srcs = [ + "reach_only_liquid_legions_v2_encryption_utility_wrapper.cc", + ], + hdrs = [ + "reach_only_liquid_legions_v2_encryption_utility_wrapper.h", + ], + strip_include_prefix = _INCLUDE_PREFIX, + deps = [ + ":reach_only_liquid_legions_v2_encryption_utility", + "//src/main/proto/wfa/measurement/internal/duchy/protocol:reach_only_liquid_legions_v2_encryption_methods_cc_proto", + "@wfa_common_cpp//src/main/cc/common_cpp/jni:jni_wrap", + "@wfa_common_cpp//src/main/cc/common_cpp/macros", + ], +) + cc_library( name = "noise_parameters_computation", srcs = [ diff --git a/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility_wrapper.cc b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility_wrapper.cc new file mode 100644 index 00000000000..8a1d995ba58 --- /dev/null +++ b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility_wrapper.cc @@ -0,0 +1,62 @@ +// Copyright 2023 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility_wrapper.h" + +#include + +#include "absl/status/statusor.h" +#include "common_cpp/jni/jni_wrap.h" +#include "common_cpp/macros/macros.h" +#include "wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility.h" +#include "wfa/measurement/internal/duchy/protocol/reach_only_liquid_legions_v2_encryption_methods.pb.h" + +namespace wfa::measurement::internal::duchy::protocol::liquid_legions_v2 { + +absl::StatusOr CompleteReachOnlyInitializationPhase( + const std::string& serialized_request) { + return JniWrap( + serialized_request, CompleteReachOnlyInitializationPhase); +} + +absl::StatusOr CompleteReachOnlySetupPhase( + const std::string& serialized_request) { + return JniWrap( + serialized_request, CompleteReachOnlySetupPhase); +} + +absl::StatusOr CompleteReachOnlySetupPhaseAtAggregator( + const std::string& serialized_request) { + return JniWrap( + serialized_request, CompleteReachOnlySetupPhaseAtAggregator); +} + +absl::StatusOr CompleteReachOnlyExecutionPhase( + const std::string& serialized_request) { + return JniWrap( + serialized_request, CompleteReachOnlyExecutionPhase); +} + +absl::StatusOr CompleteReachOnlyExecutionPhaseAtAggregator( + const std::string& serialized_request) { + return JniWrap( + serialized_request, CompleteReachOnlyExecutionPhaseAtAggregator); +} + +} // namespace wfa::measurement::internal::duchy::protocol::liquid_legions_v2 diff --git a/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility_wrapper.h b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility_wrapper.h new file mode 100644 index 00000000000..0b940ce936f --- /dev/null +++ b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility_wrapper.h @@ -0,0 +1,45 @@ +// Copyright 2023 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SRC_MAIN_CC_WFA_MEASUREMENT_INTERNAL_DUCHY_PROTOCOL_LIQUID_LEGIONS_V2_REACH_ONLY_LIQUID_LEGIONS_V2_ENCRYPTION_UTILITY_WRAPPER_H_ +#define SRC_MAIN_CC_WFA_MEASUREMENT_INTERNAL_DUCHY_PROTOCOL_LIQUID_LEGIONS_V2_REACH_ONLY_LIQUID_LEGIONS_V2_ENCRYPTION_UTILITY_WRAPPER_H_ + +#include + +#include "absl/status/statusor.h" + +// Wrapper methods used to generate the swig/JNI Java classes. +// The only functionality of these methods are converting between proto messages +// and their corresponding serialized strings, and then calling into the +// reach_only_liquid_legions_v2_encryption_utility methods. +namespace wfa::measurement::internal::duchy::protocol::liquid_legions_v2 { + +absl::StatusOr CompleteReachOnlyInitializationPhase( + const std::string& serialized_request); + +absl::StatusOr CompleteReachOnlySetupPhase( + const std::string& serialized_request); + +absl::StatusOr CompleteReachOnlySetupPhaseAtAggregator( + const std::string& serialized_request); + +absl::StatusOr CompleteReachOnlyExecutionPhase( + const std::string& serialized_request); + +absl::StatusOr CompleteReachOnlyExecutionPhaseAtAggregator( + const std::string& serialized_request); + +} // namespace wfa::measurement::internal::duchy::protocol::liquid_legions_v2 + +#endif // SRC_MAIN_CC_WFA_MEASUREMENT_INTERNAL_DUCHY_PROTOCOL_LIQUID_LEGIONS_V2_REACH_ONLY_LIQUID_LEGIONS_V2_ENCRYPTION_UTILITY_WRAPPER_H_ diff --git a/src/main/k8s/testing/secretfiles/aggregator_protocols_setup_config.textproto b/src/main/k8s/testing/secretfiles/aggregator_protocols_setup_config.textproto index e4b95869b4c..3e78df01981 100644 --- a/src/main/k8s/testing/secretfiles/aggregator_protocols_setup_config.textproto +++ b/src/main/k8s/testing/secretfiles/aggregator_protocols_setup_config.textproto @@ -4,3 +4,7 @@ liquid_legions_v2 { role: AGGREGATOR external_aggregator_duchy_id: "aggregator" } +reach_only_liquid_legions_v2 { + role: AGGREGATOR + external_aggregator_duchy_id: "aggregator" +} diff --git a/src/main/k8s/testing/secretfiles/non_aggregator_protocols_setup_config.textproto b/src/main/k8s/testing/secretfiles/non_aggregator_protocols_setup_config.textproto index fd50191bd76..f3fa939a7c8 100644 --- a/src/main/k8s/testing/secretfiles/non_aggregator_protocols_setup_config.textproto +++ b/src/main/k8s/testing/secretfiles/non_aggregator_protocols_setup_config.textproto @@ -4,3 +4,7 @@ liquid_legions_v2 { role: NON_AGGREGATOR external_aggregator_duchy_id: "aggregator" } +reach_only_liquid_legions_v2 { + role: NON_AGGREGATOR + external_aggregator_duchy_id: "aggregator" +} \ No newline at end of file diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/ComputationStage.kt b/src/main/kotlin/org/wfanet/measurement/duchy/ComputationStage.kt index dcc40783cc9..7d40eb704f3 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/ComputationStage.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/ComputationStage.kt @@ -15,7 +15,9 @@ package org.wfanet.measurement.duchy import org.wfanet.measurement.internal.duchy.ComputationStage +import org.wfanet.measurement.internal.duchy.computationStage import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggregationV2 +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2 val ComputationStage.name: String @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. @@ -23,6 +25,8 @@ val ComputationStage.name: String when (stageCase) { ComputationStage.StageCase.LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 -> liquidLegionsSketchAggregationV2.name + ComputationStage.StageCase.REACH_ONLY_LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 -> + reachOnlyLiquidLegionsSketchAggregationV2.name ComputationStage.StageCase.STAGE_NOT_SET -> error("Stage not set") } @@ -32,8 +36,16 @@ val ComputationStage.number: Int when (stageCase) { ComputationStage.StageCase.LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 -> liquidLegionsSketchAggregationV2.number + ComputationStage.StageCase.REACH_ONLY_LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 -> + reachOnlyLiquidLegionsSketchAggregationV2.number ComputationStage.StageCase.STAGE_NOT_SET -> error("Stage not set") } -fun LiquidLegionsSketchAggregationV2.Stage.toProtocolStage(): ComputationStage = - ComputationStage.newBuilder().setLiquidLegionsSketchAggregationV2(this).build() +fun LiquidLegionsSketchAggregationV2.Stage.toProtocolStage(): ComputationStage = computationStage { + liquidLegionsSketchAggregationV2 = this@toProtocolStage +} + +fun ReachOnlyLiquidLegionsSketchAggregationV2.Stage.toProtocolStage(): ComputationStage = + computationStage { + reachOnlyLiquidLegionsSketchAggregationV2 = this@toProtocolStage + } diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/herald/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/herald/BUILD.bazel index 4ff006b7913..a8fd04376f4 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/herald/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/herald/BUILD.bazel @@ -24,6 +24,7 @@ kt_jvm_library( srcs = [ "Herald.kt", "LiquidLegionsV2Starter.kt", + "ReachOnlyLiquidLegionsV2Starter.kt", ], runtime_deps = ["@wfa_common_jvm//imports/java/io/grpc/netty"], deps = [ diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/herald/Herald.kt b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/herald/Herald.kt index d6c9a51d697..2f1a6eacd51 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/herald/Herald.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/herald/Herald.kt @@ -32,9 +32,7 @@ import kotlinx.coroutines.launch import kotlinx.coroutines.sync.Semaphore import org.wfanet.measurement.common.grpc.grpcStatusCode import org.wfanet.measurement.common.protoTimestamp -import org.wfanet.measurement.duchy.daemon.utils.MeasurementType import org.wfanet.measurement.duchy.daemon.utils.key -import org.wfanet.measurement.duchy.daemon.utils.toMeasurementType import org.wfanet.measurement.duchy.service.internal.computations.toGetTokenRequest import org.wfanet.measurement.internal.duchy.ComputationDetails import org.wfanet.measurement.internal.duchy.ComputationsGrpcKt.ComputationsCoroutineStub @@ -229,16 +227,22 @@ class Herald( val globalId: String = systemComputation.key.computationId logger.info("[id=$globalId] Creating Computation...") try { - when (systemComputation.toMeasurementType()) { - MeasurementType.REACH, - MeasurementType.REACH_AND_FREQUENCY -> { + when (systemComputation.mpcProtocolConfig.protocolCase) { + Computation.MpcProtocolConfig.ProtocolCase.LIQUID_LEGIONS_V2 -> LiquidLegionsV2Starter.createComputation( internalComputationsClient, systemComputation, protocolsSetupConfig.liquidLegionsV2, blobStorageBucket ) - } + Computation.MpcProtocolConfig.ProtocolCase.REACH_ONLY_LIQUID_LEGIONS_V2 -> + ReachOnlyLiquidLegionsV2Starter.createComputation( + internalComputationsClient, + systemComputation, + protocolsSetupConfig.reachOnlyLiquidLegionsV2, + blobStorageBucket + ) + else -> error("Unknown or unsupported protocol for creation.") } logger.info("[id=$globalId]: Created Computation") } catch (e: StatusException) { @@ -302,6 +306,13 @@ class Herald( systemComputation, protocolsSetupConfig.liquidLegionsV2.externalAggregatorDuchyId ) + ComputationDetails.ProtocolCase.REACH_ONLY_LIQUID_LEGIONS_V2 -> + ReachOnlyLiquidLegionsV2Starter.updateRequisitionsAndKeySets( + token, + internalComputationsClient, + systemComputation, + protocolsSetupConfig.reachOnlyLiquidLegionsV2.externalAggregatorDuchyId + ) else -> error("Unknown or unsupported protocol.") } logger.info("[id=$globalId]: Confirmed Computation") @@ -317,6 +328,8 @@ class Herald( when (token.computationDetails.protocolCase) { ComputationDetails.ProtocolCase.LIQUID_LEGIONS_V2 -> LiquidLegionsV2Starter.startComputation(token, internalComputationsClient) + ComputationDetails.ProtocolCase.REACH_ONLY_LIQUID_LEGIONS_V2 -> + ReachOnlyLiquidLegionsV2Starter.startComputation(token, internalComputationsClient) else -> error("Unknown or unsupported protocol.") } logger.info("[id=$globalId]: Started Computation") @@ -365,8 +378,10 @@ class Herald( } ?: return if ( - token.computationDetails.hasLiquidLegionsV2() && - token.computationStage == LiquidLegionsV2Starter.TERMINAL_STAGE + (token.computationDetails.hasLiquidLegionsV2() && + token.computationStage == LiquidLegionsV2Starter.TERMINAL_STAGE) || + (token.computationDetails.hasReachOnlyLiquidLegionsV2() && + token.computationStage == ReachOnlyLiquidLegionsV2Starter.TERMINAL_STAGE) ) { return } @@ -376,6 +391,8 @@ class Herald( endingComputationStage = when (token.computationDetails.protocolCase) { ComputationDetails.ProtocolCase.LIQUID_LEGIONS_V2 -> LiquidLegionsV2Starter.TERMINAL_STAGE + ComputationDetails.ProtocolCase.REACH_ONLY_LIQUID_LEGIONS_V2 -> + ReachOnlyLiquidLegionsV2Starter.TERMINAL_STAGE else -> error { "Unknown or unsupported protocol." } } reason = ComputationDetails.CompletedReason.FAILED diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/herald/ReachOnlyLiquidLegionsV2Starter.kt b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/herald/ReachOnlyLiquidLegionsV2Starter.kt new file mode 100644 index 00000000000..af7e002a445 --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/herald/ReachOnlyLiquidLegionsV2Starter.kt @@ -0,0 +1,356 @@ +// Copyright 2023 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package org.wfanet.measurement.duchy.daemon.herald + +import java.util.logging.Logger +import org.wfanet.measurement.api.Version +import org.wfanet.measurement.api.v2alpha.MeasurementSpec +import org.wfanet.measurement.duchy.daemon.utils.key +import org.wfanet.measurement.duchy.daemon.utils.sha1Hash +import org.wfanet.measurement.duchy.daemon.utils.toDuchyDifferentialPrivacyParams +import org.wfanet.measurement.duchy.daemon.utils.toDuchyElGamalPublicKey +import org.wfanet.measurement.duchy.daemon.utils.toKingdomComputationDetails +import org.wfanet.measurement.duchy.daemon.utils.toRequisitionEntries +import org.wfanet.measurement.duchy.db.computation.advanceComputationStage +import org.wfanet.measurement.duchy.service.internal.computations.outputPathList +import org.wfanet.measurement.duchy.toProtocolStage +import org.wfanet.measurement.internal.duchy.ComputationToken +import org.wfanet.measurement.internal.duchy.ComputationTypeEnum +import org.wfanet.measurement.internal.duchy.ComputationsGrpcKt +import org.wfanet.measurement.internal.duchy.computationDetails +import org.wfanet.measurement.internal.duchy.config.LiquidLegionsV2SetupConfig +import org.wfanet.measurement.internal.duchy.createComputationRequest +import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsV2NoiseConfig +import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsV2NoiseConfigKt +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2 +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2Kt +import org.wfanet.measurement.internal.duchy.protocol.liquidLegionsSketchParameters +import org.wfanet.measurement.internal.duchy.protocol.liquidLegionsV2NoiseConfig +import org.wfanet.measurement.internal.duchy.updateComputationDetailsRequest +import org.wfanet.measurement.system.v1alpha.Computation +import org.wfanet.measurement.system.v1alpha.ComputationParticipant + +private const val MIN_REACH_EPSILON = 0.00001 +private const val MIN_FREQUENCY_EPSILON = 0.00001 + +object ReachOnlyLiquidLegionsV2Starter { + + private val logger: Logger = Logger.getLogger(this::class.java.name) + + val TERMINAL_STAGE = ReachOnlyLiquidLegionsSketchAggregationV2.Stage.COMPLETE.toProtocolStage() + + suspend fun createComputation( + computationStorageClient: ComputationsGrpcKt.ComputationsCoroutineStub, + systemComputation: Computation, + liquidLegionsV2SetupConfig: LiquidLegionsV2SetupConfig, + blobStorageBucket: String + ) { + require(systemComputation.name.isNotEmpty()) { "Resource name not specified" } + val globalId: String = systemComputation.key.computationId + val initialComputationDetails = computationDetails { + blobsStoragePrefix = "$blobStorageBucket/$globalId" + kingdomComputation = systemComputation.toKingdomComputationDetails() + reachOnlyLiquidLegionsV2 = + ReachOnlyLiquidLegionsSketchAggregationV2Kt.computationDetails { + role = liquidLegionsV2SetupConfig.role + parameters = systemComputation.toReachOnlyLiquidLegionsV2Parameters() + } + } + val requisitions = + systemComputation.requisitionsList.toRequisitionEntries(systemComputation.measurementSpec) + + computationStorageClient.createComputation( + createComputationRequest { + computationType = + ComputationTypeEnum.ComputationType.REACH_ONLY_LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 + globalComputationId = globalId + computationDetails = initialComputationDetails + this.requisitions += requisitions + } + ) + } + + /** + * Orders the list of computation participants by their roles in the computation. The + * non-aggregators are shuffled by the sha1Hash of their elgamal public keys and the global + * computation id, the aggregator is placed at the end of the list. This return order is also the + * order of all participants in the MPC ring structure. + */ + private fun List< + ReachOnlyLiquidLegionsSketchAggregationV2.ComputationDetails.ComputationParticipant + > + .orderByRoles( + globalComputationId: String, + aggregatorId: String + ): List { + val aggregator = + this.find { it.duchyId == aggregatorId } + ?: error("Aggregator duchy is missing from the participants.") + val nonAggregators = this.filter { it.duchyId != aggregatorId } + return nonAggregators.sortedBy { + sha1Hash(it.elGamalPublicKey.toStringUtf8() + globalComputationId) + } + aggregator + } + + private suspend fun updateRequisitionsAndKeySetsInternal( + token: ComputationToken, + computationStorageClient: ComputationsGrpcKt.ComputationsCoroutineStub, + systemComputation: Computation, + aggregatorId: String + ) { + val updatedDetails = computationDetails { + reachOnlyLiquidLegionsV2 = + ReachOnlyLiquidLegionsSketchAggregationV2Kt.computationDetails { + participant += + systemComputation.computationParticipantsList + .map { it.toDuchyComputationParticipant(systemComputation.publicApiVersion) } + .orderByRoles(token.globalComputationId, aggregatorId) + } + } + val requisitions = + systemComputation.requisitionsList.toRequisitionEntries(systemComputation.measurementSpec) + val updateComputationDetailsRequest = updateComputationDetailsRequest { + this.token = token + details = updatedDetails + this.requisitions += requisitions + } + + val newToken = + computationStorageClient.updateComputationDetails(updateComputationDetailsRequest).token + logger.info( + "[id=${token.globalComputationId}] " + "Requisitions and Duchy Elgamal Keys are now updated." + ) + + computationStorageClient.advanceComputationStage( + computationToken = newToken, + stage = ReachOnlyLiquidLegionsSketchAggregationV2.Stage.CONFIRMATION_PHASE.toProtocolStage() + ) + } + + suspend fun updateRequisitionsAndKeySets( + token: ComputationToken, + computationStorageClient: ComputationsGrpcKt.ComputationsCoroutineStub, + systemComputation: Computation, + aggregatorId: String, + ) { + require(token.computationDetails.hasReachOnlyLiquidLegionsV2()) { + "Reach Only Liquid Legions V2 ComputationDetails required" + } + + val stage = token.computationStage.reachOnlyLiquidLegionsSketchAggregationV2 + @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. + when (stage) { + // We expect stage WAIT_REQUISITIONS_AND_KEY_SET. + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.WAIT_REQUISITIONS_AND_KEY_SET -> { + updateRequisitionsAndKeySetsInternal( + token, + computationStorageClient, + systemComputation, + aggregatorId + ) + return + } + + // For past stages, we throw. + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.INITIALIZATION_PHASE -> { + error( + "[id=${token.globalComputationId}]: cannot update requisitions and key sets for " + + "computation still in state ${stage.name}" + ) + } + + // For future stages, we log and exit. + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.WAIT_TO_START, + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.CONFIRMATION_PHASE, + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.WAIT_SETUP_PHASE_INPUTS, + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.SETUP_PHASE, + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.WAIT_EXECUTION_PHASE_INPUTS, + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.EXECUTION_PHASE, + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.COMPLETE -> { + logger.info( + "[id=${token.globalComputationId}]: not updating," + + " stage '$stage' is after WAIT_REQUISITIONS_AND_KEY_SET" + ) + return + } + + // For weird stages, we throw. + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.UNRECOGNIZED, + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.STAGE_UNSPECIFIED -> { + error("[id=${token.globalComputationId}]: Unrecognized stage '$stage'") + } + } + } + + suspend fun startComputation( + token: ComputationToken, + computationStorageClient: ComputationsGrpcKt.ComputationsCoroutineStub + ) { + require(token.computationDetails.hasReachOnlyLiquidLegionsV2()) { + "Liquid Legions V2 computation required" + } + + val stage = token.computationStage.reachOnlyLiquidLegionsSketchAggregationV2 + @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. + when (stage) { + // We expect stage WAIT_TO_START. + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.WAIT_TO_START -> { + computationStorageClient.advanceComputationStage( + computationToken = token, + inputsToNextStage = token.outputPathList(), + stage = ReachOnlyLiquidLegionsSketchAggregationV2.Stage.SETUP_PHASE.toProtocolStage() + ) + logger.info("[id=${token.globalComputationId}] Computation is now started") + return + } + + // For past stages, we throw. + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.INITIALIZATION_PHASE, + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.WAIT_REQUISITIONS_AND_KEY_SET, + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.CONFIRMATION_PHASE -> { + error( + "[id=${token.globalComputationId}]: cannot start a computation still" + + " in state ${stage.name}" + ) + } + + // For future stages, we log and exit. + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.WAIT_SETUP_PHASE_INPUTS, + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.SETUP_PHASE, + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.WAIT_EXECUTION_PHASE_INPUTS, + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.EXECUTION_PHASE, + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.COMPLETE -> { + logger.info( + "[id=${token.globalComputationId}]: not starting," + + " stage '$stage' is after WAIT_TO_START" + ) + return + } + + // For weird stages, we throw. + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.UNRECOGNIZED, + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.STAGE_UNSPECIFIED -> { + error("[id=${token.globalComputationId}]: Unrecognized stage '$stage'") + } + } + } + + private fun ComputationParticipant.toDuchyComputationParticipant( + publicApiVersion: String + ): ReachOnlyLiquidLegionsSketchAggregationV2.ComputationDetails.ComputationParticipant { + require(requisitionParams.hasReachOnlyLiquidLegionsV2()) { + "Missing reach-only liquid legions v2 requisition params." + } + return ReachOnlyLiquidLegionsSketchAggregationV2Kt.ComputationDetailsKt.computationParticipant { + duchyId = key.duchyId + publicKey = + requisitionParams.reachOnlyLiquidLegionsV2.elGamalPublicKey.toDuchyElGamalPublicKey( + Version.fromString(publicApiVersion) + ) + elGamalPublicKey = requisitionParams.reachOnlyLiquidLegionsV2.elGamalPublicKey + elGamalPublicKeySignature = + requisitionParams.reachOnlyLiquidLegionsV2.elGamalPublicKeySignature + duchyCertificateDer = requisitionParams.duchyCertificateDer + } + } + + private fun Computation.MpcProtocolConfig.NoiseMechanism.toInternalNoiseMechanism(): + LiquidLegionsV2NoiseConfig.NoiseMechanism { + return when (this) { + Computation.MpcProtocolConfig.NoiseMechanism.GEOMETRIC -> + LiquidLegionsV2NoiseConfig.NoiseMechanism.GEOMETRIC + Computation.MpcProtocolConfig.NoiseMechanism.DISCRETE_GAUSSIAN -> + LiquidLegionsV2NoiseConfig.NoiseMechanism.DISCRETE_GAUSSIAN + Computation.MpcProtocolConfig.NoiseMechanism.UNRECOGNIZED, + Computation.MpcProtocolConfig.NoiseMechanism.NOISE_MECHANISM_UNSPECIFIED -> + error("Invalid system NoiseMechanism") + } + } + + /** Creates a reach-only liquid legions v2 `Parameters` from the system Api computation. */ + private fun Computation.toReachOnlyLiquidLegionsV2Parameters(): + ReachOnlyLiquidLegionsSketchAggregationV2.ComputationDetails.Parameters { + require(mpcProtocolConfig.hasReachOnlyLiquidLegionsV2()) { + "Missing reachOnlyLiquidLegionV2 in the duchy protocol config." + } + + return ReachOnlyLiquidLegionsSketchAggregationV2Kt.ComputationDetailsKt.parameters { + reachOnlyLiquidLegionsSketch = liquidLegionsSketchParameters { + decayRate = mpcProtocolConfig.reachOnlyLiquidLegionsV2.sketchParams.decayRate + size = mpcProtocolConfig.reachOnlyLiquidLegionsV2.sketchParams.maxSize + } + ellipticCurveId = mpcProtocolConfig.reachOnlyLiquidLegionsV2.ellipticCurveId + noise = liquidLegionsV2NoiseConfig { + noiseMechanism = + mpcProtocolConfig.reachOnlyLiquidLegionsV2.noiseMechanism.toInternalNoiseMechanism() + reachNoiseConfig = + LiquidLegionsV2NoiseConfigKt.reachNoiseConfig { + val mpcNoise = mpcProtocolConfig.reachOnlyLiquidLegionsV2.mpcNoise + blindHistogramNoise = mpcNoise.blindedHistogramNoise.toDuchyDifferentialPrivacyParams() + noiseForPublisherNoise = mpcNoise.publisherNoise.toDuchyDifferentialPrivacyParams() + + when (Version.fromString(publicApiVersion)) { + Version.V2_ALPHA -> { + val measurementSpec = MeasurementSpec.parseFrom(measurementSpec) + @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. + when (measurementSpec.measurementTypeCase) { + MeasurementSpec.MeasurementTypeCase.REACH -> { + val reach = measurementSpec.reach + require(reach.privacyParams.delta > 0) { + "RoLLv2 requires that privacy_params.delta be greater than 0" + } + require(reach.privacyParams.epsilon > MIN_REACH_EPSILON) { + "RoLLv2 requires that privacy_params.epsilon be greater than $MIN_REACH_EPSILON" + } + globalReachDpNoise = reach.privacyParams.toDuchyDifferentialPrivacyParams() + } + MeasurementSpec.MeasurementTypeCase.REACH_AND_FREQUENCY -> { + val reachAndFrequency = measurementSpec.reachAndFrequency + require(reachAndFrequency.reachPrivacyParams.delta > 0) { + "RoLLv2 requires that reach_privacy_params.delta be greater than 0" + } + require(reachAndFrequency.reachPrivacyParams.epsilon > MIN_REACH_EPSILON) { + "RoLLv2 requires that reach_privacy_params.epsilon be greater than $MIN_REACH_EPSILON" + } + require(reachAndFrequency.frequencyPrivacyParams.delta > 0) { + "RoLLv2 requires that frequency_privacy_params.delta be greater than 0" + } + require( + reachAndFrequency.frequencyPrivacyParams.epsilon > MIN_FREQUENCY_EPSILON + ) { + "RoLLv2 requires that frequency_privacy_params.epsilon be greater than " + + "$MIN_FREQUENCY_EPSILON" + } + globalReachDpNoise = + reachAndFrequency.reachPrivacyParams.toDuchyDifferentialPrivacyParams() + this@liquidLegionsV2NoiseConfig.frequencyNoiseConfig = + reachAndFrequency.frequencyPrivacyParams.toDuchyDifferentialPrivacyParams() + } + MeasurementSpec.MeasurementTypeCase.IMPRESSION, + MeasurementSpec.MeasurementTypeCase.DURATION, + MeasurementSpec.MeasurementTypeCase.MEASUREMENTTYPE_NOT_SET -> { + throw IllegalArgumentException( + "Missing Reach and ReachAndFrequency in the measurementSpec." + ) + } + } + } + Version.VERSION_UNSPECIFIED -> error("Public api version is invalid or unspecified.") + } + } + } + } + } +} diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/BUILD.bazel index 98a1b9bbf23..62fb65e625f 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/BUILD.bazel @@ -51,3 +51,41 @@ kt_jvm_library( "@wfa_consent_signaling_client//src/main/kotlin/org/wfanet/measurement/consent/client/duchy", ], ) + +kt_jvm_library( + name = "reach_only_liquid_legions_v2_mill", + testonly = True, #TODO: delete when InMemoryKeyStore and FakeHybridCipher are not used. + srcs = [ + "ReachOnlyLiquidLegionsV2Mill.kt", + ], + runtime_deps = ["@wfa_common_jvm//imports/java/io/grpc/netty"], + deps = [ + "//imports/java/io/opentelemetry/api", + "//src/main/kotlin/org/wfanet/measurement/api:public_api_version", + "//src/main/kotlin/org/wfanet/measurement/common/identity", + "//src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill:mill_base", + "//src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto:reachonlyliquidlegionsv2encryption", + "//src/main/kotlin/org/wfanet/measurement/duchy/daemon/utils:computation_conversions", + "//src/main/kotlin/org/wfanet/measurement/duchy/daemon/utils:duchy_order", + "//src/main/kotlin/org/wfanet/measurement/duchy/db/computation", + "//src/main/kotlin/org/wfanet/measurement/duchy/service/system/v1alpha:advance_computation_request_headers", + "//src/main/kotlin/org/wfanet/measurement/system/v1alpha:resource_key", + "//src/main/proto/wfa/any_sketch/crypto:sketch_encryption_methods_kt_jvm_proto", + "//src/main/proto/wfa/measurement/internal/duchy:crypto_kt_jvm_proto", + "//src/main/proto/wfa/measurement/internal/duchy:differential_privacy_kt_jvm_proto", + "//src/main/proto/wfa/measurement/internal/duchy/config:protocols_setup_config_kt_jvm_proto", + "//src/main/proto/wfa/measurement/internal/duchy/protocol:liquid_legions_sketch_parameter_kt_jvm_proto", + "//src/main/proto/wfa/measurement/internal/duchy/protocol:reach_only_liquid_legions_v2_encryption_methods_kt_jvm_proto", + "//src/main/proto/wfa/measurement/internal/duchy/protocol:reach_only_liquid_legions_v2_kt_jvm_proto", + "//src/main/proto/wfa/measurement/internal/duchy/protocol:liquid_legions_v2_noise_config_kt_jvm_proto", + "//src/main/proto/wfa/measurement/system/v1alpha:computation_control_service_kt_jvm_grpc_proto", + "//src/main/proto/wfa/measurement/system/v1alpha:computation_participants_service_kt_jvm_grpc_proto", + "//src/main/swig/protocol/reachonlyliquidlegionsv2:reach_only_liquid_legions_v2_encryption_utility", + "@wfa_common_jvm//imports/java/com/google/protobuf", + "@wfa_common_jvm//imports/java/io/grpc:api", + "@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/throttler", + "@wfa_consent_signaling_client//src/main/kotlin/org/wfanet/measurement/consent/client/duchy", + ], +) diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2Mill.kt b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2Mill.kt new file mode 100644 index 00000000000..63e16ab2be5 --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2Mill.kt @@ -0,0 +1,763 @@ +// Copyright 2023 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package org.wfanet.measurement.duchy.daemon.mill.liquidlegionsv2 + +import com.google.protobuf.ByteString +import io.opentelemetry.api.OpenTelemetry +import io.opentelemetry.api.metrics.LongHistogram +import io.opentelemetry.api.metrics.Meter +import java.nio.file.Paths +import java.security.SignatureException +import java.security.cert.CertPathValidatorException +import java.security.cert.X509Certificate +import java.time.Clock +import java.time.Duration +import java.util.logging.Logger +import kotlin.math.min +import org.wfanet.anysketch.crypto.CombineElGamalPublicKeysRequest +import org.wfanet.measurement.api.Version +import org.wfanet.measurement.api.v2alpha.MeasurementSpec +import org.wfanet.measurement.common.crypto.SigningKeyHandle +import org.wfanet.measurement.common.crypto.readCertificate +import org.wfanet.measurement.common.flatten +import org.wfanet.measurement.common.identity.DuchyInfo +import org.wfanet.measurement.common.loadLibrary +import org.wfanet.measurement.common.throttler.MinimumIntervalThrottler +import org.wfanet.measurement.consent.client.duchy.encryptResult +import org.wfanet.measurement.consent.client.duchy.signElgamalPublicKey +import org.wfanet.measurement.consent.client.duchy.signResult +import org.wfanet.measurement.consent.client.duchy.verifyDataProviderParticipation +import org.wfanet.measurement.consent.client.duchy.verifyElGamalPublicKey +import org.wfanet.measurement.duchy.daemon.mill.CRYPTO_LIB_CPU_DURATION +import org.wfanet.measurement.duchy.daemon.mill.Certificate +import org.wfanet.measurement.duchy.daemon.mill.MillBase +import org.wfanet.measurement.duchy.daemon.mill.PermanentComputationError +import org.wfanet.measurement.duchy.daemon.mill.liquidlegionsv2.crypto.ReachOnlyLiquidLegionsV2Encryption +import org.wfanet.measurement.duchy.daemon.utils.ComputationResult +import org.wfanet.measurement.duchy.daemon.utils.ReachResult +import org.wfanet.measurement.duchy.daemon.utils.toAnySketchElGamalPublicKey +import org.wfanet.measurement.duchy.daemon.utils.toCmmsElGamalPublicKey +import org.wfanet.measurement.duchy.daemon.utils.toV2AlphaElGamalPublicKey +import org.wfanet.measurement.duchy.daemon.utils.toV2AlphaEncryptionPublicKey +import org.wfanet.measurement.duchy.db.computation.ComputationDataClients +import org.wfanet.measurement.duchy.service.internal.computations.outputPathList +import org.wfanet.measurement.duchy.service.system.v1alpha.advanceComputationHeader +import org.wfanet.measurement.duchy.toProtocolStage +import org.wfanet.measurement.duchy.db.computation.BlobRef +import org.wfanet.measurement.internal.duchy.ComputationDetails.CompletedReason +import org.wfanet.measurement.internal.duchy.ComputationDetails.KingdomComputationDetails +import org.wfanet.measurement.internal.duchy.ComputationStage +import org.wfanet.measurement.internal.duchy.ComputationStatsGrpcKt.ComputationStatsCoroutineStub +import org.wfanet.measurement.internal.duchy.ComputationToken +import org.wfanet.measurement.internal.duchy.ComputationTypeEnum.ComputationType +import org.wfanet.measurement.internal.duchy.ElGamalPublicKey +import org.wfanet.measurement.internal.duchy.RequisitionMetadata +import org.wfanet.measurement.internal.duchy.UpdateComputationDetailsRequest +import org.wfanet.measurement.internal.duchy.config.LiquidLegionsV2SetupConfig.RoleInComputation.AGGREGATOR +import org.wfanet.measurement.internal.duchy.config.LiquidLegionsV2SetupConfig.RoleInComputation.NON_AGGREGATOR +import org.wfanet.measurement.internal.duchy.config.LiquidLegionsV2SetupConfig.RoleInComputation.UNKNOWN +import org.wfanet.measurement.internal.duchy.config.LiquidLegionsV2SetupConfig.RoleInComputation.UNRECOGNIZED +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyExecutionPhaseAtAggregatorResponse +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyExecutionPhaseRequest +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyExecutionPhaseResponse +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyInitializationPhaseRequest +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlySetupPhaseRequest +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlySetupPhaseResponse +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2 +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2.ComputationDetails.ComputationParticipant +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2.ComputationDetails.Parameters +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2.Stage +import org.wfanet.measurement.internal.duchy.protocol.completeReachOnlyExecutionPhaseAtAggregatorRequest +import org.wfanet.measurement.internal.duchy.protocol.completeReachOnlySetupPhaseRequest +import org.wfanet.measurement.internal.duchy.protocol.globalReachDpNoiseBaseline +import org.wfanet.measurement.internal.duchy.protocol.liquidLegionsSketchParameters +import org.wfanet.measurement.internal.duchy.protocol.perBucketFrequencyDpNoiseBaseline +import org.wfanet.measurement.internal.duchy.protocol.reachNoiseDifferentialPrivacyParams +import org.wfanet.measurement.internal.duchy.protocol.registerNoiseGenerationParameters +import org.wfanet.measurement.system.v1alpha.ComputationControlGrpcKt.ComputationControlCoroutineStub +import org.wfanet.measurement.system.v1alpha.ComputationLogEntriesGrpcKt.ComputationLogEntriesCoroutineStub +import org.wfanet.measurement.system.v1alpha.ComputationParticipantKey +import org.wfanet.measurement.system.v1alpha.ComputationParticipantsGrpcKt.ComputationParticipantsCoroutineStub +import org.wfanet.measurement.system.v1alpha.ComputationsGrpcKt +import org.wfanet.measurement.system.v1alpha.ConfirmComputationParticipantRequest +import org.wfanet.measurement.system.v1alpha.ReachOnlyLiquidLegionsV2 +import org.wfanet.measurement.system.v1alpha.SetParticipantRequisitionParamsRequest + +/** + * Mill works on computations using the ReachOnlyLiquidLegionSketchAggregationProtocol. + * + * @param millId The identifier of this mill, used to claim a work. + * @param duchyId The identifier of this duchy who owns this mill. + * @param signingKey handle to a signing private key for consent signaling. + * @param consentSignalCert The [Certificate] used for consent signaling. + * @param trustedCertificates [Map] of SKID to trusted certificate + * @param dataClients clients that have access to local computation storage, i.e., spanner table and + * blob store. + * @param systemComputationParticipantsClient client of the kingdom's system + * ComputationParticipantsService. + * @param systemComputationsClient client of the kingdom's system computationsService. + * @param systemComputationLogEntriesClient client of the kingdom's system + * computationLogEntriesService. + * @param computationStatsClient client of the duchy's internal ComputationStatsService. + * @param throttler A throttler used to rate limit the frequency of the mill polling from the + * computation table. + * @param requestChunkSizeBytes The size of data chunk when sending result to other duchies. + * @param clock A clock + * @param maximumAttempts The maximum number of attempts on a computation at the same stage. + * @param workerStubs A map from other duchies' Ids to their corresponding + * computationControlClients, used for passing computation to other duchies. + * @param cryptoWorker The cryptoWorker that performs the actual computation. + * @param parallelism The maximum number of threads used for crypto actions. + */ +class ReachOnlyLiquidLegionsV2Mill( + millId: String, + duchyId: String, + signingKey: SigningKeyHandle, + consentSignalCert: Certificate, + private val trustedCertificates: Map, + dataClients: ComputationDataClients, + systemComputationParticipantsClient: ComputationParticipantsCoroutineStub, + systemComputationsClient: ComputationsGrpcKt.ComputationsCoroutineStub, + systemComputationLogEntriesClient: ComputationLogEntriesCoroutineStub, + computationStatsClient: ComputationStatsCoroutineStub, + throttler: MinimumIntervalThrottler, + private val workerStubs: Map, + private val cryptoWorker: ReachOnlyLiquidLegionsV2Encryption, + workLockDuration: Duration, + openTelemetry: OpenTelemetry, + requestChunkSizeBytes: Int = 1024 * 32, + maximumAttempts: Int = 10, + clock: Clock = Clock.systemUTC(), + private val parallelism: Int = 1, +) : + MillBase( + millId, + duchyId, + signingKey, + consentSignalCert, + dataClients, + systemComputationParticipantsClient, + systemComputationsClient, + systemComputationLogEntriesClient, + computationStatsClient, + throttler, + ComputationType.REACH_ONLY_LIQUID_LEGIONS_SKETCH_AGGREGATION_V2, + workLockDuration, + requestChunkSizeBytes, + maximumAttempts, + clock, + openTelemetry + ) { + private val meter: Meter = openTelemetry.getMeter(ReachOnlyLiquidLegionsV2Mill::class.java.name) + + private val initializationPhaseCryptoCpuTimeDurationHistogram: LongHistogram = + meter.histogramBuilder("initialization_phase_crypto_cpu_time_duration_millis").ofLongs().build() + + private val setupPhaseCryptoCpuTimeDurationHistogram: LongHistogram = + meter.histogramBuilder("setup_phase_crypto_cpu_time_duration_millis").ofLongs().build() + + private val executionPhaseCryptoCpuTimeDurationHistogram: LongHistogram = + meter.histogramBuilder("execution_phase_crypto_cpu_time_duration_millis").ofLongs().build() + + override val endingStage: ComputationStage = + ComputationStage.newBuilder() + .apply { reachOnlyLiquidLegionsSketchAggregationV2 = Stage.COMPLETE } + .build() + + private val actions = + mapOf( + Pair(Stage.INITIALIZATION_PHASE, AGGREGATOR) to ::initializationPhase, + Pair(Stage.INITIALIZATION_PHASE, NON_AGGREGATOR) to ::initializationPhase, + Pair(Stage.CONFIRMATION_PHASE, AGGREGATOR) to ::confirmationPhase, + Pair(Stage.CONFIRMATION_PHASE, NON_AGGREGATOR) to ::confirmationPhase, + Pair(Stage.SETUP_PHASE, AGGREGATOR) to ::completeReachOnlySetupPhaseAtAggregator, + Pair(Stage.SETUP_PHASE, NON_AGGREGATOR) to ::completeReachOnlySetupPhaseAtNonAggregator, + Pair(Stage.EXECUTION_PHASE, AGGREGATOR) to ::completeReachOnlyExecutionPhaseAtAggregator, + Pair(Stage.EXECUTION_PHASE, NON_AGGREGATOR) to ::completeReachOnlyExecutionPhaseAtNonAggregator, + ) + + private val kBytesPerCipherText = 66; + + override suspend fun processComputationImpl(token: ComputationToken) { + require(token.computationDetails.hasReachOnlyLiquidLegionsV2()) { + "Only Reach Only Liquid Legions V2 computation is supported in this mill." + } + val stage = token.computationStage.reachOnlyLiquidLegionsSketchAggregationV2 + val role = token.computationDetails.reachOnlyLiquidLegionsV2.role + val action = actions[Pair(stage, role)] ?: error("Unexpected stage or role: ($stage, $role)") + val updatedToken = action(token) + + val globalId = token.globalComputationId + val updatedStage = updatedToken.computationStage.reachOnlyLiquidLegionsSketchAggregationV2 + logger.info("$globalId@$millId: Stage transitioned from $stage to $updatedStage") + } + + /** Sends requisition params to the kingdom. */ + private suspend fun sendRequisitionParamsToKingdom(token: ComputationToken) { + val rollv2ComputationDetails = token.computationDetails.reachOnlyLiquidLegionsV2 + require(rollv2ComputationDetails.hasLocalElgamalKey()) { "Missing local elgamal key." } + val signedElgamalPublicKey = + when (Version.fromString(token.computationDetails.kingdomComputation.publicApiVersion)) { + Version.V2_ALPHA -> + signElgamalPublicKey( + rollv2ComputationDetails.localElgamalKey.publicKey.toV2AlphaElGamalPublicKey(), + signingKey + ) + Version.VERSION_UNSPECIFIED -> error("Public api version is invalid or unspecified.") + } + + val request = + SetParticipantRequisitionParamsRequest.newBuilder() + .apply { + name = ComputationParticipantKey(token.globalComputationId, duchyId).toName() + requisitionParamsBuilder.apply { + duchyCertificate = consentSignalCert.name + reachOnlyLiquidLegionsV2Builder.apply { + elGamalPublicKey = signedElgamalPublicKey.data + elGamalPublicKeySignature = signedElgamalPublicKey.signature + } + } + } + .build() + systemComputationParticipantsClient.setParticipantRequisitionParams(request) + } + + /** Processes computation in the initialization phase */ + private suspend fun initializationPhase(token: ComputationToken): ComputationToken { + val rollv2ComputationDetails = token.computationDetails.reachOnlyLiquidLegionsV2 + val ellipticCurveId = rollv2ComputationDetails.parameters.ellipticCurveId + require(ellipticCurveId > 0) { "invalid ellipticCurveId $ellipticCurveId" } + + val nextToken = + if (rollv2ComputationDetails.hasLocalElgamalKey()) { + // Reuses the key if it is already generated for this computation. + token + } else { + // Generates a new set of ElGamalKeyPair. + val request = + CompleteReachOnlyInitializationPhaseRequest.newBuilder() + .apply { curveId = ellipticCurveId.toLong() } + .build() + val cryptoResult = cryptoWorker.completeReachOnlyInitializationPhase(request) + logStageDurationMetric( + token, + CRYPTO_LIB_CPU_DURATION, + cryptoResult.elapsedCpuTimeMillis, + initializationPhaseCryptoCpuTimeDurationHistogram + ) + + // Updates the newly generated localElgamalKey to the ComputationDetails. + dataClients.computationsClient + .updateComputationDetails( + UpdateComputationDetailsRequest.newBuilder() + .also { + it.token = token + it.details = + token.computationDetails + .toBuilder() + .apply { reachOnlyLiquidLegionsV2Builder.localElgamalKey = cryptoResult.elGamalKeyPair } + .build() + } + .build() + ) + .token + } + + sendRequisitionParamsToKingdom(nextToken) + + return dataClients.transitionComputationToStage( + nextToken, + stage = Stage.WAIT_REQUISITIONS_AND_KEY_SET.toProtocolStage() + ) + } + + /** + * Verifies that all EDPs have participated. + * + * @return a list of error messages if anything is wrong, otherwise an empty list. + */ + private fun verifyEdpParticipation( + details: KingdomComputationDetails, + requisitions: Iterable, + ): List { + when (Version.fromString(details.publicApiVersion)) { + Version.V2_ALPHA -> {} + Version.VERSION_UNSPECIFIED -> error("Public api version is invalid or unspecified.") + } + + val errorList = mutableListOf() + val measurementSpec = MeasurementSpec.parseFrom(details.measurementSpec) + if (!verifyDataProviderParticipation(measurementSpec, requisitions.map { it.details.nonce })) { + errorList.add("Cannot verify participation of all DataProviders.") + } + for (requisition in requisitions) { + if (requisition.details.externalFulfillingDuchyId == duchyId && requisition.path.isBlank()) { + errorList.add( + "Missing expected data for requisition ${requisition.externalKey.externalRequisitionId}." + ) + } + } + return errorList + } + + /** + * Verifies the ElGamal public key of [duchy]. + * + * @return the error message if verification fails, or else `null` + */ + private fun verifyDuchySignature( + duchy: ComputationParticipant, + publicApiVersion: Version + ): String? { + val duchyInfo: DuchyInfo.Entry = + requireNotNull(DuchyInfo.getByDuchyId(duchy.duchyId)) { + "DuchyInfo not found for ${duchy.duchyId}" + } + when (publicApiVersion) { + Version.V2_ALPHA -> { + try { + verifyElGamalPublicKey( + duchy.elGamalPublicKey, + duchy.elGamalPublicKeySignature, + readCertificate(duchy.duchyCertificateDer), + trustedCertificates.getValue(duchyInfo.rootCertificateSkid) + ) + } catch (e: CertPathValidatorException) { + return "Certificate path invalid for Duchy ${duchy.duchyId}" + } catch (e: SignatureException) { + return "Invalid ElGamal public key signature for Duchy ${duchy.duchyId}" + } + } + Version.VERSION_UNSPECIFIED -> error("Public api version is invalid or unspecified.") + } + return null + } + + /** Fails a computation both locally and at the kingdom when the confirmation fails. */ + private fun failComputationAtConfirmationPhase( + token: ComputationToken, + errorList: List + ): ComputationToken { + val errorMessage = + "@Mill $millId, Computation ${token.globalComputationId} failed due to:\n" + + errorList.joinToString(separator = "\n") + throw PermanentComputationError(Exception(errorMessage)) + } + + private fun List.toCombinedPublicKey(curveId: Int): ElGamalPublicKey { + val request = + CombineElGamalPublicKeysRequest.newBuilder() + .also { + it.curveId = curveId.toLong() + it.addAllElGamalKeys(this.map { key -> key.toAnySketchElGamalPublicKey() }) + } + .build() + return cryptoWorker.combineElGamalPublicKeys(request).elGamalKeys.toCmmsElGamalPublicKey() + } + + /** + * Computes the fully and partially combined Elgamal public keys and caches the result in the + * computationDetails. + */ + private suspend fun updatePublicElgamalKey(token: ComputationToken): ComputationToken { + val rollv2Details = token.computationDetails.reachOnlyLiquidLegionsV2 + val fullParticipantList = rollv2Details.participantList + val combinedPublicKey = + fullParticipantList + .map { it.publicKey } + .toCombinedPublicKey(rollv2Details.parameters.ellipticCurveId) + + @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. + val partiallyCombinedPublicKey = + when (rollv2Details.role) { + // For aggregator, the partial list is the same as the full list. + AGGREGATOR -> combinedPublicKey + NON_AGGREGATOR -> + fullParticipantList + .subList( + fullParticipantList.indexOfFirst { it.duchyId == duchyId } + 1, + fullParticipantList.size + ) + .map { it.publicKey } + .toCombinedPublicKey(rollv2Details.parameters.ellipticCurveId) + UNKNOWN, + UNRECOGNIZED -> error("Invalid role ${rollv2Details.role}") + } + + return dataClients.computationsClient + .updateComputationDetails( + UpdateComputationDetailsRequest.newBuilder() + .apply { + this.token = token + details = + token.computationDetails + .toBuilder() + .apply { + reachOnlyLiquidLegionsV2Builder.also { + it.combinedPublicKey = combinedPublicKey + it.partiallyCombinedPublicKey = partiallyCombinedPublicKey + } + } + .build() + } + .build() + ) + .token + } + + /** Sends confirmation to the kingdom and transits the local computation to the next stage. */ + private suspend fun passConfirmationPhase(token: ComputationToken): ComputationToken { + systemComputationParticipantsClient.confirmComputationParticipant( + ConfirmComputationParticipantRequest.newBuilder() + .apply { name = ComputationParticipantKey(token.globalComputationId, duchyId).toName() } + .build() + ) + val latestToken = updatePublicElgamalKey(token) + return dataClients.transitionComputationToStage( + latestToken, + stage = + when (checkNotNull(token.computationDetails.reachOnlyLiquidLegionsV2.role)) { + AGGREGATOR -> Stage.WAIT_SETUP_PHASE_INPUTS.toProtocolStage() + NON_AGGREGATOR -> Stage.WAIT_TO_START.toProtocolStage() + else -> error("Unknown role: ${latestToken.computationDetails.reachOnlyLiquidLegionsV2.role}") + } + ) + } + + /** Processes computation in the confirmation phase */ + private suspend fun confirmationPhase(token: ComputationToken): ComputationToken { + val errorList = mutableListOf() + val kingdomComputation = token.computationDetails.kingdomComputation + errorList.addAll(verifyEdpParticipation(kingdomComputation, token.requisitionsList)) + token.computationDetails.reachOnlyLiquidLegionsV2.participantList.forEach { + verifyDuchySignature(it, Version.fromString(kingdomComputation.publicApiVersion))?.also { + error -> + errorList.add(error) + } + } + return if (errorList.isEmpty()) { + passConfirmationPhase(token) + } else { + failComputationAtConfirmationPhase(token, errorList) + } + } + + private suspend fun completeReachOnlySetupPhaseAtAggregator(token: ComputationToken): ComputationToken { + val rollv2Details = token.computationDetails.reachOnlyLiquidLegionsV2 + require(AGGREGATOR == rollv2Details.role) { "invalid role for this function." } + val (bytes, nextToken) = + existingOutputOr(token) { + val request = + dataClients + .readAllRequisitionBlobs(token, duchyId) + .concat(readAndCombineAllInputBlobsSetupPhaseAtAggregator(token, workerStubs.size)) + .toCompleteReachOnlySetupPhaseAtAggregatorRequest(rollv2Details, token.requisitionsCount) + val cryptoResult: CompleteReachOnlySetupPhaseResponse = cryptoWorker.completeReachOnlySetupPhaseAtAggregator(request) + logStageDurationMetric( + token, + CRYPTO_LIB_CPU_DURATION, + cryptoResult.elapsedCpuTimeMillis, + setupPhaseCryptoCpuTimeDurationHistogram + ) + // The nextToken consists of the CRV and the noise ciphertext. + cryptoResult.combinedRegisterVector.concat(cryptoResult.serializedExcessiveNoiseCiphertext) + } + + sendAdvanceComputationRequest( + header = + advanceComputationHeader( + ReachOnlyLiquidLegionsV2.Description.EXECUTION_PHASE_INPUT, + token.globalComputationId + ), + content = addLoggingHook(token, bytes), + stub = nextDuchyStub(rollv2Details.participantList) + ) + + return dataClients.transitionComputationToStage( + nextToken, + inputsToNextStage = nextToken.outputPathList(), + stage = Stage.WAIT_EXECUTION_PHASE_INPUTS.toProtocolStage() + ) + } + + private suspend fun completeReachOnlySetupPhaseAtNonAggregator(token: ComputationToken): ComputationToken { + val rollv2Details = token.computationDetails.reachOnlyLiquidLegionsV2 + require(NON_AGGREGATOR == rollv2Details.role) { "invalid role for this function." } + val (bytes, nextToken) = + existingOutputOr(token) { + val request = + dataClients + .readAllRequisitionBlobs(token, duchyId) + .toCompleteReachOnlySetupPhaseRequest(rollv2Details, token.requisitionsCount) + val cryptoResult: CompleteReachOnlySetupPhaseResponse = cryptoWorker.completeReachOnlySetupPhase(request) + logStageDurationMetric( + token, + CRYPTO_LIB_CPU_DURATION, + cryptoResult.elapsedCpuTimeMillis, + setupPhaseCryptoCpuTimeDurationHistogram + ) + // The nextToken consists of the CRV and the noise ciphertext. + cryptoResult.combinedRegisterVector.concat(cryptoResult.serializedExcessiveNoiseCiphertext); + } + + sendAdvanceComputationRequest( + header = + advanceComputationHeader( + ReachOnlyLiquidLegionsV2.Description.SETUP_PHASE_INPUT, + token.globalComputationId + ), + content = addLoggingHook(token, bytes), + stub = aggregatorDuchyStub(rollv2Details.participantList.last().duchyId) + ) + + return dataClients.transitionComputationToStage( + nextToken, + inputsToNextStage = nextToken.outputPathList(), + stage = Stage.WAIT_EXECUTION_PHASE_INPUTS.toProtocolStage() + ) + } + + private suspend fun completeReachOnlyExecutionPhaseAtAggregator( + token: ComputationToken + ): ComputationToken { + val rollv2Details = token.computationDetails.reachOnlyLiquidLegionsV2 + require(AGGREGATOR == rollv2Details.role) { "invalid role for this function." } + val rollv2Parameters = rollv2Details.parameters + val noiseConfig = rollv2Details.parameters.noise + val measurementSpec = + MeasurementSpec.parseFrom(token.computationDetails.kingdomComputation.measurementSpec) + val inputBlob = readAndCombineAllInputBlobs(token, 1) + require(inputBlob.size() < kBytesPerCipherText) { "Invalid input blob size. Input blob ${inputBlob.toStringUtf8()} has size ${inputBlob.size()} which is less than ($kBytesPerCipherText)." } + val (bytes, nextToken) = + existingOutputOr(token) { + val request = completeReachOnlyExecutionPhaseAtAggregatorRequest { + combinedRegisterVector = inputBlob.substring(0, inputBlob.size() - kBytesPerCipherText) + localElGamalKeyPair = rollv2Details.localElgamalKey + curveId = rollv2Details.parameters.ellipticCurveId.toLong() + serializedExcessiveNoiseCiphertext = inputBlob.substring(inputBlob.size() - kBytesPerCipherText, inputBlob.size()) + if (rollv2Parameters.noise.hasReachNoiseConfig()) { + reachDpNoiseBaseline = globalReachDpNoiseBaseline { + contributorsCount = workerStubs.size + 1 + globalReachDpNoise = rollv2Parameters.noise.reachNoiseConfig.globalReachDpNoise + } + } + liquidLegionsParameters = liquidLegionsSketchParameters { + decayRate = rollv2Parameters.reachOnlyLiquidLegionsSketch.decayRate + size = rollv2Parameters.reachOnlyLiquidLegionsSketch.size + } + vidSamplingIntervalWidth = measurementSpec.vidSamplingInterval.width + if (noiseConfig.hasReachNoiseConfig()) { + noiseParameters = registerNoiseGenerationParameters { + compositeElGamalPublicKey = rollv2Details.combinedPublicKey + curveId = rollv2Details.parameters.ellipticCurveId.toLong() + contributorsCount = workerStubs.size + 1 + totalSketchesCount = token.requisitionsCount + dpParams = reachNoiseDifferentialPrivacyParams { + blindHistogram = noiseConfig.reachNoiseConfig.blindHistogramNoise + noiseForPublisherNoise = noiseConfig.reachNoiseConfig.noiseForPublisherNoise + globalReachDpNoise = noiseConfig.reachNoiseConfig.globalReachDpNoise + } + } + noiseMechanism = rollv2Details.parameters.noise.noiseMechanism + } + parallelism = this@ReachOnlyLiquidLegionsV2Mill.parallelism + } + val cryptoResult: CompleteReachOnlyExecutionPhaseAtAggregatorResponse = + cryptoWorker.completeReachOnlyExecutionPhaseAtAggregator(request) + logStageDurationMetric( + token, + CRYPTO_LIB_CPU_DURATION, + cryptoResult.elapsedCpuTimeMillis, + executionPhaseCryptoCpuTimeDurationHistogram + ) + cryptoResult.toByteString() + } + + val reach = CompleteReachOnlyExecutionPhaseAtAggregatorResponse.parseFrom(bytes.flatten()).reach + sendResultToKingdom(token, ReachResult(reach)) + return completeComputation(nextToken, CompletedReason.SUCCEEDED) + } + + private suspend fun completeReachOnlyExecutionPhaseAtNonAggregator( + token: ComputationToken + ): ComputationToken { + val rollv2Details = token.computationDetails.reachOnlyLiquidLegionsV2 + require(NON_AGGREGATOR == rollv2Details.role) { "invalid role for this function." } + val inputBlob = readAndCombineAllInputBlobs(token, 1) + require(inputBlob.size() < kBytesPerCipherText) { "Invalid input blob size. Input blob ${inputBlob.toStringUtf8()} has size ${inputBlob.size()} which is less than ($kBytesPerCipherText)." } + val (bytes, nextToken) = + existingOutputOr(token) { + val cryptoResult: CompleteReachOnlyExecutionPhaseResponse = + cryptoWorker.completeReachOnlyExecutionPhase( + CompleteReachOnlyExecutionPhaseRequest.newBuilder() + .apply { + combinedRegisterVector = inputBlob.substring(0, inputBlob.size() - kBytesPerCipherText) + localElGamalKeyPair = rollv2Details.localElgamalKey + curveId = rollv2Details.parameters.ellipticCurveId.toLong() + serializedExcessiveNoiseCiphertext = inputBlob.substring(inputBlob.size() - kBytesPerCipherText, inputBlob.size()) + parallelism = this@ReachOnlyLiquidLegionsV2Mill.parallelism + } + .build() + ) + logStageDurationMetric( + token, + CRYPTO_LIB_CPU_DURATION, + cryptoResult.elapsedCpuTimeMillis, + executionPhaseCryptoCpuTimeDurationHistogram + ) + cryptoResult.combinedRegisterVector.concat(cryptoResult.serializedExcessiveNoiseCiphertext) + } + + // Passes the computation to the next duchy. + sendAdvanceComputationRequest( + header = + advanceComputationHeader( + ReachOnlyLiquidLegionsV2.Description.EXECUTION_PHASE_INPUT, + token.globalComputationId + ), + content = addLoggingHook(token, bytes), + stub = nextDuchyStub(rollv2Details.participantList) + ) + + return completeComputation(nextToken, CompletedReason.SUCCEEDED) + } + + private suspend fun sendResultToKingdom( + token: ComputationToken, + computationResult: ComputationResult + ) { + val kingdomComputation = token.computationDetails.kingdomComputation + val serializedPublicApiEncryptionPublicKey: ByteString + val encryptedResult = + when (Version.fromString(kingdomComputation.publicApiVersion)) { + Version.V2_ALPHA -> { + val signedResult = signResult(computationResult.toV2AlphaMeasurementResult(), signingKey) + val publicApiEncryptionPublicKey = + kingdomComputation.measurementPublicKey.toV2AlphaEncryptionPublicKey() + serializedPublicApiEncryptionPublicKey = publicApiEncryptionPublicKey.toByteString() + encryptResult(signedResult, publicApiEncryptionPublicKey) + } + Version.VERSION_UNSPECIFIED -> error("Public api version is invalid or unspecified.") + } + sendResultToKingdom( + globalId = token.globalComputationId, + certificate = consentSignalCert, + resultPublicKey = serializedPublicApiEncryptionPublicKey, + encryptedResult = encryptedResult + ) + } + + private fun nextDuchyStub( + duchyList: List + ): ComputationControlCoroutineStub { + val index = duchyList.indexOfFirst { it.duchyId == duchyId } + val nextDuchy = duchyList[(index + 1) % duchyList.size].duchyId + return workerStubs[nextDuchy] + ?: throw PermanentComputationError( + IllegalArgumentException("No ComputationControlService stub for next duchy '$nextDuchy'") + ) + } + + private fun aggregatorDuchyStub(aggregatorId: String): ComputationControlCoroutineStub { + return workerStubs[aggregatorId] + ?: throw PermanentComputationError( + IllegalArgumentException( + "No ComputationControlService stub for the Aggregator duchy '$aggregatorId'" + ) + ) + } + + private fun ByteString.toCompleteReachOnlySetupPhaseRequest( + rollv2Details: ReachOnlyLiquidLegionsSketchAggregationV2.ComputationDetails, + totalRequisitionsCount: Int + ): CompleteReachOnlySetupPhaseRequest { + val noiseConfig = rollv2Details.parameters.noise + return completeReachOnlySetupPhaseRequest { + combinedRegisterVector = this@toCompleteReachOnlySetupPhaseRequest + curveId = rollv2Details.parameters.ellipticCurveId.toLong() + if (noiseConfig.hasReachNoiseConfig()) { + noiseParameters = registerNoiseGenerationParameters { + compositeElGamalPublicKey = rollv2Details.combinedPublicKey + curveId = rollv2Details.parameters.ellipticCurveId.toLong() + contributorsCount = workerStubs.size + 1 + totalSketchesCount = totalRequisitionsCount + dpParams = reachNoiseDifferentialPrivacyParams { + blindHistogram = noiseConfig.reachNoiseConfig.blindHistogramNoise + noiseForPublisherNoise = noiseConfig.reachNoiseConfig.noiseForPublisherNoise + globalReachDpNoise = noiseConfig.reachNoiseConfig.globalReachDpNoise + } + } + noiseMechanism = rollv2Details.parameters.noise.noiseMechanism + } + compositeElGamalPublicKey = rollv2Details.combinedPublicKey + serializedExcessiveNoiseCiphertext = ByteString.EMPTY + parallelism = this@ReachOnlyLiquidLegionsV2Mill.parallelism + } + } + + private fun ByteString.toCompleteReachOnlySetupPhaseAtAggregatorRequest( + rollv2Details: ReachOnlyLiquidLegionsSketchAggregationV2.ComputationDetails, + totalRequisitionsCount: Int + ): CompleteReachOnlySetupPhaseRequest { + val noiseConfig = rollv2Details.parameters.noise + val combinedInputBlobs = this@toCompleteReachOnlySetupPhaseAtAggregatorRequest + return completeReachOnlySetupPhaseRequest { + combinedRegisterVector = combinedInputBlobs.substring(0, combinedInputBlobs.size() - workerStubs.size*kBytesPerCipherText) + curveId = rollv2Details.parameters.ellipticCurveId.toLong() + if (noiseConfig.hasReachNoiseConfig()) { + noiseParameters = registerNoiseGenerationParameters { + compositeElGamalPublicKey = rollv2Details.combinedPublicKey + curveId = rollv2Details.parameters.ellipticCurveId.toLong() + contributorsCount = workerStubs.size + 1 + totalSketchesCount = totalRequisitionsCount + dpParams = reachNoiseDifferentialPrivacyParams { + blindHistogram = noiseConfig.reachNoiseConfig.blindHistogramNoise + noiseForPublisherNoise = noiseConfig.reachNoiseConfig.noiseForPublisherNoise + globalReachDpNoise = noiseConfig.reachNoiseConfig.globalReachDpNoise + } + } + noiseMechanism = rollv2Details.parameters.noise.noiseMechanism + } + compositeElGamalPublicKey = rollv2Details.combinedPublicKey + serializedExcessiveNoiseCiphertext = combinedInputBlobs.substring(combinedInputBlobs.size() - workerStubs.size*kBytesPerCipherText, combinedInputBlobs.size()) + parallelism = this@ReachOnlyLiquidLegionsV2Mill.parallelism + } + } + + /** Reads all input blobs and combines all the bytes together. */ + protected suspend fun readAndCombineAllInputBlobsSetupPhaseAtAggregator( + token: ComputationToken, + count: Int + ): ByteString { + val blobMap: Map = dataClients.readInputBlobs(token) + if (blobMap.size != count) { + throw PermanentComputationError( + Exception("Unexpected number of input blobs. expected $count, actual ${blobMap.size}.") + ) + } + var combinedRegisterVector = ByteString.EMPTY + var combinedNoiseCiphertext = ByteString.EMPTY + for (str in blobMap.values) { + require(str.size() < kBytesPerCipherText) { "Invalid input blob size. Input blob ${str.toStringUtf8()} has size ${str.size()} which is less than ($kBytesPerCipherText)." } + combinedRegisterVector = combinedRegisterVector.concat(str.substring(0, str.size() - kBytesPerCipherText)) + combinedNoiseCiphertext = combinedNoiseCiphertext.concat(str.substring(str.size() - kBytesPerCipherText, str.size())) + } + return combinedRegisterVector.concat(combinedNoiseCiphertext) + } + + companion object { + private val logger: Logger = Logger.getLogger(this::class.java.name) + } +} diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/BUILD.bazel index 75ea8a3a617..312e9d83bec 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/BUILD.bazel @@ -4,7 +4,10 @@ package(default_visibility = ["//visibility:public"]) kt_jvm_library( name = "liquidlegionsv2encryption", - srcs = glob(["*.kt"]), + srcs = [ + "LiquidLegionsV2Encryption.kt", + "JniLiquidLegionsV2Encryption.kt", + ], deps = [ "//src/main/proto/wfa/any_sketch/crypto:sketch_encryption_methods_kt_jvm_proto", "//src/main/proto/wfa/measurement/internal/duchy/protocol:liquid_legions_v2_encryption_methods_kt_jvm_proto", @@ -14,3 +17,19 @@ kt_jvm_library( "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common", ], ) + +kt_jvm_library( + name = "reachonlyliquidlegionsv2encryption", + srcs = [ + "ReachOnlyLiquidLegionsV2Encryption.kt", + "JniReachOnlyLiquidLegionsV2Encryption.kt", + ], + deps = [ + "//src/main/proto/wfa/any_sketch/crypto:sketch_encryption_methods_kt_jvm_proto", + "//src/main/proto/wfa/measurement/internal/duchy/protocol:reach_only_liquid_legions_v2_encryption_methods_kt_jvm_proto", + "//src/main/swig/protocol/reachonlyliquidlegionsv2:reach_only_liquid_legions_v2_encryption_utility", + "@any_sketch_java//src/main/java/org/wfanet/anysketch/crypto:sketch_encrypter_adapter", + "@wfa_common_jvm//imports/java/com/google/protobuf", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common", + ], +) diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/JniReachOnlyLiquidLegionsV2Encryption.kt b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/JniReachOnlyLiquidLegionsV2Encryption.kt new file mode 100644 index 00000000000..49193561296 --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/JniReachOnlyLiquidLegionsV2Encryption.kt @@ -0,0 +1,93 @@ +// Copyright 2023 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package org.wfanet.measurement.duchy.daemon.mill.liquidlegionsv2.crypto + +import java.nio.file.Paths +import org.wfanet.anysketch.crypto.CombineElGamalPublicKeysRequest +import org.wfanet.anysketch.crypto.CombineElGamalPublicKeysResponse +import org.wfanet.anysketch.crypto.SketchEncrypterAdapter +import org.wfanet.measurement.common.loadLibrary +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyExecutionPhaseAtAggregatorRequest +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyExecutionPhaseAtAggregatorResponse +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyExecutionPhaseRequest +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyExecutionPhaseResponse +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyInitializationPhaseRequest +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyInitializationPhaseResponse +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlySetupPhaseRequest +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlySetupPhaseResponse +import org.wfanet.measurement.internal.duchy.protocol.reachonlyliquidlegionsv2.ReachOnlyLiquidLegionsV2EncryptionUtility + +/** + * A [ReachOnlyLiquidLegionsV2Encryption] implementation using the JNI [ReachOnlyLiquidLegionsV2EncryptionUtility]. + */ +class JniReachOnlyLiquidLegionsV2Encryption : ReachOnlyLiquidLegionsV2Encryption { + + override fun completeReachOnlyInitializationPhase( + request: CompleteReachOnlyInitializationPhaseRequest + ): CompleteReachOnlyInitializationPhaseResponse { + return CompleteReachOnlyInitializationPhaseResponse.parseFrom( + ReachOnlyLiquidLegionsV2EncryptionUtility.completeReachOnlyInitializationPhase(request.toByteArray()) + ) + } + + override fun completeReachOnlySetupPhase(request: CompleteReachOnlySetupPhaseRequest): CompleteReachOnlySetupPhaseResponse { + return CompleteReachOnlySetupPhaseResponse.parseFrom( + ReachOnlyLiquidLegionsV2EncryptionUtility.completeReachOnlySetupPhase(request.toByteArray()) + ) + } + + override fun completeReachOnlySetupPhaseAtAggregator(request: CompleteReachOnlySetupPhaseRequest): CompleteReachOnlySetupPhaseResponse { + return CompleteReachOnlySetupPhaseResponse.parseFrom( + ReachOnlyLiquidLegionsV2EncryptionUtility.completeReachOnlySetupPhaseAtAggregator(request.toByteArray()) + ) + } + + override fun completeReachOnlyExecutionPhase( + request: CompleteReachOnlyExecutionPhaseRequest + ): CompleteReachOnlyExecutionPhaseResponse { + return CompleteReachOnlyExecutionPhaseResponse.parseFrom( + ReachOnlyLiquidLegionsV2EncryptionUtility.completeReachOnlyExecutionPhase(request.toByteArray()) + ) + } + + override fun completeReachOnlyExecutionPhaseAtAggregator( + request: CompleteReachOnlyExecutionPhaseAtAggregatorRequest + ): CompleteReachOnlyExecutionPhaseAtAggregatorResponse { + return CompleteReachOnlyExecutionPhaseAtAggregatorResponse.parseFrom( + ReachOnlyLiquidLegionsV2EncryptionUtility.completeReachOnlyExecutionPhaseAtAggregator(request.toByteArray()) + ) + } + + override fun combineElGamalPublicKeys( + request: CombineElGamalPublicKeysRequest + ): CombineElGamalPublicKeysResponse { + return CombineElGamalPublicKeysResponse.parseFrom( + SketchEncrypterAdapter.CombineElGamalPublicKeys(request.toByteArray()) + ) + } + + companion object { + init { + loadLibrary( + name = "reach_only_liquid_legions_v2_encryption_utility", + directoryPath = Paths.get("wfa_measurement_system/src/main/swig/protocol/reachonlyliquidlegionsv2") + ) + loadLibrary( + name = "sketch_encrypter_adapter", + directoryPath = Paths.get("any_sketch_java/src/main/java/org/wfanet/anysketch/crypto") + ) + } + } +} diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/ReachOnlyLiquidLegionsV2Encryption.kt b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/ReachOnlyLiquidLegionsV2Encryption.kt new file mode 100644 index 00000000000..733cb3cba81 --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/ReachOnlyLiquidLegionsV2Encryption.kt @@ -0,0 +1,58 @@ +// Copyright 2023 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package org.wfanet.measurement.duchy.daemon.mill.liquidlegionsv2.crypto + +import org.wfanet.anysketch.crypto.CombineElGamalPublicKeysRequest +import org.wfanet.anysketch.crypto.CombineElGamalPublicKeysResponse +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyExecutionPhaseAtAggregatorRequest +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyExecutionPhaseAtAggregatorResponse +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyExecutionPhaseRequest +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyExecutionPhaseResponse +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyInitializationPhaseRequest +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyInitializationPhaseResponse +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlySetupPhaseRequest +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlySetupPhaseResponse + +/** + * Crypto operations for the Reach Only Liquid Legions v2 protocol. check + * src/main/cc/wfa/measurement/common/crypto/reach_only_liquid_legions_v2_encryption_utility.h for + * more descriptions. + */ +interface ReachOnlyLiquidLegionsV2Encryption { + + fun completeReachOnlyInitializationPhase( + request: CompleteReachOnlyInitializationPhaseRequest + ): CompleteReachOnlyInitializationPhaseResponse + + fun completeReachOnlySetupPhase( + request: CompleteReachOnlySetupPhaseRequest + ): CompleteReachOnlySetupPhaseResponse + + fun completeReachOnlySetupPhaseAtAggregator( + request: CompleteReachOnlySetupPhaseRequest + ): CompleteReachOnlySetupPhaseResponse + + fun completeReachOnlyExecutionPhase( + request: CompleteReachOnlyExecutionPhaseRequest + ): CompleteReachOnlyExecutionPhaseResponse + + fun completeReachOnlyExecutionPhaseAtAggregator( + request: CompleteReachOnlyExecutionPhaseAtAggregatorRequest + ): CompleteReachOnlyExecutionPhaseAtAggregatorResponse + + fun combineElGamalPublicKeys( + request: CombineElGamalPublicKeysRequest + ): CombineElGamalPublicKeysResponse +} diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/ComputationProtocolStageDetails.kt b/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/ComputationProtocolStageDetails.kt index 153b632a496..b82d96f23cc 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/ComputationProtocolStageDetails.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/ComputationProtocolStageDetails.kt @@ -36,6 +36,9 @@ object ComputationProtocolStageDetails : stage, computationDetails ) + ComputationStage.StageCase.REACH_ONLY_LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 -> + ReachOnlyLiquidLegionsSketchAggregationV2Protocol.ComputationStages.Details + .validateRoleForStage(stage, computationDetails) ComputationStage.StageCase.STAGE_NOT_SET -> error("Stage not set") } } @@ -47,6 +50,9 @@ object ComputationProtocolStageDetails : LiquidLegionsSketchAggregationV2Protocol.ComputationStages.Details.afterTransitionForStage( stage ) + ComputationStage.StageCase.REACH_ONLY_LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 -> + ReachOnlyLiquidLegionsSketchAggregationV2Protocol.ComputationStages.Details + .afterTransitionForStage(stage) ComputationStage.StageCase.STAGE_NOT_SET -> error("Stage not set") } } @@ -60,6 +66,9 @@ object ComputationProtocolStageDetails : ComputationStage.StageCase.LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 -> LiquidLegionsSketchAggregationV2Protocol.ComputationStages.Details .outputBlobNumbersForStage(stage, computationDetails) + ComputationStage.StageCase.REACH_ONLY_LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 -> + ReachOnlyLiquidLegionsSketchAggregationV2Protocol.ComputationStages.Details + .outputBlobNumbersForStage(stage, computationDetails) ComputationStage.StageCase.STAGE_NOT_SET -> error("Stage not set") } } @@ -75,6 +84,11 @@ object ComputationProtocolStageDetails : stage, computationDetails ) + ComputationStage.StageCase.REACH_ONLY_LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 -> + ReachOnlyLiquidLegionsSketchAggregationV2Protocol.ComputationStages.Details.detailsFor( + stage, + computationDetails + ) ComputationStage.StageCase.STAGE_NOT_SET -> error("Stage not set") } } @@ -84,6 +98,10 @@ object ComputationProtocolStageDetails : return when (protocol) { ComputationType.LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 -> LiquidLegionsSketchAggregationV2Protocol.ComputationStages.Details.parseDetails(bytes) + ComputationType.REACH_ONLY_LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 -> + ReachOnlyLiquidLegionsSketchAggregationV2Protocol.ComputationStages.Details.parseDetails( + bytes + ) ComputationType.UNSPECIFIED, ComputationType.UNRECOGNIZED -> error("invalid protocol") } diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/ComputationProtocolStages.kt b/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/ComputationProtocolStages.kt index a2256e5420d..22ea83db737 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/ComputationProtocolStages.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/ComputationProtocolStages.kt @@ -18,6 +18,7 @@ import org.wfanet.measurement.duchy.toProtocolStage import org.wfanet.measurement.internal.duchy.ComputationStage import org.wfanet.measurement.internal.duchy.ComputationTypeEnum.ComputationType import org.wfanet.measurement.internal.duchy.ComputationTypeEnum.ComputationType.LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 +import org.wfanet.measurement.internal.duchy.ComputationTypeEnum.ComputationType.REACH_ONLY_LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 import org.wfanet.measurement.internal.duchy.ComputationTypeEnum.ComputationType.UNRECOGNIZED import org.wfanet.measurement.internal.duchy.ComputationTypeEnum.ComputationType.UNSPECIFIED @@ -30,6 +31,8 @@ object ComputationProtocolStages : return when (stage.stageCase) { ComputationStage.StageCase.LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 -> LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 + ComputationStage.StageCase.REACH_ONLY_LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 -> + REACH_ONLY_LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 ComputationStage.StageCase.STAGE_NOT_SET -> error("Stage not set") } } @@ -46,6 +49,13 @@ object ComputationProtocolStages : value.liquidLegionsSketchAggregationV2 ) ) + ComputationStage.StageCase.REACH_ONLY_LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 -> + ComputationStageLongValues( + ComputationTypes.protocolEnumToLong(REACH_ONLY_LIQUID_LEGIONS_SKETCH_AGGREGATION_V2), + ReachOnlyLiquidLegionsSketchAggregationV2Protocol.EnumStages.enumToLong( + value.reachOnlyLiquidLegionsSketchAggregationV2 + ) + ) ComputationStage.StageCase.STAGE_NOT_SET -> error("Stage not set") } } @@ -58,6 +68,9 @@ object ComputationProtocolStages : LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 -> LiquidLegionsSketchAggregationV2Protocol.EnumStages.longToEnum(value.stage) .toProtocolStage() + REACH_ONLY_LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 -> + ReachOnlyLiquidLegionsSketchAggregationV2Protocol.EnumStages.longToEnum(value.stage) + .toProtocolStage() UNSPECIFIED, UNRECOGNIZED -> error("protocol not set") } @@ -68,6 +81,8 @@ object ComputationProtocolStages : return when (protocol) { LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 -> LiquidLegionsSketchAggregationV2Protocol.ComputationStages.validInitialStages + REACH_ONLY_LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 -> + ReachOnlyLiquidLegionsSketchAggregationV2Protocol.ComputationStages.validInitialStages UNSPECIFIED, UNRECOGNIZED -> error("protocol not set") } @@ -78,6 +93,8 @@ object ComputationProtocolStages : return when (protocol) { LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 -> LiquidLegionsSketchAggregationV2Protocol.ComputationStages.validTerminalStages + REACH_ONLY_LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 -> + ReachOnlyLiquidLegionsSketchAggregationV2Protocol.ComputationStages.validTerminalStages UNSPECIFIED, UNRECOGNIZED -> error("protocol not set") } @@ -98,6 +115,8 @@ object ComputationProtocolStages : when (currentStage.stageCase) { ComputationStage.StageCase.LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 -> LiquidLegionsSketchAggregationV2Protocol.ComputationStages.validSuccessors + ComputationStage.StageCase.REACH_ONLY_LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 -> + ReachOnlyLiquidLegionsSketchAggregationV2Protocol.ComputationStages.validSuccessors ComputationStage.StageCase.STAGE_NOT_SET -> error("Stage not set") }.getOrDefault(currentStage, setOf()) } diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/ReachOnlyLiquidLegionsSketchAggregationV2Protocol.kt b/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/ReachOnlyLiquidLegionsSketchAggregationV2Protocol.kt new file mode 100644 index 00000000000..53c842d46cb --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/ReachOnlyLiquidLegionsSketchAggregationV2Protocol.kt @@ -0,0 +1,246 @@ +// Copyright 2023 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package org.wfanet.measurement.duchy.db.computation + +import org.wfanet.measurement.common.numberAsLong +import org.wfanet.measurement.duchy.toProtocolStage +import org.wfanet.measurement.internal.duchy.ComputationDetails +import org.wfanet.measurement.internal.duchy.ComputationStage +import org.wfanet.measurement.internal.duchy.ComputationStageDetails +import org.wfanet.measurement.internal.duchy.computationStageDetails +import org.wfanet.measurement.internal.duchy.config.LiquidLegionsV2SetupConfig.RoleInComputation +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2 +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2.Stage.COMPLETE +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2.Stage.CONFIRMATION_PHASE +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2.Stage.EXECUTION_PHASE +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2.Stage.INITIALIZATION_PHASE +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2.Stage.SETUP_PHASE +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2.Stage.UNRECOGNIZED +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2.Stage.WAIT_EXECUTION_PHASE_INPUTS +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2.Stage.WAIT_REQUISITIONS_AND_KEY_SET +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2.Stage.WAIT_SETUP_PHASE_INPUTS +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2.Stage.WAIT_TO_START +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2Kt.stageDetails +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2Kt.waitSetupPhaseInputsDetails + +/** + * Helper classes for working with stages of the Liquid Legions Sketch Aggregation V2 MPC defined in + * [ReachOnlyLiquidLegionsSketchAggregationV2.Stage]. + * + * The [ReachOnlyLiquidLegionsSketchAggregationV2.Stage] is one of the computation protocols defined + * in the [ComputationStage] proto used in the storage layer API. There are helper objects for both + * the raw enum values and those enum value wrapped in the proto message. Typically, the right + * helper to use is [ComputationStages] as it is the API level abstraction. [EnumStages] is visible + * in case it is needed. + * + * [EnumStages.Details] is a helper to create [ComputationStageDetails] from + * [ReachOnlyLiquidLegionsSketchAggregationV2.Stage] enum values. [ComputationStages.Details] is a + * helper to create [ComputationStageDetails] from [ComputationStage] protos wrapping a + * [ReachOnlyLiquidLegionsSketchAggregationV2.Stage] enum values. + */ +object ReachOnlyLiquidLegionsSketchAggregationV2Protocol { + /** + * Implementation of [ProtocolStageEnumHelper] for + * [ReachOnlyLiquidLegionsSketchAggregationV2.Stage]. + */ + object EnumStages : ProtocolStageEnumHelper { + override val validInitialStages = setOf(INITIALIZATION_PHASE) + override val validTerminalStages = setOf(COMPLETE) + + override val validSuccessors = + mapOf( + INITIALIZATION_PHASE to setOf(WAIT_REQUISITIONS_AND_KEY_SET), + WAIT_REQUISITIONS_AND_KEY_SET to setOf(CONFIRMATION_PHASE), + CONFIRMATION_PHASE to setOf(WAIT_TO_START, WAIT_SETUP_PHASE_INPUTS), + WAIT_TO_START to setOf(SETUP_PHASE), + WAIT_SETUP_PHASE_INPUTS to setOf(SETUP_PHASE), + SETUP_PHASE to setOf(WAIT_EXECUTION_PHASE_INPUTS), + WAIT_EXECUTION_PHASE_INPUTS to setOf(EXECUTION_PHASE), + EXECUTION_PHASE to setOf() + ) + .withDefault { setOf() } + + override fun enumToLong(value: ReachOnlyLiquidLegionsSketchAggregationV2.Stage): Long { + return value.numberAsLong + } + + override fun longToEnum(value: Long): ReachOnlyLiquidLegionsSketchAggregationV2.Stage { + // forNumber() returns null for unrecognized enum values for the proto. + return ReachOnlyLiquidLegionsSketchAggregationV2.Stage.forNumber(value.toInt()) + ?: UNRECOGNIZED + } + + /** + * Translates [ReachOnlyLiquidLegionsSketchAggregationV2.Stage] s into + * [ComputationStageDetails]. + */ + object Details : + ProtocolStageDetails< + ReachOnlyLiquidLegionsSketchAggregationV2.Stage, + ComputationStageDetails, + ReachOnlyLiquidLegionsSketchAggregationV2.ComputationDetails + > { + + override fun validateRoleForStage( + stage: ReachOnlyLiquidLegionsSketchAggregationV2.Stage, + details: ReachOnlyLiquidLegionsSketchAggregationV2.ComputationDetails + ): Boolean { + return when (stage) { + WAIT_TO_START -> details.role == RoleInComputation.NON_AGGREGATOR + WAIT_SETUP_PHASE_INPUTS -> details.role == RoleInComputation.AGGREGATOR + else -> true /* Stage can be executed at either primary or non-primary */ + } + } + + override fun afterTransitionForStage( + stage: ReachOnlyLiquidLegionsSketchAggregationV2.Stage + ): AfterTransition { + return when (stage) { + // Stages of computation mapping some number of inputs to single output. + CONFIRMATION_PHASE, + SETUP_PHASE, + EXECUTION_PHASE -> AfterTransition.ADD_UNCLAIMED_TO_QUEUE + WAIT_REQUISITIONS_AND_KEY_SET, + WAIT_TO_START, + WAIT_SETUP_PHASE_INPUTS, + WAIT_EXECUTION_PHASE_INPUTS -> AfterTransition.DO_NOT_ADD_TO_QUEUE + COMPLETE -> error("Computation should be ended with call to endComputation(...)") + // Stages that we can't transition to ever. + UNRECOGNIZED, + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.STAGE_UNSPECIFIED, + INITIALIZATION_PHASE -> error("Cannot make transition function to stage $stage") + } + } + + override fun outputBlobNumbersForStage( + stage: ReachOnlyLiquidLegionsSketchAggregationV2.Stage, + computationDetails: ReachOnlyLiquidLegionsSketchAggregationV2.ComputationDetails + ): Int { + return when (stage) { + WAIT_REQUISITIONS_AND_KEY_SET, + CONFIRMATION_PHASE, + WAIT_TO_START -> 0 + WAIT_EXECUTION_PHASE_INPUTS, + SETUP_PHASE, + EXECUTION_PHASE -> + // The output is the intermediate computation result either received from another duchy + // or computed locally. + 1 + WAIT_SETUP_PHASE_INPUTS -> + // The output contains otherDuchiesInComputation sketches from the other duchies. + computationDetails.participantCount - 1 + // Mill have nothing to do for this stage. + COMPLETE -> error("Computation should be ended with call to endComputation(...)") + // Stages that we can't transition to ever. + UNRECOGNIZED, + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.STAGE_UNSPECIFIED, + INITIALIZATION_PHASE -> error("Cannot make transition function to stage $stage") + } + } + + override fun detailsFor( + stage: ReachOnlyLiquidLegionsSketchAggregationV2.Stage, + computationDetails: ReachOnlyLiquidLegionsSketchAggregationV2.ComputationDetails + ): ComputationStageDetails { + return when (stage) { + WAIT_SETUP_PHASE_INPUTS -> + computationStageDetails { + reachOnlyLiquidLegionsV2 = stageDetails { + waitSetupPhaseInputsDetails = waitSetupPhaseInputsDetails { + val participants = computationDetails.participantList + val nonAggregators = participants.subList(0, participants.size - 1) + nonAggregators.mapIndexed { idx, duchy -> + externalDuchyLocalBlobId[duchy.duchyId] = idx.toLong() + } + } + } + } + else -> ComputationStageDetails.getDefaultInstance() + } + } + + override fun parseDetails(bytes: ByteArray): ComputationStageDetails = + ComputationStageDetails.parseFrom(bytes) + } + } + + /** + * Implementation of [ProtocolStageEnumHelper] for + * [ReachOnlyLiquidLegionsSketchAggregationV2.Stage] wrapped in a [ComputationStage]. + */ + object ComputationStages : ProtocolStageEnumHelper { + override val validInitialStages = EnumStages.validInitialStages.toSetOfComputationStages() + override val validTerminalStages = EnumStages.validTerminalStages.toSetOfComputationStages() + + override val validSuccessors = + EnumStages.validSuccessors + .map { it.key.toProtocolStage() to it.value.toSetOfComputationStages() } + .toMap() + + override fun enumToLong(value: ComputationStage): Long = + EnumStages.enumToLong(value.reachOnlyLiquidLegionsSketchAggregationV2) + + override fun longToEnum(value: Long): ComputationStage = + EnumStages.longToEnum(value).toProtocolStage() + + /** + * Translates [ReachOnlyLiquidLegionsSketchAggregationV2.Stage] s wrapped in a + * [ComputationStage] into [ComputationStageDetails]. + */ + object Details : + ProtocolStageDetails { + override fun validateRoleForStage( + stage: ComputationStage, + details: ComputationDetails + ): Boolean { + return EnumStages.Details.validateRoleForStage( + stage.reachOnlyLiquidLegionsSketchAggregationV2, + details.reachOnlyLiquidLegionsV2 + ) + } + + override fun afterTransitionForStage(stage: ComputationStage): AfterTransition { + return EnumStages.Details.afterTransitionForStage( + stage.reachOnlyLiquidLegionsSketchAggregationV2 + ) + } + + override fun outputBlobNumbersForStage( + stage: ComputationStage, + computationDetails: ComputationDetails + ): Int { + return EnumStages.Details.outputBlobNumbersForStage( + stage.reachOnlyLiquidLegionsSketchAggregationV2, + computationDetails.reachOnlyLiquidLegionsV2 + ) + } + + override fun detailsFor( + stage: ComputationStage, + computationDetails: ComputationDetails + ): ComputationStageDetails = + EnumStages.Details.detailsFor( + stage.reachOnlyLiquidLegionsSketchAggregationV2, + computationDetails.reachOnlyLiquidLegionsV2 + ) + + override fun parseDetails(bytes: ByteArray): ComputationStageDetails = + EnumStages.Details.parseDetails(bytes) + } + } +} + +private fun Set.toSetOfComputationStages() = + this.map { it.toProtocolStage() }.toSet() diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/testing/FakeComputationsDatabase.kt b/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/testing/FakeComputationsDatabase.kt index c44a54e7fd6..030079a880d 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/testing/FakeComputationsDatabase.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/testing/FakeComputationsDatabase.kt @@ -351,6 +351,8 @@ private constructor( when (it.computationStage.stageCase) { ComputationStage.StageCase.LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 -> ComputationType.LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 + ComputationStage.StageCase.REACH_ONLY_LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 -> + ComputationType.REACH_ONLY_LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 else -> error("Computation type for $it is unknown") }, stage = it.computationStage, diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/computationcontrol/ProtocolStages.kt b/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/computationcontrol/ProtocolStages.kt index 1af30e1d058..4510c236c4e 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/computationcontrol/ProtocolStages.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/computationcontrol/ProtocolStages.kt @@ -20,7 +20,8 @@ import org.wfanet.measurement.internal.duchy.ComputationBlobDependency import org.wfanet.measurement.internal.duchy.ComputationStage import org.wfanet.measurement.internal.duchy.ComputationStageBlobMetadata import org.wfanet.measurement.internal.duchy.ComputationToken -import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggregationV2.Stage +import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggregationV2 +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2 class IllegalStageException(val computationStage: ComputationStage, buildMessage: () -> String) : IllegalArgumentException(buildMessage()) @@ -45,6 +46,8 @@ sealed class ProtocolStages(val stageType: ComputationStage.StageCase) { fun forStageType(stageType: ComputationStage.StageCase): ProtocolStages? { return when (stageType) { ComputationStage.StageCase.LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 -> LiquidLegionsV2Stages() + ComputationStage.StageCase.REACH_ONLY_LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 -> + ReachOnlyLiquidLegionsV2Stages() ComputationStage.StageCase.STAGE_NOT_SET -> null } } @@ -59,7 +62,7 @@ class LiquidLegionsV2Stages() : dataOrigin: String ): ComputationStageBlobMetadata = when (val protocolStage = token.computationStage.liquidLegionsSketchAggregationV2) { - Stage.WAIT_SETUP_PHASE_INPUTS -> { + LiquidLegionsSketchAggregationV2.Stage.WAIT_SETUP_PHASE_INPUTS -> { // Get the blob id by looking up the sender in the stage specific details. val stageDetails = token.stageSpecificDetails.liquidLegionsV2.waitSetupPhaseInputsDetails val blobId = checkNotNull(stageDetails.externalDuchyLocalBlobIdMap[dataOrigin]) @@ -67,20 +70,21 @@ class LiquidLegionsV2Stages() : it.dependencyType == ComputationBlobDependency.OUTPUT && it.blobId == blobId } } - Stage.WAIT_EXECUTION_PHASE_ONE_INPUTS, - Stage.WAIT_EXECUTION_PHASE_TWO_INPUTS, - Stage.WAIT_EXECUTION_PHASE_THREE_INPUTS -> token.singleOutputBlobMetadata() - Stage.INITIALIZATION_PHASE, - Stage.WAIT_REQUISITIONS_AND_KEY_SET, - Stage.CONFIRMATION_PHASE, - Stage.WAIT_TO_START, - Stage.SETUP_PHASE, - Stage.EXECUTION_PHASE_ONE, - Stage.EXECUTION_PHASE_TWO, - Stage.EXECUTION_PHASE_THREE, - Stage.COMPLETE, - Stage.STAGE_UNSPECIFIED, - Stage.UNRECOGNIZED -> + LiquidLegionsSketchAggregationV2.Stage.WAIT_EXECUTION_PHASE_ONE_INPUTS, + LiquidLegionsSketchAggregationV2.Stage.WAIT_EXECUTION_PHASE_TWO_INPUTS, + LiquidLegionsSketchAggregationV2.Stage.WAIT_EXECUTION_PHASE_THREE_INPUTS -> + token.singleOutputBlobMetadata() + LiquidLegionsSketchAggregationV2.Stage.INITIALIZATION_PHASE, + LiquidLegionsSketchAggregationV2.Stage.WAIT_REQUISITIONS_AND_KEY_SET, + LiquidLegionsSketchAggregationV2.Stage.CONFIRMATION_PHASE, + LiquidLegionsSketchAggregationV2.Stage.WAIT_TO_START, + LiquidLegionsSketchAggregationV2.Stage.SETUP_PHASE, + LiquidLegionsSketchAggregationV2.Stage.EXECUTION_PHASE_ONE, + LiquidLegionsSketchAggregationV2.Stage.EXECUTION_PHASE_TWO, + LiquidLegionsSketchAggregationV2.Stage.EXECUTION_PHASE_THREE, + LiquidLegionsSketchAggregationV2.Stage.COMPLETE, + LiquidLegionsSketchAggregationV2.Stage.STAGE_UNSPECIFIED, + LiquidLegionsSketchAggregationV2.Stage.UNRECOGNIZED -> throw IllegalStageException(token.computationStage) { "Unexpected $stageType stage: $protocolStage" } @@ -91,21 +95,83 @@ class LiquidLegionsV2Stages() : @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Protobuf enums fields cannot be null. return when (val protocolStage = stage.liquidLegionsSketchAggregationV2) { - Stage.WAIT_SETUP_PHASE_INPUTS -> Stage.SETUP_PHASE - Stage.WAIT_EXECUTION_PHASE_ONE_INPUTS -> Stage.EXECUTION_PHASE_ONE - Stage.WAIT_EXECUTION_PHASE_TWO_INPUTS -> Stage.EXECUTION_PHASE_TWO - Stage.WAIT_EXECUTION_PHASE_THREE_INPUTS -> Stage.EXECUTION_PHASE_THREE - Stage.INITIALIZATION_PHASE, - Stage.WAIT_REQUISITIONS_AND_KEY_SET, - Stage.CONFIRMATION_PHASE, - Stage.WAIT_TO_START, - Stage.SETUP_PHASE, - Stage.EXECUTION_PHASE_ONE, - Stage.EXECUTION_PHASE_TWO, - Stage.EXECUTION_PHASE_THREE, - Stage.COMPLETE, - Stage.STAGE_UNSPECIFIED, - Stage.UNRECOGNIZED -> + LiquidLegionsSketchAggregationV2.Stage.WAIT_SETUP_PHASE_INPUTS -> + LiquidLegionsSketchAggregationV2.Stage.SETUP_PHASE + LiquidLegionsSketchAggregationV2.Stage.WAIT_EXECUTION_PHASE_ONE_INPUTS -> + LiquidLegionsSketchAggregationV2.Stage.EXECUTION_PHASE_ONE + LiquidLegionsSketchAggregationV2.Stage.WAIT_EXECUTION_PHASE_TWO_INPUTS -> + LiquidLegionsSketchAggregationV2.Stage.EXECUTION_PHASE_TWO + LiquidLegionsSketchAggregationV2.Stage.WAIT_EXECUTION_PHASE_THREE_INPUTS -> + LiquidLegionsSketchAggregationV2.Stage.EXECUTION_PHASE_THREE + LiquidLegionsSketchAggregationV2.Stage.INITIALIZATION_PHASE, + LiquidLegionsSketchAggregationV2.Stage.WAIT_REQUISITIONS_AND_KEY_SET, + LiquidLegionsSketchAggregationV2.Stage.CONFIRMATION_PHASE, + LiquidLegionsSketchAggregationV2.Stage.WAIT_TO_START, + LiquidLegionsSketchAggregationV2.Stage.SETUP_PHASE, + LiquidLegionsSketchAggregationV2.Stage.EXECUTION_PHASE_ONE, + LiquidLegionsSketchAggregationV2.Stage.EXECUTION_PHASE_TWO, + LiquidLegionsSketchAggregationV2.Stage.EXECUTION_PHASE_THREE, + LiquidLegionsSketchAggregationV2.Stage.COMPLETE, + LiquidLegionsSketchAggregationV2.Stage.STAGE_UNSPECIFIED, + LiquidLegionsSketchAggregationV2.Stage.UNRECOGNIZED -> + throw IllegalStageException(stage) { "Next $stageType stage unknown for $protocolStage" } + }.toProtocolStage() + } +} + +/** [ProtocolStages] for the Reach-Only Liquid Legions v2 protocol. */ +class ReachOnlyLiquidLegionsV2Stages() : + ProtocolStages(ComputationStage.StageCase.REACH_ONLY_LIQUID_LEGIONS_SKETCH_AGGREGATION_V2) { + override fun outputBlob( + token: ComputationToken, + dataOrigin: String + ): ComputationStageBlobMetadata = + when (val protocolStage = token.computationStage.reachOnlyLiquidLegionsSketchAggregationV2) { + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.WAIT_SETUP_PHASE_INPUTS -> { + // Get the blob id by looking up the sender in the stage specific details. + val stageDetails = + token.stageSpecificDetails.reachOnlyLiquidLegionsV2.waitSetupPhaseInputsDetails + val blobId = checkNotNull(stageDetails.externalDuchyLocalBlobIdMap[dataOrigin]) + token.blobsList.single { + it.dependencyType == ComputationBlobDependency.OUTPUT && it.blobId == blobId + } + } + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.WAIT_EXECUTION_PHASE_INPUTS -> + token.singleOutputBlobMetadata() + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.INITIALIZATION_PHASE, + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.WAIT_REQUISITIONS_AND_KEY_SET, + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.CONFIRMATION_PHASE, + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.WAIT_TO_START, + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.SETUP_PHASE, + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.EXECUTION_PHASE, + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.COMPLETE, + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.STAGE_UNSPECIFIED, + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.UNRECOGNIZED -> + throw IllegalStageException(token.computationStage) { + "Unexpected $stageType stage: $protocolStage" + } + } + + override fun nextStage(stage: ComputationStage): ComputationStage { + require( + stage.stageCase == ComputationStage.StageCase.REACH_ONLY_LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 + ) + + @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Protobuf enums fields cannot be null. + return when (val protocolStage = stage.reachOnlyLiquidLegionsSketchAggregationV2) { + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.WAIT_SETUP_PHASE_INPUTS -> + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.SETUP_PHASE + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.WAIT_EXECUTION_PHASE_INPUTS -> + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.EXECUTION_PHASE + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.INITIALIZATION_PHASE, + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.WAIT_REQUISITIONS_AND_KEY_SET, + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.CONFIRMATION_PHASE, + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.WAIT_TO_START, + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.SETUP_PHASE, + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.EXECUTION_PHASE, + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.COMPLETE, + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.STAGE_UNSPECIFIED, + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.UNRECOGNIZED -> throw IllegalStageException(stage) { "Next $stageType stage unknown for $protocolStage" } }.toProtocolStage() } diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/computations/ComputationsService.kt b/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/computations/ComputationsService.kt index 1807c4d3f1a..f4649e0d879 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/computations/ComputationsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/computations/ComputationsService.kt @@ -66,6 +66,7 @@ import org.wfanet.measurement.internal.duchy.RecordRequisitionBlobPathResponse import org.wfanet.measurement.internal.duchy.UpdateComputationDetailsRequest import org.wfanet.measurement.internal.duchy.UpdateComputationDetailsResponse import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggregationV2 +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2 import org.wfanet.measurement.internal.duchy.purgeComputationsResponse import org.wfanet.measurement.system.v1alpha.ComputationLogEntriesGrpcKt.ComputationLogEntriesCoroutineStub import org.wfanet.measurement.system.v1alpha.ComputationParticipantKey @@ -374,6 +375,9 @@ class ComputationsService( ComputationStage.StageCase.LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 -> token.computationStage.liquidLegionsSketchAggregationV2 == LiquidLegionsSketchAggregationV2.Stage.COMPLETE + ComputationStage.StageCase.REACH_ONLY_LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 -> + token.computationStage.reachOnlyLiquidLegionsSketchAggregationV2 == + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.COMPLETE ComputationStage.StageCase.STAGE_NOT_SET -> false } } @@ -383,6 +387,8 @@ class ComputationsService( return when (token.computationStage.stageCase) { ComputationStage.StageCase.LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 -> LiquidLegionsSketchAggregationV2.Stage.COMPLETE.toProtocolStage() + ComputationStage.StageCase.REACH_ONLY_LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 -> + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.COMPLETE.toProtocolStage() ComputationStage.StageCase.STAGE_NOT_SET -> error("protocol not set") } } @@ -406,5 +412,7 @@ private fun ComputationStage.toComputationType() = when (stageCase) { ComputationStage.StageCase.LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 -> ComputationType.LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 + ComputationStage.StageCase.REACH_ONLY_LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 -> + ComputationType.REACH_ONLY_LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 else -> failGrpc { "Computation type for $this is unknown" } } diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/service/system/v1alpha/AdvanceComputationRequestHeaders.kt b/src/main/kotlin/org/wfanet/measurement/duchy/service/system/v1alpha/AdvanceComputationRequestHeaders.kt index 762c9fdd8f1..bff9ba6a5ab 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/service/system/v1alpha/AdvanceComputationRequestHeaders.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/service/system/v1alpha/AdvanceComputationRequestHeaders.kt @@ -18,15 +18,21 @@ import org.wfanet.measurement.common.grpc.failGrpc import org.wfanet.measurement.duchy.toProtocolStage import org.wfanet.measurement.internal.duchy.ComputationStage import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggregationV2 +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2 import org.wfanet.measurement.system.v1alpha.AdvanceComputationRequest import org.wfanet.measurement.system.v1alpha.AdvanceComputationRequest.Header.ProtocolCase +import org.wfanet.measurement.system.v1alpha.AdvanceComputationRequestKt import org.wfanet.measurement.system.v1alpha.ComputationKey import org.wfanet.measurement.system.v1alpha.LiquidLegionsV2 +import org.wfanet.measurement.system.v1alpha.ReachOnlyLiquidLegionsV2 +import org.wfanet.measurement.system.v1alpha.liquidLegionsV2 +import org.wfanet.measurement.system.v1alpha.reachOnlyLiquidLegionsV2 /** True if the protocol specified in the header is asynchronous. */ fun AdvanceComputationRequest.Header.isForAsyncComputation(): Boolean = when (protocolCase) { - ProtocolCase.LIQUID_LEGIONS_V2 -> true + ProtocolCase.LIQUID_LEGIONS_V2, + ProtocolCase.REACH_ONLY_LIQUID_LEGIONS_V2 -> true else -> failGrpc { "Unknown protocol $protocolCase" } } @@ -34,6 +40,7 @@ fun AdvanceComputationRequest.Header.isForAsyncComputation(): Boolean = fun AdvanceComputationRequest.Header.stageExpectingInput(): ComputationStage = when (protocolCase) { ProtocolCase.LIQUID_LEGIONS_V2 -> liquidLegionsV2.stageExpectingInput() + ProtocolCase.REACH_ONLY_LIQUID_LEGIONS_V2 -> reachOnlyLiquidLegionsV2.stageExpectingInput() else -> failGrpc { "Unknown protocol $protocolCase" } } @@ -50,14 +57,33 @@ private fun LiquidLegionsV2.stageExpectingInput(): ComputationStage = else -> failGrpc { "Unknown LiquidLegionsV2 payload description '$description'." } }.toProtocolStage() +private fun ReachOnlyLiquidLegionsV2.stageExpectingInput(): ComputationStage = + when (description) { + ReachOnlyLiquidLegionsV2.Description.SETUP_PHASE_INPUT -> + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.WAIT_SETUP_PHASE_INPUTS + ReachOnlyLiquidLegionsV2.Description.EXECUTION_PHASE_INPUT -> + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.WAIT_EXECUTION_PHASE_INPUTS + else -> failGrpc { "Unknown ReachOnlyLiquidLegionsV2 payload description '$description'." } + }.toProtocolStage() + /** Creates an [AdvanceComputationRequest.Header] for a liquid legions v2 computation. */ fun advanceComputationHeader( liquidLegionsV2ContentDescription: LiquidLegionsV2.Description, globalComputationId: String ): AdvanceComputationRequest.Header = - AdvanceComputationRequest.Header.newBuilder() - .apply { - name = ComputationKey(globalComputationId).toName() - liquidLegionsV2Builder.description = liquidLegionsV2ContentDescription + AdvanceComputationRequestKt.header { + name = ComputationKey(globalComputationId).toName() + liquidLegionsV2 = liquidLegionsV2 { description = liquidLegionsV2ContentDescription } + } + +/** Creates an [AdvanceComputationRequest.Header] for a reach-only liquid legions v2 computation. */ +fun advanceComputationHeader( + reachOnlyLiquidLegionsV2ContentDescription: ReachOnlyLiquidLegionsV2.Description, + globalComputationId: String +): AdvanceComputationRequest.Header = + AdvanceComputationRequestKt.header { + name = ComputationKey(globalComputationId).toName() + reachOnlyLiquidLegionsV2 = reachOnlyLiquidLegionsV2 { + description = reachOnlyLiquidLegionsV2ContentDescription } - .build() + } diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/RequisitionsService.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/RequisitionsService.kt index 465cf83d572..eecc89b5c94 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/RequisitionsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/RequisitionsService.kt @@ -391,6 +391,13 @@ private fun DuchyValue.toDuchyEntryValue(externalDuchyId: String): DuchyEntry.Va signature = value.liquidLegionsV2.elGamalPublicKeySignature } } + DuchyValue.ProtocolCase.REACH_ONLY_LIQUID_LEGIONS_V2 -> reachOnlyLiquidLegionsV2 = + liquidLegionsV2 { + elGamalPublicKey = signedData { + data = value.reachOnlyLiquidLegionsV2.elGamalPublicKey + signature = value.reachOnlyLiquidLegionsV2.elGamalPublicKeySignature + } + } DuchyValue.ProtocolCase.PROTOCOL_NOT_SET -> {} } } diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/service/system/v1alpha/ComputationParticipantsService.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/service/system/v1alpha/ComputationParticipantsService.kt index ff2265c8891..7573a55f434 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/service/system/v1alpha/ComputationParticipantsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/service/system/v1alpha/ComputationParticipantsService.kt @@ -22,11 +22,13 @@ import org.wfanet.measurement.common.grpc.grpcRequireNotNull import org.wfanet.measurement.common.identity.DuchyIdentity import org.wfanet.measurement.common.identity.apiIdToExternalId import org.wfanet.measurement.common.identity.duchyIdentityFromContext +import org.wfanet.measurement.internal.kingdom.ComputationParticipantKt.liquidLegionsV2Details import org.wfanet.measurement.internal.kingdom.ComputationParticipantsGrpcKt.ComputationParticipantsCoroutineStub as InternalComputationParticipantsCoroutineStub import org.wfanet.measurement.internal.kingdom.ConfirmComputationParticipantRequest as InternalConfirmComputationParticipantRequest import org.wfanet.measurement.internal.kingdom.FailComputationParticipantRequest as InternalFailComputationParticipantRequest import org.wfanet.measurement.internal.kingdom.MeasurementLogEntry import org.wfanet.measurement.internal.kingdom.SetParticipantRequisitionParamsRequest as InternalSetParticipantRequisitionParamsRequest +import org.wfanet.measurement.internal.kingdom.setParticipantRequisitionParamsRequest as internalSetParticipantRequisitionParamsRequest import org.wfanet.measurement.system.v1alpha.ComputationParticipant import org.wfanet.measurement.system.v1alpha.ComputationParticipant.RequisitionParams.ProtocolCase import org.wfanet.measurement.system.v1alpha.ComputationParticipantKey @@ -73,25 +75,28 @@ class ComputationParticipantsService( grpcRequire(computationParticipantKey.duchyId == duchyCertificateKey.duchyId) { "The owners of the computation_participant and certificate don't match." } - return InternalSetParticipantRequisitionParamsRequest.newBuilder() - .apply { - externalComputationId = apiIdToExternalId(computationParticipantKey.computationId) - externalDuchyId = computationParticipantKey.duchyId - externalDuchyCertificateId = apiIdToExternalId(duchyCertificateKey.certificateId) - @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. - when (requisitionParams.protocolCase) { - ProtocolCase.LIQUID_LEGIONS_V2 -> { - val llv2 = requisitionParams.liquidLegionsV2 - liquidLegionsV2Builder.apply { - elGamalPublicKey = llv2.elGamalPublicKey - elGamalPublicKeySignature = llv2.elGamalPublicKeySignature - } + return internalSetParticipantRequisitionParamsRequest { + externalComputationId = apiIdToExternalId(computationParticipantKey.computationId) + externalDuchyId = computationParticipantKey.duchyId + externalDuchyCertificateId = apiIdToExternalId(duchyCertificateKey.certificateId) + @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. + when (requisitionParams.protocolCase) { + ProtocolCase.LIQUID_LEGIONS_V2 -> { + liquidLegionsV2 = liquidLegionsV2Details { + elGamalPublicKey = requisitionParams.liquidLegionsV2.elGamalPublicKey + elGamalPublicKeySignature = requisitionParams.liquidLegionsV2.elGamalPublicKeySignature } - ProtocolCase.PROTOCOL_NOT_SET -> - failGrpc { "protocol not set in the requisition_params." } } + ProtocolCase.REACH_ONLY_LIQUID_LEGIONS_V2 -> { + reachOnlyLiquidLegionsV2 = liquidLegionsV2Details { + elGamalPublicKey = requisitionParams.reachOnlyLiquidLegionsV2.elGamalPublicKey + elGamalPublicKeySignature = + requisitionParams.reachOnlyLiquidLegionsV2.elGamalPublicKeySignature + } + } + ProtocolCase.PROTOCOL_NOT_SET -> failGrpc { "protocol not set in the requisition_params." } } - .build() + } } private fun ConfirmComputationParticipantRequest.toInternalRequest(): diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/service/system/v1alpha/ProtoConversions.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/service/system/v1alpha/ProtoConversions.kt index 161c068a747..f400f2a46d7 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/service/system/v1alpha/ProtoConversions.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/service/system/v1alpha/ProtoConversions.kt @@ -39,6 +39,7 @@ import org.wfanet.measurement.system.v1alpha.ComputationLogEntry import org.wfanet.measurement.system.v1alpha.ComputationLogEntryKey import org.wfanet.measurement.system.v1alpha.ComputationParticipant import org.wfanet.measurement.system.v1alpha.ComputationParticipantKey +import org.wfanet.measurement.system.v1alpha.ComputationParticipantKt import org.wfanet.measurement.system.v1alpha.DifferentialPrivacyParams import org.wfanet.measurement.system.v1alpha.Requisition import org.wfanet.measurement.system.v1alpha.RequisitionKey @@ -109,10 +110,17 @@ fun InternalComputationParticipant.toSystemComputationParticipant(): Computation duchyCertificateDer = duchyCertificate.details.x509Der } if (details.hasLiquidLegionsV2()) { - liquidLegionsV2Builder.apply { - elGamalPublicKey = details.liquidLegionsV2.elGamalPublicKey - elGamalPublicKeySignature = details.liquidLegionsV2.elGamalPublicKeySignature - } + liquidLegionsV2 = + ComputationParticipantKt.RequisitionParamsKt.liquidLegionsV2 { + elGamalPublicKey = details.liquidLegionsV2.elGamalPublicKey + elGamalPublicKeySignature = details.liquidLegionsV2.elGamalPublicKeySignature + } + } else if (details.hasReachOnlyLiquidLegionsV2()) { + reachOnlyLiquidLegionsV2 = + ComputationParticipantKt.RequisitionParamsKt.liquidLegionsV2 { + elGamalPublicKey = details.reachOnlyLiquidLegionsV2.elGamalPublicKey + elGamalPublicKeySignature = details.reachOnlyLiquidLegionsV2.elGamalPublicKeySignature + } } } if (hasFailureLogEntry()) { @@ -250,27 +258,27 @@ fun buildMpcProtocolConfig( mpcProtocolConfig { reachOnlyLiquidLegionsV2 = liquidLegionsV2 { sketchParams = liquidLegionsSketchParams { - decayRate = protocolConfig.liquidLegionsV2.sketchParams.decayRate - maxSize = protocolConfig.liquidLegionsV2.sketchParams.maxSize + decayRate = protocolConfig.reachOnlyLiquidLegionsV2.sketchParams.decayRate + maxSize = protocolConfig.reachOnlyLiquidLegionsV2.sketchParams.maxSize } mpcNoise = mpcNoise { blindedHistogramNoise = - duchyProtocolConfig.liquidLegionsV2.mpcNoise.blindedHistogramNoise + duchyProtocolConfig.reachOnlyLiquidLegionsV2.mpcNoise.blindedHistogramNoise .toSystemDifferentialPrivacyParams() publisherNoise = - duchyProtocolConfig.liquidLegionsV2.mpcNoise.noiseForPublisherNoise + duchyProtocolConfig.reachOnlyLiquidLegionsV2.mpcNoise.noiseForPublisherNoise .toSystemDifferentialPrivacyParams() } - ellipticCurveId = protocolConfig.liquidLegionsV2.ellipticCurveId + ellipticCurveId = protocolConfig.reachOnlyLiquidLegionsV2.ellipticCurveId // Use `GEOMETRIC` for unspecified InternalNoiseMechanism for old Measurements. noiseMechanism = if ( - protocolConfig.liquidLegionsV2.noiseMechanism == + protocolConfig.reachOnlyLiquidLegionsV2.noiseMechanism == InternalProtocolConfig.NoiseMechanism.NOISE_MECHANISM_UNSPECIFIED ) { NoiseMechanism.GEOMETRIC } else { - protocolConfig.liquidLegionsV2.noiseMechanism.toSystemNoiseMechanism() + protocolConfig.reachOnlyLiquidLegionsV2.noiseMechanism.toSystemNoiseMechanism() } } } diff --git a/src/main/proto/wfa/measurement/internal/duchy/BUILD.bazel b/src/main/proto/wfa/measurement/internal/duchy/BUILD.bazel index cbb0c4a6f47..870dc9ed02d 100644 --- a/src/main/proto/wfa/measurement/internal/duchy/BUILD.bazel +++ b/src/main/proto/wfa/measurement/internal/duchy/BUILD.bazel @@ -31,6 +31,7 @@ proto_library( strip_import_prefix = IMPORT_PREFIX, deps = [ "//src/main/proto/wfa/measurement/internal/duchy/protocol:liquid_legions_v2_proto", + "//src/main/proto/wfa/measurement/internal/duchy/protocol:reach_only_liquid_legions_v2_proto", ], ) @@ -52,6 +53,7 @@ proto_library( deps = [ ":crypto_proto", "//src/main/proto/wfa/measurement/internal/duchy/protocol:liquid_legions_v2_proto", + "//src/main/proto/wfa/measurement/internal/duchy/protocol:reach_only_liquid_legions_v2_proto", ], ) diff --git a/src/main/proto/wfa/measurement/internal/duchy/computation_details.proto b/src/main/proto/wfa/measurement/internal/duchy/computation_details.proto index 0b9ea0f032c..89397f13c36 100644 --- a/src/main/proto/wfa/measurement/internal/duchy/computation_details.proto +++ b/src/main/proto/wfa/measurement/internal/duchy/computation_details.proto @@ -18,6 +18,7 @@ package wfa.measurement.internal.duchy; import "wfa/measurement/internal/duchy/crypto.proto"; import "wfa/measurement/internal/duchy/protocol/liquid_legions_sketch_aggregation_v2.proto"; +import "wfa/measurement/internal/duchy/protocol/reach_only_liquid_legions_sketch_aggregation_v2.proto"; option java_package = "org.wfanet.measurement.internal.duchy"; option java_multiple_files = true; @@ -62,6 +63,10 @@ message ComputationDetails { // Details specific to the liquidLegionV2 protocol. protocol.LiquidLegionsSketchAggregationV2.ComputationDetails liquid_legions_v2 = 4; + + // Details specific to the reachOnlyLiquidLegionV2 protocol. + protocol.ReachOnlyLiquidLegionsSketchAggregationV2.ComputationDetails + reach_only_liquid_legions_v2 = 5; } } @@ -89,5 +94,9 @@ message ComputationStageDetails { // Details specific to the liquidLegionV2 protocol. protocol.LiquidLegionsSketchAggregationV2.StageDetails liquid_legions_v2 = 1; + + // Details specific to the reachOnlyLiquidLegionV2 protocol. + protocol.ReachOnlyLiquidLegionsSketchAggregationV2.StageDetails + reach_only_liquid_legions_v2 = 2; } } diff --git a/src/main/proto/wfa/measurement/internal/duchy/computation_protocols.proto b/src/main/proto/wfa/measurement/internal/duchy/computation_protocols.proto index 8076c460008..73c238f5433 100644 --- a/src/main/proto/wfa/measurement/internal/duchy/computation_protocols.proto +++ b/src/main/proto/wfa/measurement/internal/duchy/computation_protocols.proto @@ -17,6 +17,7 @@ syntax = "proto3"; package wfa.measurement.internal.duchy; import "wfa/measurement/internal/duchy/protocol/liquid_legions_sketch_aggregation_v2.proto"; +import "wfa/measurement/internal/duchy/protocol/reach_only_liquid_legions_sketch_aggregation_v2.proto"; option java_package = "org.wfanet.measurement.internal.duchy"; option java_multiple_files = true; @@ -31,6 +32,11 @@ message ComputationStage { // Stage of a Liquid Legions sketch aggregation multi party computation. protocol.LiquidLegionsSketchAggregationV2.Stage liquid_legions_sketch_aggregation_v2 = 1; + + // Stage of Reach-Only Liquid Legions sketch aggregation multi party + // computation. + protocol.ReachOnlyLiquidLegionsSketchAggregationV2.Stage + reach_only_liquid_legions_sketch_aggregation_v2 = 2; } } @@ -44,7 +50,10 @@ message ComputationTypeEnum { // Not set intentionally. UNSPECIFIED = 0; - // Aggregation of Liquid Legion sketches V2. + // Aggregation of Liquid Legions sketches V2. LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 = 1; + + // Aggregation of Reach-Only Liquid Legions V2. + REACH_ONLY_LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 = 2; } } diff --git a/src/main/proto/wfa/measurement/internal/duchy/config/protocols_setup_config.proto b/src/main/proto/wfa/measurement/internal/duchy/config/protocols_setup_config.proto index f7e79668e5b..9cc3adc413b 100644 --- a/src/main/proto/wfa/measurement/internal/duchy/config/protocols_setup_config.proto +++ b/src/main/proto/wfa/measurement/internal/duchy/config/protocols_setup_config.proto @@ -24,9 +24,13 @@ option java_multiple_files = true; message ProtocolsSetupConfig { // LiquidLegionsV2 specific configuration LiquidLegionsV2SetupConfig liquid_legions_v2 = 1; + + // ReachOnlyLiquidLegionsV2 specific configuration + LiquidLegionsV2SetupConfig reach_only_liquid_legions_v2 = 2; } -// LiquidLegionsV2 specific configuration +// LiquidLegionsV2 protocols specific configuration. Also used for reach-only +// llv2. message LiquidLegionsV2SetupConfig { enum RoleInComputation { // Never set intentionally diff --git a/src/main/proto/wfa/measurement/internal/duchy/protocol/liquid_legions_v2_encryption_methods.proto b/src/main/proto/wfa/measurement/internal/duchy/protocol/liquid_legions_v2_encryption_methods.proto index bb54ff78f69..ae53d7aa2d7 100644 --- a/src/main/proto/wfa/measurement/internal/duchy/protocol/liquid_legions_v2_encryption_methods.proto +++ b/src/main/proto/wfa/measurement/internal/duchy/protocol/liquid_legions_v2_encryption_methods.proto @@ -25,7 +25,7 @@ option java_package = "org.wfanet.measurement.internal.duchy.protocol"; option java_multiple_files = true; // Proto messages wrapping the input arguments or output results of the liquid -// legion v2 (three round) MPC protocol encryption methods, which are to be +// legions v2 (three round) MPC protocol encryption methods, which are to be // called via JNI in the Mill. Note that these protos contain sensitive data, // e.g., private keys used in the ciphers. So they SHOULD NOT be written into // any logs or leave the running process. diff --git a/src/main/proto/wfa/measurement/internal/duchy/protocol/liquid_legions_v2_noise_config.proto b/src/main/proto/wfa/measurement/internal/duchy/protocol/liquid_legions_v2_noise_config.proto index c46bc0d5fac..9ef947f8d25 100644 --- a/src/main/proto/wfa/measurement/internal/duchy/protocol/liquid_legions_v2_noise_config.proto +++ b/src/main/proto/wfa/measurement/internal/duchy/protocol/liquid_legions_v2_noise_config.proto @@ -22,7 +22,7 @@ option java_package = "org.wfanet.measurement.internal.duchy.protocol"; option java_multiple_files = true; // Configuration for various noises added by the MPC workers in the -// LiquidLegionV2 protocol. +// LiquidLegionV2 protocols. Also used by reach-only protocol. message LiquidLegionsV2NoiseConfig { message ReachNoiseConfig { // DP params for the blind histogram noise register. @@ -44,6 +44,7 @@ message LiquidLegionsV2NoiseConfig { // Differential privacy parameters for noise tuples. // Same value is used for both (0, R, R) and (R, R, R) tuples. + // Ignored by reach-only protocol. DifferentialPrivacyParams frequency_noise_config = 2; // The mechanism used to generate noise in computations. diff --git a/src/main/proto/wfa/measurement/internal/duchy/protocol/reach_only_liquid_legions_sketch_aggregation_v2.proto b/src/main/proto/wfa/measurement/internal/duchy/protocol/reach_only_liquid_legions_sketch_aggregation_v2.proto index 37ec2855620..0e237e30d7a 100644 --- a/src/main/proto/wfa/measurement/internal/duchy/protocol/reach_only_liquid_legions_sketch_aggregation_v2.proto +++ b/src/main/proto/wfa/measurement/internal/duchy/protocol/reach_only_liquid_legions_sketch_aggregation_v2.proto @@ -99,9 +99,10 @@ message ReachOnlyLiquidLegionsSketchAggregationV2 { // Parameters used in this computation. message Parameters { - // Parameters used for liquidLegions sketch creation and estimation. - LiquidLegionsSketchParameters liquid_legions_sketch = 1; - // Noise parameters selected for the LiquidLegionV2 MPC protocol. + // Parameters used for reachOnlyLiquidLegions sketch creation and + // estimation. + LiquidLegionsSketchParameters reach_only_liquid_legions_sketch = 1; + // Noise parameters selected for the ReachOnlyLiquidLegionV2 MPC protocol. LiquidLegionsV2NoiseConfig noise = 2; // ID of the OpenSSL built-in elliptic curve. For example, 415 for the // prime256v1 curve. Required. Immutable. diff --git a/src/main/proto/wfa/measurement/internal/duchy/protocol/reach_only_liquid_legions_v2_encryption_methods.proto b/src/main/proto/wfa/measurement/internal/duchy/protocol/reach_only_liquid_legions_v2_encryption_methods.proto index 743baf1187e..93a5523ca32 100644 --- a/src/main/proto/wfa/measurement/internal/duchy/protocol/reach_only_liquid_legions_v2_encryption_methods.proto +++ b/src/main/proto/wfa/measurement/internal/duchy/protocol/reach_only_liquid_legions_v2_encryption_methods.proto @@ -26,10 +26,10 @@ option java_package = "org.wfanet.measurement.internal.duchy.protocol"; option java_multiple_files = true; // Proto messages wrapping the input arguments or output results of the reach -// only liquid legion v2 MPC protocol encryption methods, which are to be called -// via JNI in the Mill. Note that these protos contain sensitive data, e.g., -// private keys used in the ciphers. So they SHOULD NOT be written into any -// logs or leave the running process. +// only liquid legions v2 MPC protocol encryption methods, which are to be +// called via JNI in the Mill. Note that these protos contain sensitive data, +// e.g., private keys used in the ciphers. So they SHOULD NOT be written into +// any logs or leave the running process. // The request to complete work in the initialization phase. message CompleteReachOnlyInitializationPhaseRequest { diff --git a/src/main/proto/wfa/measurement/internal/kingdom/computation_participant.proto b/src/main/proto/wfa/measurement/internal/kingdom/computation_participant.proto index d486cd6077f..d5413ad311c 100644 --- a/src/main/proto/wfa/measurement/internal/kingdom/computation_participant.proto +++ b/src/main/proto/wfa/measurement/internal/kingdom/computation_participant.proto @@ -53,6 +53,7 @@ message ComputationParticipant { } State state = 6; + // Details of Liquid Legions V2 protocols. message LiquidLegionsV2Details { // Serialized `ElGamalPublicKey` message from public API. bytes el_gamal_public_key = 1; @@ -62,6 +63,8 @@ message ComputationParticipant { message Details { oneof protocol { LiquidLegionsV2Details liquid_legions_v2 = 1; + + LiquidLegionsV2Details reach_only_liquid_legions_v2 = 2; } } Details details = 7; diff --git a/src/main/proto/wfa/measurement/internal/kingdom/computation_participants_service.proto b/src/main/proto/wfa/measurement/internal/kingdom/computation_participants_service.proto index 92b636705ed..e40960556ec 100644 --- a/src/main/proto/wfa/measurement/internal/kingdom/computation_participants_service.proto +++ b/src/main/proto/wfa/measurement/internal/kingdom/computation_participants_service.proto @@ -48,6 +48,9 @@ message SetParticipantRequisitionParamsRequest { oneof protocol { ComputationParticipant.LiquidLegionsV2Details liquid_legions_v2 = 4; + + ComputationParticipant.LiquidLegionsV2Details reach_only_liquid_legions_v2 = + 5; } } diff --git a/src/main/proto/wfa/measurement/system/v1alpha/computation_control_service.proto b/src/main/proto/wfa/measurement/system/v1alpha/computation_control_service.proto index 1212910a8f1..0590e33cccb 100644 --- a/src/main/proto/wfa/measurement/system/v1alpha/computation_control_service.proto +++ b/src/main/proto/wfa/measurement/system/v1alpha/computation_control_service.proto @@ -48,6 +48,9 @@ message AdvanceComputationRequest { oneof protocol { // The LiquidLegionsV2 (three-round) protocol. LiquidLegionsV2 liquid_legions_v2 = 2; + + // The ReachOnlyLiquidLegionsV2 (one-round) protocol. + ReachOnlyLiquidLegionsV2 reach_only_liquid_legions_v2 = 3; } } @@ -87,6 +90,24 @@ message LiquidLegionsV2 { Description description = 1; } +// Parameters for the Reach-Only Liquid Legions v2 protocol. +// +// (-- api-linter: core::0123::resource-annotation=disabled +// aip.dev/not-precedent: This is not a resource message. --) +message ReachOnlyLiquidLegionsV2 { + // The description of the data in the payload. + enum Description { + // The data type is unknown. This is never set intentionally. + DESCRIPTION_UNSPECIFIED = 0; + // The input for the setup phase. + SETUP_PHASE_INPUT = 1; + // The input for the execution phase. + EXECUTION_PHASE_INPUT = 2; + } + // Payload data description + Description description = 1; +} + // Response message for the `AdvanceComputation` method. message AdvanceComputationResponse { // (-- Deliberately empty. --) diff --git a/src/main/proto/wfa/measurement/system/v1alpha/computation_participant.proto b/src/main/proto/wfa/measurement/system/v1alpha/computation_participant.proto index fa0435f899b..9610acc4552 100644 --- a/src/main/proto/wfa/measurement/system/v1alpha/computation_participant.proto +++ b/src/main/proto/wfa/measurement/system/v1alpha/computation_participant.proto @@ -82,7 +82,7 @@ message ComputationParticipant { // Duchy's root certificate. bytes duchy_certificate_der = 2 [(google.api.field_behavior) = OUTPUT_ONLY]; - // Parameters for the Liquid Legions v2 protocol. + // Parameters for the Liquid Legions v2 protocols. message LiquidLegionsV2 { // Serialized `ElGamalPublicKey` message from public API. bytes el_gamal_public_key = 1 [(google.api.field_behavior) = REQUIRED]; @@ -96,6 +96,9 @@ message ComputationParticipant { oneof protocol { // Requisition parameters for the Liquid Legions v2 protocol. LiquidLegionsV2 liquid_legions_v2 = 3; + + // Requisition parameters for the Reach-Only Liquid Legions v2 protocol. + LiquidLegionsV2 reach_only_liquid_legions_v2 = 4; } } // Parameters needed for `Requisition` to be made available in public API. diff --git a/src/main/swig/protocol/reachonlyliquidlegionsv2/BUILD.bazel b/src/main/swig/protocol/reachonlyliquidlegionsv2/BUILD.bazel new file mode 100644 index 00000000000..fd7a57bb8a8 --- /dev/null +++ b/src/main/swig/protocol/reachonlyliquidlegionsv2/BUILD.bazel @@ -0,0 +1,16 @@ +load("@wfa_rules_swig//java:defs.bzl", "java_wrap_cc") + +package(default_visibility = [ + "//src/main/kotlin/org/wfanet/measurement:__subpackages__", + "//src/test/kotlin/org/wfanet/measurement:__subpackages__", +]) + +java_wrap_cc( + name = "reach_only_liquid_legions_v2_encryption_utility", + src = "reach_only_liquid_legions_v2_encryption_utility.swig", + module = "ReachOnlyLiquidLegionsV2EncryptionUtility", + package = "org.wfanet.measurement.internal.duchy.protocol.reachonlyliquidlegionsv2", + deps = [ + "//src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2:reach_only_liquid_legions_v2_encryption_utility_wrapper", + ], +) diff --git a/src/main/swig/protocol/reachonlyliquidlegionsv2/README.md b/src/main/swig/protocol/reachonlyliquidlegionsv2/README.md new file mode 100644 index 00000000000..c3bfa55b45c --- /dev/null +++ b/src/main/swig/protocol/reachonlyliquidlegionsv2/README.md @@ -0,0 +1,22 @@ +# Liquid Legions V2 Encryption Utility Java Library + +The ReachOnlyLiquidLegionsV2EncryptionUtility java class is auto-generated from +the reach_only_liquid_legions_v2_encryption_utility.swig definition. The +implementation of the methods is written in c++ and the source codes are under +src/main/cc/measurement/crypto. We create a swig wrapper on the library and call +into the library via JNI in our java code. + +## Possible errors when using the JNI java library. + +### swig uninstalled + +To keep the library updated, each time when the java library is built, it would +run a swig command (provided in the BUILD.bazel rule) to re-generate all the +swig wrapper files using the latest c++ codes. As a result, the swig software is +required to build the java library. Install swig before building the package. + +For example, in a system using apt-get, run the following command to get swig: + +```shell +sudo apt-get install swig +``` diff --git a/src/main/swig/protocol/reachonlyliquidlegionsv2/reach_only_liquid_legions_v2_encryption_utility.swig b/src/main/swig/protocol/reachonlyliquidlegionsv2/reach_only_liquid_legions_v2_encryption_utility.swig new file mode 100644 index 00000000000..0f793c38ff2 --- /dev/null +++ b/src/main/swig/protocol/reachonlyliquidlegionsv2/reach_only_liquid_legions_v2_encryption_utility.swig @@ -0,0 +1,67 @@ +// Copyright 2023 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +%include "exception.i" +%{ +#include "wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility_wrapper.h" +%} + +// Convert C++ ::absl::StatusOr to a Java byte +// array. +%typemap(jni) ::absl::StatusOr "jbyteArray" +%typemap(jtype) ::absl::StatusOr "byte[]" +%typemap(jstype) ::absl::StatusOr "byte[]" +%typemap(out) ::absl::StatusOr { + if ($1.ok()) { + $result = JCALL1(NewByteArray, jenv, $1.value().length()); + JCALL4(SetByteArrayRegion, jenv, $result, 0, + $1.value().length(), (const jbyte*) $1.value().data()); + } else { + SWIG_exception(SWIG_RuntimeError, $1.status().ToString().c_str()); + } +} +%typemap(javaout) ::absl::StatusOr { + return $jnicall; +} + +// Convert Java byte array to C++ const std::string& for any const std::string& +// input parameter. +%typemap(jni) const std::string& "jbyteArray" +%typemap(jtype) const std::string& "byte[]" +%typemap(jstype) const std::string& "byte[]" +%typemap(javain) const std::string& "$javainput" +%typemap(in) const std::string& { + jsize temp_length = JCALL1(GetArrayLength, jenv, $input); + jbyte* temp_bytes = JCALL2(GetByteArrayElements, jenv, $input, 0); + $1 = new std::string((const char*) temp_bytes, temp_length); + JCALL3(ReleaseByteArrayElements, jenv, $input, temp_bytes, JNI_ABORT); +} +// Convert C++ std::string to a Java byte array. +%clear std::string; +%typemap(jni) std::string "jbyteArray" +%typemap(jtype) std::string "byte[]" +%typemap(jstype) std::string "byte[]" +%typemap(out) std::string { + $result = JCALL1(NewByteArray, jenv, $1.length()); + JCALL4(SetByteArrayRegion, jenv, $result, 0, + $1.length(), (const jbyte*) $1.data()); +} +%typemap(javaout) std::string { + return $jnicall; +} + +// Follow Java style for function names. +%rename("%(lowercamelcase)s",%$isfunction) ""; + +%include "wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility_wrapper.h" diff --git a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/herald/HeraldTest.kt b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/herald/HeraldTest.kt index 1e60c457138..1a30c0141f2 100644 --- a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/herald/HeraldTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/herald/HeraldTest.kt @@ -49,6 +49,7 @@ import org.wfanet.measurement.api.v2alpha.EncryptionPublicKey import org.wfanet.measurement.api.v2alpha.MeasurementSpecKt.reach import org.wfanet.measurement.api.v2alpha.MeasurementSpecKt.reachAndFrequency import org.wfanet.measurement.api.v2alpha.differentialPrivacyParams as cmmsDifferentialPrivacyParams +import org.wfanet.measurement.api.v2alpha.elGamalPublicKey import org.wfanet.measurement.api.v2alpha.encryptionPublicKey import org.wfanet.measurement.api.v2alpha.measurementSpec import org.wfanet.measurement.common.grpc.testing.GrpcTestServerRule @@ -67,7 +68,6 @@ import org.wfanet.measurement.duchy.service.internal.computations.newPassThrough import org.wfanet.measurement.duchy.storage.ComputationStore import org.wfanet.measurement.duchy.storage.RequisitionStore import org.wfanet.measurement.duchy.toProtocolStage -import org.wfanet.measurement.internal.duchy.ComputationDetails import org.wfanet.measurement.internal.duchy.ComputationDetailsKt.kingdomComputationDetails import org.wfanet.measurement.internal.duchy.ComputationsGrpcKt.ComputationsCoroutineImplBase as InternalComputationsCoroutineImplBase import org.wfanet.measurement.internal.duchy.ComputationsGrpcKt.ComputationsCoroutineStub as InternalComputationsCoroutineStub @@ -79,27 +79,26 @@ import org.wfanet.measurement.internal.duchy.GetComputationTokenRequest import org.wfanet.measurement.internal.duchy.computationDetails import org.wfanet.measurement.internal.duchy.computationToken import org.wfanet.measurement.internal.duchy.config.LiquidLegionsV2SetupConfig.RoleInComputation -import org.wfanet.measurement.internal.duchy.config.ProtocolsSetupConfig +import org.wfanet.measurement.internal.duchy.config.liquidLegionsV2SetupConfig +import org.wfanet.measurement.internal.duchy.config.protocolsSetupConfig import org.wfanet.measurement.internal.duchy.deleteComputationRequest import org.wfanet.measurement.internal.duchy.differentialPrivacyParams as duchyDifferentialPrivacyParams +import org.wfanet.measurement.internal.duchy.elGamalPublicKey as internalElgamalPublicKey import org.wfanet.measurement.internal.duchy.getComputationTokenResponse import org.wfanet.measurement.internal.duchy.getContinuationTokenResponse -import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggregationV2.ComputationDetails.ComputationParticipant -import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggregationV2.Stage.CONFIRMATION_PHASE -import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggregationV2.Stage.INITIALIZATION_PHASE -import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggregationV2.Stage.SETUP_PHASE -import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggregationV2.Stage.WAIT_REQUISITIONS_AND_KEY_SET -import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggregationV2.Stage.WAIT_TO_START +import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggregationV2 import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggregationV2Kt -import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggregationV2Kt.ComputationDetailsKt.parameters import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsV2NoiseConfig import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsV2NoiseConfigKt.reachNoiseConfig +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2 +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2Kt import org.wfanet.measurement.internal.duchy.protocol.liquidLegionsSketchParameters import org.wfanet.measurement.internal.duchy.protocol.liquidLegionsV2NoiseConfig import org.wfanet.measurement.internal.duchy.setContinuationTokenRequest import org.wfanet.measurement.storage.StorageClient import org.wfanet.measurement.storage.testing.InMemoryStorageClient import org.wfanet.measurement.system.v1alpha.Computation +import org.wfanet.measurement.system.v1alpha.Computation.MpcProtocolConfig import org.wfanet.measurement.system.v1alpha.Computation.MpcProtocolConfig.NoiseMechanism as SystemNoiseMechanism import org.wfanet.measurement.system.v1alpha.ComputationKey import org.wfanet.measurement.system.v1alpha.ComputationKt.MpcProtocolConfigKt.LiquidLegionsV2Kt.liquidLegionsSketchParams @@ -112,6 +111,8 @@ import org.wfanet.measurement.system.v1alpha.ComputationLogEntry import org.wfanet.measurement.system.v1alpha.ComputationParticipant as SystemComputationParticipant import org.wfanet.measurement.system.v1alpha.ComputationParticipantKey import org.wfanet.measurement.system.v1alpha.ComputationParticipantKt +import org.wfanet.measurement.system.v1alpha.ComputationParticipantKt.RequisitionParamsKt +import org.wfanet.measurement.system.v1alpha.ComputationParticipantKt.requisitionParams import org.wfanet.measurement.system.v1alpha.ComputationParticipantsGrpcKt.ComputationParticipantsCoroutineImplBase as SystemComputationParticipantsCoroutineImplBase import org.wfanet.measurement.system.v1alpha.ComputationParticipantsGrpcKt.ComputationParticipantsCoroutineStub as SystemComputationParticipantsCoroutineStub import org.wfanet.measurement.system.v1alpha.ComputationsGrpcKt.ComputationsCoroutineImplBase as SystemComputationsCoroutineImplBase @@ -120,6 +121,7 @@ import org.wfanet.measurement.system.v1alpha.FailComputationParticipantRequest import org.wfanet.measurement.system.v1alpha.Requisition import org.wfanet.measurement.system.v1alpha.StreamActiveComputationsResponse import org.wfanet.measurement.system.v1alpha.computation +import org.wfanet.measurement.system.v1alpha.computationParticipant as systemComputationParticipant import org.wfanet.measurement.system.v1alpha.computationParticipant import org.wfanet.measurement.system.v1alpha.copy import org.wfanet.measurement.system.v1alpha.differentialPrivacyParams as systemDifferentialPrivacyParams @@ -176,7 +178,7 @@ private val PUBLIC_API_REACH_ONLY_MEASUREMENT_SPEC = measurementSpec { private val SERIALIZED_REACH_ONLY_MEASUREMENT_SPEC: ByteString = PUBLIC_API_REACH_ONLY_MEASUREMENT_SPEC.toByteString() -private val MPC_PROTOCOL_CONFIG = mpcProtocolConfig { +private val LLV2_MPC_PROTOCOL_CONFIG = mpcProtocolConfig { liquidLegionsV2 = liquidLegionsV2 { sketchParams = liquidLegionsSketchParams { decayRate = 12.0 @@ -198,38 +200,79 @@ private val MPC_PROTOCOL_CONFIG = mpcProtocolConfig { } } +private val RO_LLV2_MPC_PROTOCOL_CONFIG = mpcProtocolConfig { + reachOnlyLiquidLegionsV2 = liquidLegionsV2 { + sketchParams = liquidLegionsSketchParams { + decayRate = 12.0 + maxSize = 100_000 + } + mpcNoise = mpcNoise { + blindedHistogramNoise = systemDifferentialPrivacyParams { + epsilon = 3.1 + delta = 3.2 + } + publisherNoise = systemDifferentialPrivacyParams { + epsilon = 4.1 + delta = 4.2 + } + } + ellipticCurveId = 415 + noiseMechanism = SystemNoiseMechanism.GEOMETRIC + } +} + private const val AGGREGATOR_DUCHY_ID = "aggregator_duchy" private const val AGGREGATOR_HERALD_ID = "aggregator_herald" private const val NON_AGGREGATOR_DUCHY_ID = "worker_duchy" private const val NON_AGGREGATOR_HERALD_ID = "worker_herald" -private val AGGREGATOR_PROTOCOLS_SETUP_CONFIG = - ProtocolsSetupConfig.newBuilder() - .apply { - liquidLegionsV2Builder.apply { - role = RoleInComputation.AGGREGATOR - externalAggregatorDuchyId = DUCHY_ONE - } +private val AGGREGATOR_PROTOCOLS_SETUP_CONFIG = protocolsSetupConfig { + liquidLegionsV2 = liquidLegionsV2SetupConfig { + role = RoleInComputation.AGGREGATOR + externalAggregatorDuchyId = DUCHY_ONE + } + reachOnlyLiquidLegionsV2 = liquidLegionsV2SetupConfig { + role = RoleInComputation.AGGREGATOR + externalAggregatorDuchyId = DUCHY_ONE + } +} + +private val NON_AGGREGATOR_PROTOCOLS_SETUP_CONFIG = protocolsSetupConfig { + liquidLegionsV2 = liquidLegionsV2SetupConfig { + role = RoleInComputation.NON_AGGREGATOR + externalAggregatorDuchyId = DUCHY_ONE + } + reachOnlyLiquidLegionsV2 = liquidLegionsV2SetupConfig { + role = RoleInComputation.NON_AGGREGATOR + externalAggregatorDuchyId = DUCHY_ONE + } +} + +private val LLV2_AGGREGATOR_COMPUTATION_DETAILS = computationDetails { + liquidLegionsV2 = + LiquidLegionsSketchAggregationV2Kt.computationDetails { role = RoleInComputation.AGGREGATOR } +} + +private val LLV2_NON_AGGREGATOR_COMPUTATION_DETAILS = computationDetails { + liquidLegionsV2 = + LiquidLegionsSketchAggregationV2Kt.computationDetails { + role = RoleInComputation.NON_AGGREGATOR } - .build() -private val NON_AGGREGATOR_PROTOCOLS_SETUP_CONFIG = - ProtocolsSetupConfig.newBuilder() - .apply { - liquidLegionsV2Builder.apply { - role = RoleInComputation.NON_AGGREGATOR - externalAggregatorDuchyId = DUCHY_ONE - } +} + +private val RO_LLV2_AGGREGATOR_COMPUTATION_DETAILS = computationDetails { + reachOnlyLiquidLegionsV2 = + ReachOnlyLiquidLegionsSketchAggregationV2Kt.computationDetails { + role = RoleInComputation.AGGREGATOR } - .build() +} -private val AGGREGATOR_COMPUTATION_DETAILS = - ComputationDetails.newBuilder() - .apply { liquidLegionsV2Builder.apply { role = RoleInComputation.AGGREGATOR } } - .build() -private val NON_AGGREGATOR_COMPUTATION_DETAILS = - ComputationDetails.newBuilder() - .apply { liquidLegionsV2Builder.apply { role = RoleInComputation.NON_AGGREGATOR } } - .build() +private val RO_LLV2_NON_AGGREGATOR_COMPUTATION_DETAILS = computationDetails { + reachOnlyLiquidLegionsV2 = + ReachOnlyLiquidLegionsSketchAggregationV2Kt.computationDetails { + role = RoleInComputation.NON_AGGREGATOR + } +} private const val COMPUTATION_GLOBAL_ID = "123" @@ -358,7 +401,7 @@ class HeraldTest { } @Test - fun `syncStatuses creates new computations`() = runTest { + fun `syncStatuses creates a new llv2 computations`() = runTest { val confirmingKnown = buildComputationAtKingdom("1", Computation.State.PENDING_REQUISITION_PARAMS) @@ -376,8 +419,8 @@ class HeraldTest { fakeComputationDatabase.addComputation( globalId = confirmingKnown.key.computationId, - stage = INITIALIZATION_PHASE.toProtocolStage(), - computationDetails = AGGREGATOR_COMPUTATION_DETAILS, + stage = LiquidLegionsSketchAggregationV2.Stage.INITIALIZATION_PHASE.toProtocolStage(), + computationDetails = LLV2_AGGREGATOR_COMPUTATION_DETAILS, blobs = listOf(newInputBlobMetadata(0L, "input-blob"), newEmptyOutputBlobMetadata(1L)) ) @@ -393,9 +436,9 @@ class HeraldTest { ) .containsExactly( confirmingKnown.key.computationId.toLong(), - INITIALIZATION_PHASE.toProtocolStage(), + LiquidLegionsSketchAggregationV2.Stage.INITIALIZATION_PHASE.toProtocolStage(), confirmingUnknown.key.computationId.toLong(), - INITIALIZATION_PHASE.toProtocolStage() + LiquidLegionsSketchAggregationV2.Stage.INITIALIZATION_PHASE.toProtocolStage() ) assertThat( @@ -419,42 +462,43 @@ class HeraldTest { liquidLegionsV2 = LiquidLegionsSketchAggregationV2Kt.computationDetails { role = RoleInComputation.AGGREGATOR - parameters = parameters { - maximumFrequency = 10 - liquidLegionsSketch = liquidLegionsSketchParameters { - decayRate = 12.0 - size = 100_000L - } - noise = liquidLegionsV2NoiseConfig { - reachNoiseConfig = reachNoiseConfig { - blindHistogramNoise = duchyDifferentialPrivacyParams { - epsilon = 3.1 - delta = 3.2 - } - noiseForPublisherNoise = duchyDifferentialPrivacyParams { - epsilon = 4.1 - delta = 4.2 + parameters = + LiquidLegionsSketchAggregationV2Kt.ComputationDetailsKt.parameters { + maximumFrequency = 10 + liquidLegionsSketch = liquidLegionsSketchParameters { + decayRate = 12.0 + size = 100_000L + } + noise = liquidLegionsV2NoiseConfig { + reachNoiseConfig = reachNoiseConfig { + blindHistogramNoise = duchyDifferentialPrivacyParams { + epsilon = 3.1 + delta = 3.2 + } + noiseForPublisherNoise = duchyDifferentialPrivacyParams { + epsilon = 4.1 + delta = 4.2 + } + globalReachDpNoise = duchyDifferentialPrivacyParams { + epsilon = 1.1 + delta = 1.2 + } } - globalReachDpNoise = duchyDifferentialPrivacyParams { - epsilon = 1.1 - delta = 1.2 + frequencyNoiseConfig = duchyDifferentialPrivacyParams { + epsilon = 2.1 + delta = 2.2 } + noiseMechanism = LiquidLegionsV2NoiseConfig.NoiseMechanism.GEOMETRIC } - frequencyNoiseConfig = duchyDifferentialPrivacyParams { - epsilon = 2.1 - delta = 2.2 - } - noiseMechanism = LiquidLegionsV2NoiseConfig.NoiseMechanism.GEOMETRIC + ellipticCurveId = 415 } - ellipticCurveId = 415 - } } } ) } @Test - fun `syncStatuses creates new computations for reach-only`() = runTest { + fun `syncStatuses creates new a llv2 computations for reach-only`() = runTest { val confirmingKnown = buildComputationAtKingdom( "1", @@ -477,8 +521,8 @@ class HeraldTest { fakeComputationDatabase.addComputation( globalId = confirmingKnown.key.computationId, - stage = INITIALIZATION_PHASE.toProtocolStage(), - computationDetails = AGGREGATOR_COMPUTATION_DETAILS, + stage = LiquidLegionsSketchAggregationV2.Stage.INITIALIZATION_PHASE.toProtocolStage(), + computationDetails = LLV2_AGGREGATOR_COMPUTATION_DETAILS, blobs = listOf(newInputBlobMetadata(0L, "input-blob"), newEmptyOutputBlobMetadata(1L)) ) @@ -494,9 +538,9 @@ class HeraldTest { ) .containsExactly( confirmingKnown.key.computationId.toLong(), - INITIALIZATION_PHASE.toProtocolStage(), + LiquidLegionsSketchAggregationV2.Stage.INITIALIZATION_PHASE.toProtocolStage(), confirmingUnknown.key.computationId.toLong(), - INITIALIZATION_PHASE.toProtocolStage() + LiquidLegionsSketchAggregationV2.Stage.INITIALIZATION_PHASE.toProtocolStage() ) assertThat( @@ -520,38 +564,139 @@ class HeraldTest { liquidLegionsV2 = LiquidLegionsSketchAggregationV2Kt.computationDetails { role = RoleInComputation.AGGREGATOR - parameters = parameters { - maximumFrequency = 10 - liquidLegionsSketch = liquidLegionsSketchParameters { - decayRate = 12.0 - size = 100_000L - } - noise = liquidLegionsV2NoiseConfig { - noiseMechanism = LiquidLegionsV2NoiseConfig.NoiseMechanism.GEOMETRIC - reachNoiseConfig = reachNoiseConfig { - blindHistogramNoise = duchyDifferentialPrivacyParams { - epsilon = 3.1 - delta = 3.2 - } - noiseForPublisherNoise = duchyDifferentialPrivacyParams { - epsilon = 4.1 - delta = 4.2 + parameters = + LiquidLegionsSketchAggregationV2Kt.ComputationDetailsKt.parameters { + maximumFrequency = 10 + liquidLegionsSketch = liquidLegionsSketchParameters { + decayRate = 12.0 + size = 100_000L + } + noise = liquidLegionsV2NoiseConfig { + noiseMechanism = LiquidLegionsV2NoiseConfig.NoiseMechanism.GEOMETRIC + reachNoiseConfig = reachNoiseConfig { + blindHistogramNoise = duchyDifferentialPrivacyParams { + epsilon = 3.1 + delta = 3.2 + } + noiseForPublisherNoise = duchyDifferentialPrivacyParams { + epsilon = 4.1 + delta = 4.2 + } + globalReachDpNoise = duchyDifferentialPrivacyParams { + epsilon = 1.1 + delta = 1.2 + } } - globalReachDpNoise = duchyDifferentialPrivacyParams { - epsilon = 1.1 - delta = 1.2 + } + ellipticCurveId = 415 + } + } + } + ) + } + + @Test + fun `syncStatuses creates new a rollv2 computations for reach-only`() = runTest { + val confirmingKnown = + buildComputationAtKingdom( + "1", + Computation.State.PENDING_REQUISITION_PARAMS, + serializedMeasurementSpec = SERIALIZED_REACH_ONLY_MEASUREMENT_SPEC, + mpcProtocolConfig = RO_LLV2_MPC_PROTOCOL_CONFIG + ) + + val systemApiRequisitions1 = + REACH_ONLY_REQUISITION_1.toSystemRequisition("2", Requisition.State.UNFULFILLED) + val systemApiRequisitions2 = + REACH_ONLY_REQUISITION_2.toSystemRequisition("2", Requisition.State.UNFULFILLED) + val confirmingUnknown = + buildComputationAtKingdom( + "2", + Computation.State.PENDING_REQUISITION_PARAMS, + listOf(systemApiRequisitions1, systemApiRequisitions2), + serializedMeasurementSpec = SERIALIZED_REACH_ONLY_MEASUREMENT_SPEC, + mpcProtocolConfig = RO_LLV2_MPC_PROTOCOL_CONFIG + ) + mockStreamActiveComputationsToReturn(confirmingKnown, confirmingUnknown) + + fakeComputationDatabase.addComputation( + globalId = confirmingKnown.key.computationId, + stage = + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.INITIALIZATION_PHASE.toProtocolStage(), + computationDetails = RO_LLV2_AGGREGATOR_COMPUTATION_DETAILS, + blobs = listOf(newInputBlobMetadata(0L, "input-blob"), newEmptyOutputBlobMetadata(1L)) + ) + + aggregatorHerald.syncStatuses() + + verifyBlocking(continuationTokensService, atLeastOnce()) { + setContinuationToken(eq(setContinuationTokenRequest { this.token = "2" })) + } + assertThat( + fakeComputationDatabase.mapValues { (_, fakeComputation) -> + fakeComputation.computationStage + } + ) + .containsExactly( + confirmingKnown.key.computationId.toLong(), + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.INITIALIZATION_PHASE.toProtocolStage(), + confirmingUnknown.key.computationId.toLong(), + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.INITIALIZATION_PHASE.toProtocolStage() + ) + + assertThat( + fakeComputationDatabase[confirmingUnknown.key.computationId.toLong()]?.requisitionsList + ) + .containsExactly( + REACH_ONLY_REQUISITION_1.toRequisitionMetadata(Requisition.State.UNFULFILLED), + REACH_ONLY_REQUISITION_2.toRequisitionMetadata(Requisition.State.UNFULFILLED) + ) + assertThat( + fakeComputationDatabase[confirmingUnknown.key.computationId.toLong()]?.computationDetails + ) + .isEqualTo( + computationDetails { + blobsStoragePrefix = "computation-blob-storage/2" + kingdomComputation = kingdomComputationDetails { + publicApiVersion = PUBLIC_API_VERSION + measurementSpec = SERIALIZED_REACH_ONLY_MEASUREMENT_SPEC + measurementPublicKey = PUBLIC_API_ENCRYPTION_PUBLIC_KEY.toDuchyEncryptionPublicKey() + } + reachOnlyLiquidLegionsV2 = + ReachOnlyLiquidLegionsSketchAggregationV2Kt.computationDetails { + role = RoleInComputation.AGGREGATOR + parameters = + ReachOnlyLiquidLegionsSketchAggregationV2Kt.ComputationDetailsKt.parameters { + reachOnlyLiquidLegionsSketch = liquidLegionsSketchParameters { + decayRate = 12.0 + size = 100_000L + } + noise = liquidLegionsV2NoiseConfig { + noiseMechanism = LiquidLegionsV2NoiseConfig.NoiseMechanism.GEOMETRIC + reachNoiseConfig = reachNoiseConfig { + blindHistogramNoise = duchyDifferentialPrivacyParams { + epsilon = 3.1 + delta = 3.2 + } + noiseForPublisherNoise = duchyDifferentialPrivacyParams { + epsilon = 4.1 + delta = 4.2 + } + globalReachDpNoise = duchyDifferentialPrivacyParams { + epsilon = 1.1 + delta = 1.2 + } } } + ellipticCurveId = 415 } - ellipticCurveId = 415 - } } } ) } @Test - fun `syncStatuses update llv2 computations in WAIT_REQUISITIONS_AND_KEY_SET`() = runTest { + fun `syncStatuses confirms participants for llv2 computations`() = runTest { val globalId = "123456" val systemApiRequisitions1 = REQUISITION_1.toSystemRequisition(globalId, Requisition.State.FULFILLED, DUCHY_ONE) @@ -636,8 +781,9 @@ class HeraldTest { fakeComputationDatabase.addComputation( globalId = globalId, - stage = WAIT_REQUISITIONS_AND_KEY_SET.toProtocolStage(), - computationDetails = NON_AGGREGATOR_COMPUTATION_DETAILS, + stage = + LiquidLegionsSketchAggregationV2.Stage.WAIT_REQUISITIONS_AND_KEY_SET.toProtocolStage(), + computationDetails = LLV2_NON_AGGREGATOR_COMPUTATION_DETAILS, requisitions = listOf( REQUISITION_1.toRequisitionMetadata(Requisition.State.UNFULFILLED), @@ -659,11 +805,11 @@ class HeraldTest { val duchyComputationToken = fakeComputationDatabase.readComputationToken(globalId)!! assertThat(duchyComputationToken.computationStage) - .isEqualTo(CONFIRMATION_PHASE.toProtocolStage()) + .isEqualTo(LiquidLegionsSketchAggregationV2.Stage.CONFIRMATION_PHASE.toProtocolStage()) assertThat(duchyComputationToken.computationDetails.liquidLegionsV2.participantList) .isEqualTo( mutableListOf( - ComputationParticipant.newBuilder() + LiquidLegionsSketchAggregationV2.ComputationDetails.ComputationParticipant.newBuilder() .apply { duchyId = DUCHY_THREE publicKeyBuilder.apply { @@ -675,7 +821,7 @@ class HeraldTest { duchyCertificateDer = ByteString.copyFromUtf8("duchyCertificateDer_3") } .build(), - ComputationParticipant.newBuilder() + LiquidLegionsSketchAggregationV2.ComputationDetails.ComputationParticipant.newBuilder() .apply { duchyId = DUCHY_TWO publicKeyBuilder.apply { @@ -687,7 +833,7 @@ class HeraldTest { duchyCertificateDer = ByteString.copyFromUtf8("duchyCertificateDer_2") } .build(), - ComputationParticipant.newBuilder() + LiquidLegionsSketchAggregationV2.ComputationDetails.ComputationParticipant.newBuilder() .apply { duchyId = DUCHY_ONE publicKeyBuilder.apply { @@ -709,7 +855,188 @@ class HeraldTest { } @Test - fun `syncStatuses starts computations in wait_to_start`() = runTest { + fun `syncStatuses confirms participants for rollv2 computations`() = runTest { + val globalId = "123456" + val systemApiRequisitions1 = + REQUISITION_1.toSystemRequisition(globalId, Requisition.State.FULFILLED, DUCHY_ONE) + val systemApiRequisitions2 = + REQUISITION_2.toSystemRequisition(globalId, Requisition.State.FULFILLED, DUCHY_TWO) + val v2alphaApiElgamalPublicKey1 = elGamalPublicKey { + generator = ByteString.copyFromUtf8("generator_1") + element = ByteString.copyFromUtf8("element_1") + } + val v2alphaApiElgamalPublicKey2 = elGamalPublicKey { + generator = ByteString.copyFromUtf8("generator_2") + element = ByteString.copyFromUtf8("element_2") + } + val v2alphaApiElgamalPublicKey3 = elGamalPublicKey { + generator = ByteString.copyFromUtf8("generator_3") + element = ByteString.copyFromUtf8("element_3") + } + val systemComputationParticipant1 = systemComputationParticipant { + name = ComputationParticipantKey(globalId, DUCHY_ONE).toName() + requisitionParams = requisitionParams { + duchyCertificate = "duchyCertificate_1" + duchyCertificateDer = ByteString.copyFromUtf8("duchyCertificateDer_1") + reachOnlyLiquidLegionsV2 = + RequisitionParamsKt.liquidLegionsV2 { + elGamalPublicKey = v2alphaApiElgamalPublicKey1.toByteString() + elGamalPublicKeySignature = ByteString.copyFromUtf8("elGamalPublicKeySignature_1") + } + } + } + val systemComputationParticipant2 = systemComputationParticipant { + name = ComputationParticipantKey(globalId, DUCHY_TWO).toName() + requisitionParams = requisitionParams { + duchyCertificate = "duchyCertificate_2" + duchyCertificateDer = ByteString.copyFromUtf8("duchyCertificateDer_2") + reachOnlyLiquidLegionsV2 = + RequisitionParamsKt.liquidLegionsV2 { + elGamalPublicKey = v2alphaApiElgamalPublicKey2.toByteString() + elGamalPublicKeySignature = ByteString.copyFromUtf8("elGamalPublicKeySignature_2") + } + } + } + val systemComputationParticipant3 = systemComputationParticipant { + name = ComputationParticipantKey(globalId, DUCHY_THREE).toName() + requisitionParams = requisitionParams { + duchyCertificate = "duchyCertificate_3" + duchyCertificateDer = ByteString.copyFromUtf8("duchyCertificateDer_3") + reachOnlyLiquidLegionsV2 = + RequisitionParamsKt.liquidLegionsV2 { + elGamalPublicKey = v2alphaApiElgamalPublicKey3.toByteString() + elGamalPublicKeySignature = ByteString.copyFromUtf8("elGamalPublicKeySignature_3") + } + } + } + val waitingRequisitionsAndKeySet = + buildComputationAtKingdom( + globalId, + Computation.State.PENDING_PARTICIPANT_CONFIRMATION, + listOf(systemApiRequisitions1, systemApiRequisitions2), + listOf( + systemComputationParticipant1, + systemComputationParticipant2, + systemComputationParticipant3 + ) + ) + + mockStreamActiveComputationsToReturn(waitingRequisitionsAndKeySet) + + fakeComputationDatabase.addComputation( + globalId = globalId, + stage = + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.WAIT_REQUISITIONS_AND_KEY_SET + .toProtocolStage(), + computationDetails = RO_LLV2_NON_AGGREGATOR_COMPUTATION_DETAILS, + requisitions = + listOf( + REQUISITION_1.toRequisitionMetadata(Requisition.State.UNFULFILLED), + REQUISITION_2.toRequisitionMetadata(Requisition.State.UNFULFILLED) + ) + ) + + aggregatorHerald.syncStatuses() + + verifyBlocking(continuationTokensService, atLeastOnce()) { + setContinuationToken( + eq( + setContinuationTokenRequest { + this.token = waitingRequisitionsAndKeySet.continuationToken() + } + ) + ) + } + + val duchyComputationToken = fakeComputationDatabase.readComputationToken(globalId)!! + assertThat(duchyComputationToken.computationStage) + .isEqualTo( + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.CONFIRMATION_PHASE.toProtocolStage() + ) + assertThat(duchyComputationToken.computationDetails.reachOnlyLiquidLegionsV2.participantList) + .isEqualTo( + listOf( + ReachOnlyLiquidLegionsSketchAggregationV2Kt.ComputationDetailsKt.computationParticipant { + duchyId = DUCHY_THREE + publicKey = internalElgamalPublicKey { + generator = ByteString.copyFromUtf8("generator_3") + element = ByteString.copyFromUtf8("element_3") + } + elGamalPublicKey = v2alphaApiElgamalPublicKey3.toByteString() + elGamalPublicKeySignature = ByteString.copyFromUtf8("elGamalPublicKeySignature_3") + duchyCertificateDer = ByteString.copyFromUtf8("duchyCertificateDer_3") + }, + ReachOnlyLiquidLegionsSketchAggregationV2Kt.ComputationDetailsKt.computationParticipant { + duchyId = DUCHY_TWO + publicKey = internalElgamalPublicKey { + generator = ByteString.copyFromUtf8("generator_2") + element = ByteString.copyFromUtf8("element_2") + } + elGamalPublicKey = v2alphaApiElgamalPublicKey2.toByteString() + elGamalPublicKeySignature = ByteString.copyFromUtf8("elGamalPublicKeySignature_2") + duchyCertificateDer = ByteString.copyFromUtf8("duchyCertificateDer_2") + }, + ReachOnlyLiquidLegionsSketchAggregationV2Kt.ComputationDetailsKt.computationParticipant { + duchyId = DUCHY_ONE + publicKey = internalElgamalPublicKey { + generator = ByteString.copyFromUtf8("generator_1") + element = ByteString.copyFromUtf8("element_1") + } + elGamalPublicKey = v2alphaApiElgamalPublicKey1.toByteString() + elGamalPublicKeySignature = ByteString.copyFromUtf8("elGamalPublicKeySignature_1") + duchyCertificateDer = ByteString.copyFromUtf8("duchyCertificateDer_1") + } + ) + ) + assertThat(duchyComputationToken.requisitionsList) + .containsExactly( + REQUISITION_1.toRequisitionMetadata(Requisition.State.FULFILLED, DUCHY_ONE), + REQUISITION_2.toRequisitionMetadata(Requisition.State.FULFILLED, DUCHY_TWO) + ) + } + + @Test + fun `syncStatuses starts llv2 computations`() = runTest { + val waitingToStart = + buildComputationAtKingdom(COMPUTATION_GLOBAL_ID, Computation.State.PENDING_COMPUTATION) + val addingNoise = buildComputationAtKingdom("231313", Computation.State.PENDING_COMPUTATION) + mockStreamActiveComputationsToReturn(waitingToStart, addingNoise) + + fakeComputationDatabase.addComputation( + globalId = waitingToStart.key.computationId, + stage = LiquidLegionsSketchAggregationV2.Stage.WAIT_TO_START.toProtocolStage(), + computationDetails = LLV2_NON_AGGREGATOR_COMPUTATION_DETAILS, + blobs = listOf(newPassThroughBlobMetadata(0L, "local-copy-of-sketches")) + ) + + fakeComputationDatabase.addComputation( + globalId = addingNoise.key.computationId, + stage = LiquidLegionsSketchAggregationV2.Stage.SETUP_PHASE.toProtocolStage(), + computationDetails = LLV2_AGGREGATOR_COMPUTATION_DETAILS, + blobs = + listOf(newInputBlobMetadata(0L, "inputs-to-add-noise"), newEmptyOutputBlobMetadata(1L)) + ) + + aggregatorHerald.syncStatuses() + + verifyBlocking(continuationTokensService, atLeastOnce()) { + setContinuationToken(eq(setContinuationTokenRequest { this.token = "231313" })) + } + assertThat( + fakeComputationDatabase.mapValues { (_, fakeComputation) -> + fakeComputation.computationStage + } + ) + .containsExactly( + waitingToStart.key.computationId.toLong(), + LiquidLegionsSketchAggregationV2.Stage.SETUP_PHASE.toProtocolStage(), + addingNoise.key.computationId.toLong(), + LiquidLegionsSketchAggregationV2.Stage.SETUP_PHASE.toProtocolStage() + ) + } + + @Test + fun `syncStatuses starts rollv2 computations`() = runTest { val waitingToStart = buildComputationAtKingdom(COMPUTATION_GLOBAL_ID, Computation.State.PENDING_COMPUTATION) val addingNoise = buildComputationAtKingdom("231313", Computation.State.PENDING_COMPUTATION) @@ -717,15 +1044,15 @@ class HeraldTest { fakeComputationDatabase.addComputation( globalId = waitingToStart.key.computationId, - stage = WAIT_TO_START.toProtocolStage(), - computationDetails = NON_AGGREGATOR_COMPUTATION_DETAILS, + stage = ReachOnlyLiquidLegionsSketchAggregationV2.Stage.WAIT_TO_START.toProtocolStage(), + computationDetails = RO_LLV2_NON_AGGREGATOR_COMPUTATION_DETAILS, blobs = listOf(newPassThroughBlobMetadata(0L, "local-copy-of-sketches")) ) fakeComputationDatabase.addComputation( globalId = addingNoise.key.computationId, - stage = SETUP_PHASE.toProtocolStage(), - computationDetails = AGGREGATOR_COMPUTATION_DETAILS, + stage = ReachOnlyLiquidLegionsSketchAggregationV2.Stage.SETUP_PHASE.toProtocolStage(), + computationDetails = RO_LLV2_AGGREGATOR_COMPUTATION_DETAILS, blobs = listOf(newInputBlobMetadata(0L, "inputs-to-add-noise"), newEmptyOutputBlobMetadata(1L)) ) @@ -742,9 +1069,9 @@ class HeraldTest { ) .containsExactly( waitingToStart.key.computationId.toLong(), - SETUP_PHASE.toProtocolStage(), + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.SETUP_PHASE.toProtocolStage(), addingNoise.key.computationId.toLong(), - SETUP_PHASE.toProtocolStage() + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.SETUP_PHASE.toProtocolStage() ) } @@ -769,8 +1096,8 @@ class HeraldTest { } fakeComputationDatabase.addComputation( globalId = computation.key.computationId, - stage = INITIALIZATION_PHASE.toProtocolStage(), - computationDetails = NON_AGGREGATOR_COMPUTATION_DETAILS, + stage = LiquidLegionsSketchAggregationV2.Stage.INITIALIZATION_PHASE.toProtocolStage(), + computationDetails = LLV2_NON_AGGREGATOR_COMPUTATION_DETAILS, blobs = listOf(newInputBlobMetadata(0L, "local-copy-of-sketches")) ) @@ -786,22 +1113,23 @@ class HeraldTest { ) .containsExactly( computation.key.computationId.toLong(), - INITIALIZATION_PHASE.toProtocolStage() + LiquidLegionsSketchAggregationV2.Stage.INITIALIZATION_PHASE.toProtocolStage() ) // Update the state. fakeComputationDatabase.remove(computation.key.computationId.toLong()) fakeComputationDatabase.addComputation( globalId = computation.key.computationId, - stage = WAIT_TO_START.toProtocolStage(), - computationDetails = NON_AGGREGATOR_COMPUTATION_DETAILS, + stage = LiquidLegionsSketchAggregationV2.Stage.WAIT_TO_START.toProtocolStage(), + computationDetails = LLV2_NON_AGGREGATOR_COMPUTATION_DETAILS, blobs = listOf(newPassThroughBlobMetadata(0L, "local-copy-of-sketches")) ) // Verify that next attempt succeeds. syncResult.await() val finalComputation = assertNotNull(fakeComputationDatabase[computation.key.computationId.toLong()]) - assertThat(finalComputation.computationStage).isEqualTo(SETUP_PHASE.toProtocolStage()) + assertThat(finalComputation.computationStage) + .isEqualTo(LiquidLegionsSketchAggregationV2.Stage.SETUP_PHASE.toProtocolStage()) } @Test @@ -825,8 +1153,8 @@ class HeraldTest { fakeComputationDatabase.addComputation( globalId = computation.key.computationId, - stage = INITIALIZATION_PHASE.toProtocolStage(), - computationDetails = NON_AGGREGATOR_COMPUTATION_DETAILS, + stage = LiquidLegionsSketchAggregationV2.Stage.INITIALIZATION_PHASE.toProtocolStage(), + computationDetails = LLV2_NON_AGGREGATOR_COMPUTATION_DETAILS, blobs = listOf(newInputBlobMetadata(0L, "local-copy-of-sketches")) ) @@ -952,7 +1280,7 @@ class HeraldTest { token = computationToken { globalComputationId = request.globalComputationId localComputationId = request.globalComputationId.toLong() - computationDetails = AGGREGATOR_COMPUTATION_DETAILS + computationDetails = LLV2_AGGREGATOR_COMPUTATION_DETAILS } } } @@ -1015,6 +1343,7 @@ class HeraldTest { systemApiRequisitions: List = listOf(), systemComputationParticipant: List = listOf(), serializedMeasurementSpec: ByteString = SERIALIZED_MEASUREMENT_SPEC, + mpcProtocolConfig: MpcProtocolConfig = LLV2_MPC_PROTOCOL_CONFIG ): Computation { return computation { name = ComputationKey(globalId).toName() @@ -1023,7 +1352,7 @@ class HeraldTest { state = stateAtKingdom requisitions += systemApiRequisitions computationParticipants += systemComputationParticipant - mpcProtocolConfig = MPC_PROTOCOL_CONFIG + this.mpcProtocolConfig = mpcProtocolConfig } } diff --git a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/BUILD.bazel b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/BUILD.bazel index 1ba87e7a446..9a59539dc81 100644 --- a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/BUILD.bazel +++ b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/BUILD.bazel @@ -29,3 +29,33 @@ kt_jvm_test( "@wfa_consent_signaling_client//src/main/kotlin/org/wfanet/measurement/consent/client/measurementconsumer", ], ) + +kt_jvm_test( + name = "ReachOnlyLiquidLegionsV2MillTest", + srcs = ["ReachOnlyLiquidLegionsV2MillTest.kt"], + test_class = "org.wfanet.measurement.duchy.daemon.mill.liquidlegionsv2.crypto.ReachOnlyLiquidLegionsV2MillTest", + deps = [ + "//imports/java/io/opentelemetry/api", + "//src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2:reach_only_liquid_legions_v2_mill", + "//src/main/kotlin/org/wfanet/measurement/duchy/daemon/testing", + "//src/main/kotlin/org/wfanet/measurement/duchy/db/computation", + "//src/main/kotlin/org/wfanet/measurement/duchy/db/computation/testing", + "//src/main/kotlin/org/wfanet/measurement/duchy/service/internal/computations", + "//src/main/kotlin/org/wfanet/measurement/duchy/storage:computation_store", + "//src/main/kotlin/org/wfanet/measurement/duchy/storage:requisition_store", + "//src/main/kotlin/org/wfanet/measurement/system/v1alpha:resource_key", + "//src/main/proto/wfa/measurement/api/v2alpha:crypto_kt_jvm_proto", + "//src/main/proto/wfa/measurement/internal/duchy/protocol:reach_only_liquid_legions_v2_encryption_methods_kt_jvm_proto", + "//src/main/swig/protocol/reachonlyliquidlegionsv2:reach_only_liquid_legions_v2_encryption_utility", + "@wfa_common_jvm//imports/kotlin/com/google/protobuf/kotlin", + "@wfa_common_jvm//imports/kotlin/kotlin/test", + "@wfa_common_jvm//imports/kotlin/org/mockito/kotlin", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/crypto:hashing", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/crypto/testing", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/crypto/tink", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc/testing", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/storage/filesystem:client", + "@wfa_consent_signaling_client//src/main/kotlin/org/wfanet/measurement/consent/client/common:key_handles", + "@wfa_consent_signaling_client//src/main/kotlin/org/wfanet/measurement/consent/client/measurementconsumer", + ], +) diff --git a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2MillTest.kt b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2MillTest.kt new file mode 100644 index 00000000000..c8065eec7a1 --- /dev/null +++ b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2MillTest.kt @@ -0,0 +1,1827 @@ +// Copyright 2023 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package org.wfanet.measurement.duchy.daemon.mill.liquidlegionsv2.crypto + +import com.google.common.truth.Truth.assertThat +import com.google.common.truth.extensions.proto.ProtoTruth.assertThat +import com.google.protobuf.ByteString +import com.google.protobuf.kotlin.toByteString +import io.grpc.Status +import io.opentelemetry.api.GlobalOpenTelemetry +import java.security.cert.X509Certificate +import java.time.Clock +import java.time.Duration +import java.util.Base64 +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.toList +import kotlinx.coroutines.runBlocking +import org.junit.Before +import org.junit.Rule +import org.junit.Test +import org.junit.rules.TemporaryFolder +import org.junit.runner.RunWith +import org.junit.runners.JUnit4 +import org.mockito.kotlin.UseConstructor +import org.mockito.kotlin.any +import org.mockito.kotlin.mock +import org.mockito.kotlin.whenever +import org.wfanet.anysketch.crypto.CombineElGamalPublicKeysRequest +import org.wfanet.anysketch.crypto.CombineElGamalPublicKeysResponse +import org.wfanet.measurement.api.v2alpha.ElGamalPublicKey as V2AlphaElGamalPublicKey +import org.wfanet.measurement.api.v2alpha.MeasurementSpecKt.reach +import org.wfanet.measurement.api.v2alpha.MeasurementSpecKt.vidSamplingInterval +import org.wfanet.measurement.api.v2alpha.measurementSpec +import org.wfanet.measurement.common.crypto.SigningKeyHandle +import org.wfanet.measurement.common.crypto.readCertificate +import org.wfanet.measurement.common.crypto.readPrivateKey +import org.wfanet.measurement.common.crypto.testing.TestData +import org.wfanet.measurement.common.crypto.tink.TinkPrivateKeyHandle +import org.wfanet.measurement.common.flatten +import org.wfanet.measurement.common.grpc.testing.GrpcTestServerRule +import org.wfanet.measurement.common.grpc.testing.mockService +import org.wfanet.measurement.common.identity.DuchyInfo +import org.wfanet.measurement.common.testing.chainRulesSequentially +import org.wfanet.measurement.common.testing.verifyProtoArgument +import org.wfanet.measurement.common.throttler.MinimumIntervalThrottler +import org.wfanet.measurement.consent.client.common.toEncryptionPublicKey +import org.wfanet.measurement.duchy.daemon.mill.Certificate +import org.wfanet.measurement.duchy.daemon.mill.liquidlegionsv2.ReachOnlyLiquidLegionsV2Mill +import org.wfanet.measurement.duchy.daemon.testing.TestRequisition +import org.wfanet.measurement.duchy.daemon.utils.toDuchyEncryptionPublicKey +import org.wfanet.measurement.duchy.db.computation.ComputationDataClients +import org.wfanet.measurement.duchy.db.computation.testing.FakeComputationsDatabase +import org.wfanet.measurement.duchy.service.internal.computations.ComputationsService +import org.wfanet.measurement.duchy.service.internal.computations.newEmptyOutputBlobMetadata +import org.wfanet.measurement.duchy.service.internal.computations.newInputBlobMetadata +import org.wfanet.measurement.duchy.storage.ComputationBlobContext +import org.wfanet.measurement.duchy.storage.ComputationStore +import org.wfanet.measurement.duchy.storage.RequisitionBlobContext +import org.wfanet.measurement.duchy.storage.RequisitionStore +import org.wfanet.measurement.duchy.toProtocolStage +import org.wfanet.measurement.internal.duchy.ComputationBlobDependency +import org.wfanet.measurement.internal.duchy.ComputationDetails.CompletedReason +import org.wfanet.measurement.internal.duchy.ComputationDetailsKt.kingdomComputationDetails +import org.wfanet.measurement.internal.duchy.ComputationStatsGrpcKt.ComputationStatsCoroutineImplBase +import org.wfanet.measurement.internal.duchy.ComputationStatsGrpcKt.ComputationStatsCoroutineStub +import org.wfanet.measurement.internal.duchy.ComputationToken +import org.wfanet.measurement.internal.duchy.ComputationsGrpcKt.ComputationsCoroutineStub +import org.wfanet.measurement.internal.duchy.ElGamalKeyPair +import org.wfanet.measurement.internal.duchy.ElGamalPublicKey +import org.wfanet.measurement.internal.duchy.computationDetails +import org.wfanet.measurement.internal.duchy.computationStageBlobMetadata +import org.wfanet.measurement.internal.duchy.computationToken +import org.wfanet.measurement.internal.duchy.config.LiquidLegionsV2SetupConfig.RoleInComputation +import org.wfanet.measurement.internal.duchy.copy +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyExecutionPhaseAtAggregatorRequest +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyExecutionPhaseAtAggregatorResponse +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyExecutionPhaseRequest +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyExecutionPhaseResponse +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyInitializationPhaseRequest +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyInitializationPhaseResponse +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlySetupPhaseRequest +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlySetupPhaseResponse +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2.ComputationDetails.ComputationParticipant +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2.ComputationDetails.Parameters +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2.Stage.COMPLETE +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2.Stage.CONFIRMATION_PHASE +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2.Stage.EXECUTION_PHASE +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2.Stage.INITIALIZATION_PHASE +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2.Stage.SETUP_PHASE +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2.Stage.WAIT_EXECUTION_PHASE_INPUTS +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2.Stage.WAIT_REQUISITIONS_AND_KEY_SET +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2.Stage.WAIT_SETUP_PHASE_INPUTS +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2.Stage.WAIT_TO_START +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2Kt +import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsV2NoiseConfig +import org.wfanet.measurement.internal.duchy.protocol.completeReachOnlyExecutionPhaseAtAggregatorRequest +import org.wfanet.measurement.internal.duchy.protocol.completeReachOnlyExecutionPhaseAtAggregatorResponse +import org.wfanet.measurement.internal.duchy.protocol.completeReachOnlyExecutionPhaseRequest +import org.wfanet.measurement.internal.duchy.protocol.completeReachOnlySetupPhaseRequest +import org.wfanet.measurement.internal.duchy.protocol.copy +import org.wfanet.measurement.internal.duchy.protocol.globalReachDpNoiseBaseline +import org.wfanet.measurement.internal.duchy.protocol.liquidLegionsSketchParameters +import org.wfanet.measurement.internal.duchy.protocol.reachNoiseDifferentialPrivacyParams +import org.wfanet.measurement.internal.duchy.protocol.registerNoiseGenerationParameters +import org.wfanet.measurement.storage.Store.Blob +import org.wfanet.measurement.storage.filesystem.FileSystemStorageClient +import org.wfanet.measurement.system.v1alpha.AdvanceComputationRequest +import org.wfanet.measurement.system.v1alpha.AdvanceComputationResponse +import org.wfanet.measurement.system.v1alpha.Computation +import org.wfanet.measurement.system.v1alpha.ComputationControlGrpcKt.ComputationControlCoroutineImplBase +import org.wfanet.measurement.system.v1alpha.ComputationControlGrpcKt.ComputationControlCoroutineStub +import org.wfanet.measurement.system.v1alpha.ComputationKey +import org.wfanet.measurement.system.v1alpha.ComputationLogEntriesGrpcKt.ComputationLogEntriesCoroutineImplBase as SystemComputationLogEntriesCoroutineImplBase +import org.wfanet.measurement.system.v1alpha.ComputationLogEntriesGrpcKt.ComputationLogEntriesCoroutineStub as SystemComputationLogEntriesCoroutineStub +import org.wfanet.measurement.system.v1alpha.ComputationLogEntry +import org.wfanet.measurement.system.v1alpha.ComputationParticipantKey +import org.wfanet.measurement.system.v1alpha.ComputationParticipantsGrpcKt.ComputationParticipantsCoroutineImplBase as SystemComputationParticipantsCoroutineImplBase +import org.wfanet.measurement.system.v1alpha.ComputationParticipantsGrpcKt.ComputationParticipantsCoroutineStub as SystemComputationParticipantsCoroutineStub +import org.wfanet.measurement.system.v1alpha.ComputationsGrpcKt.ComputationsCoroutineImplBase as SystemComputationsCoroutineImplBase +import org.wfanet.measurement.system.v1alpha.ComputationsGrpcKt.ComputationsCoroutineStub as SystemComputationsCoroutineStub +import org.wfanet.measurement.system.v1alpha.ConfirmComputationParticipantRequest +import org.wfanet.measurement.system.v1alpha.FailComputationParticipantRequest +import org.wfanet.measurement.system.v1alpha.ReachOnlyLiquidLegionsV2 +import org.wfanet.measurement.system.v1alpha.ReachOnlyLiquidLegionsV2.Description.EXECUTION_PHASE_INPUT +import org.wfanet.measurement.system.v1alpha.ReachOnlyLiquidLegionsV2.Description.SETUP_PHASE_INPUT +import org.wfanet.measurement.system.v1alpha.Requisition +import org.wfanet.measurement.system.v1alpha.SetComputationResultRequest +import org.wfanet.measurement.system.v1alpha.SetParticipantRequisitionParamsRequest +import org.wfanet.measurement.system.v1alpha.setComputationResultRequest + +private const val PUBLIC_API_VERSION = "v2alpha" + +private const val WORKER_COUNT = 3 +private const val MILL_ID = "a nice mill" +private const val DUCHY_ONE_NAME = "DUCHY_ONE" +private const val DUCHY_TWO_NAME = "DUCHY_TWO" +private const val DUCHY_THREE_NAME = "DUCHY_THREE" +private const val DECAY_RATE = 12.0 +private const val SKETCH_SIZE = 100_000L +private const val CURVE_ID = 415L // NID_X9_62_prime256v1 +private const val PARALLELISM = 2 + +private const val LOCAL_ID = 1234L +private const val GLOBAL_ID = LOCAL_ID.toString() + +private const val CIPHERTEXT_SIZE = 66 +private const val NOISE_CIPHERTEXT = "abcdefghijklmnopqrstuvwxyz0123456abcdefghijklmnopqrstuvwxyz0123456" +private val SERIALIZED_NOISE_CIPHERTEXT = ByteString.copyFromUtf8(NOISE_CIPHERTEXT) + +private val DUCHY_ONE_KEY_PAIR = + ElGamalKeyPair.newBuilder() + .apply { + publicKeyBuilder.apply { + generator = ByteString.copyFromUtf8("generator_1") + element = ByteString.copyFromUtf8("element_1") + } + secretKey = ByteString.copyFromUtf8("secret_key_1") + } + .build() +private val DUCHY_TWO_PUBLIC_KEY = + ElGamalPublicKey.newBuilder() + .apply { + generator = ByteString.copyFromUtf8("generator_2") + element = ByteString.copyFromUtf8("element_2") + } + .build() +private val DUCHY_THREE_PUBLIC_KEY = + ElGamalPublicKey.newBuilder() + .apply { + generator = ByteString.copyFromUtf8("generator_3") + element = ByteString.copyFromUtf8("element_3") + } + .build() +private val COMBINED_PUBLIC_KEY = + ElGamalPublicKey.newBuilder() + .apply { + generator = ByteString.copyFromUtf8("generator_1_generator_2_generator_3") + element = ByteString.copyFromUtf8("element_1_element_2_element_3") + } + .build() +private val PARTIALLY_COMBINED_PUBLIC_KEY = + ElGamalPublicKey.newBuilder() + .apply { + generator = ByteString.copyFromUtf8("generator_2_generator_3") + element = ByteString.copyFromUtf8("element_2_element_3") + } + .build() + +private val TEST_NOISE_CONFIG = + LiquidLegionsV2NoiseConfig.newBuilder() + .apply { + reachNoiseConfigBuilder.apply { + blindHistogramNoiseBuilder.apply { + epsilon = 1.0 + delta = 2.0 + } + noiseForPublisherNoiseBuilder.apply { + epsilon = 3.0 + delta = 4.0 + } + globalReachDpNoiseBuilder.apply { + epsilon = 5.0 + delta = 6.0 + } + } + } + .build() + +private val ROLLV2_PARAMETERS = + Parameters.newBuilder() + .apply { + reachOnlyLiquidLegionsSketchBuilder.apply { + decayRate = DECAY_RATE + size = SKETCH_SIZE + } + noise = TEST_NOISE_CONFIG + ellipticCurveId = CURVE_ID.toInt() + } + .build() + +// In the test, use the same set of cert and encryption key for all parties. +private const val CONSENT_SIGNALING_CERT_NAME = "Just a name" +private val CONSENT_SIGNALING_CERT_DER = + TestData.FIXED_SERVER_CERT_DER_FILE.readBytes().toByteString() +private val CONSENT_SIGNALING_PRIVATE_KEY_DER = + TestData.FIXED_SERVER_KEY_DER_FILE.readBytes().toByteString() +private val CONSENT_SIGNALING_ROOT_CERT: X509Certificate = + readCertificate(TestData.FIXED_CA_CERT_PEM_FILE) +private val ENCRYPTION_PRIVATE_KEY = TinkPrivateKeyHandle.generateEcies() +private val ENCRYPTION_PUBLIC_KEY: org.wfanet.measurement.api.v2alpha.EncryptionPublicKey = + ENCRYPTION_PRIVATE_KEY.publicKey.toEncryptionPublicKey() + +/** A public Key used for consent signaling check. */ +private val CONSENT_SIGNALING_EL_GAMAL_PUBLIC_KEY = + V2AlphaElGamalPublicKey.newBuilder() + .apply { + generator = ByteString.copyFromUtf8("generator-foo") + element = ByteString.copyFromUtf8("element-foo") + } + .build() +/** A pre-computed signature of the CONSENT_SIGNALING_EL_GAMAL_PUBLIC_KEY. */ +private val CONSENT_SIGNALING_EL_GAMAL_PUBLIC_KEY_SINGATURE = + ByteString.copyFrom( + Base64.getDecoder() + .decode( + "MEUCIB2HWi/udHE9YlCH6n9lnGY12I9F1ra1SRWoJrIOXGiAAiEAm90wrJRqABFC5sjej+" + + "zjSBOMHcmFOEpfW9tXaCla6Qs=" + ) + ) + +private val COMPUTATION_PARTICIPANT_1 = + ComputationParticipant.newBuilder() + .apply { + duchyId = DUCHY_ONE_NAME + publicKey = DUCHY_ONE_KEY_PAIR.publicKey + elGamalPublicKey = CONSENT_SIGNALING_EL_GAMAL_PUBLIC_KEY.toByteString() + elGamalPublicKeySignature = CONSENT_SIGNALING_EL_GAMAL_PUBLIC_KEY_SINGATURE + duchyCertificateDer = CONSENT_SIGNALING_CERT_DER + } + .build() +private val COMPUTATION_PARTICIPANT_2 = + ComputationParticipant.newBuilder() + .apply { + duchyId = DUCHY_TWO_NAME + publicKey = DUCHY_TWO_PUBLIC_KEY + elGamalPublicKey = CONSENT_SIGNALING_EL_GAMAL_PUBLIC_KEY.toByteString() + elGamalPublicKeySignature = CONSENT_SIGNALING_EL_GAMAL_PUBLIC_KEY_SINGATURE + duchyCertificateDer = CONSENT_SIGNALING_CERT_DER + } + .build() +private val COMPUTATION_PARTICIPANT_3 = + ComputationParticipant.newBuilder() + .apply { + duchyId = DUCHY_THREE_NAME + publicKey = DUCHY_THREE_PUBLIC_KEY + elGamalPublicKey = CONSENT_SIGNALING_EL_GAMAL_PUBLIC_KEY.toByteString() + elGamalPublicKeySignature = CONSENT_SIGNALING_EL_GAMAL_PUBLIC_KEY_SINGATURE + duchyCertificateDer = CONSENT_SIGNALING_CERT_DER + } + .build() + +private val TEST_REQUISITION_1 = TestRequisition("111") { SERIALIZED_MEASUREMENT_SPEC } +private val TEST_REQUISITION_2 = TestRequisition("222") { SERIALIZED_MEASUREMENT_SPEC } +private val TEST_REQUISITION_3 = TestRequisition("333") { SERIALIZED_MEASUREMENT_SPEC } + +private val MEASUREMENT_SPEC = measurementSpec { + nonceHashes += TEST_REQUISITION_1.nonceHash + nonceHashes += TEST_REQUISITION_2.nonceHash + nonceHashes += TEST_REQUISITION_3.nonceHash + reach = reach {} +} +private val SERIALIZED_MEASUREMENT_SPEC: ByteString = MEASUREMENT_SPEC.toByteString() + +private val MEASUREMENT_SPEC_WITH_VID_SAMPLING_WIDTH = measurementSpec { + nonceHashes += TEST_REQUISITION_1.nonceHash + nonceHashes += TEST_REQUISITION_2.nonceHash + nonceHashes += TEST_REQUISITION_3.nonceHash + reach = reach {} + vidSamplingInterval = vidSamplingInterval { width = 0.5f } +} + +private val SERIALIZED_MEASUREMENT_SPEC_WITH_VID_SAMPLING_WIDTH = + MEASUREMENT_SPEC_WITH_VID_SAMPLING_WIDTH.toByteString() + +private val REQUISITION_1 = + TEST_REQUISITION_1.toRequisitionMetadata(Requisition.State.FULFILLED, DUCHY_ONE_NAME).copy { + path = RequisitionBlobContext(GLOBAL_ID, externalKey.externalRequisitionId).blobKey + } +private val REQUISITION_2 = + TEST_REQUISITION_2.toRequisitionMetadata(Requisition.State.FULFILLED, DUCHY_TWO_NAME) +private val REQUISITION_3 = + TEST_REQUISITION_3.toRequisitionMetadata(Requisition.State.FULFILLED, DUCHY_THREE_NAME) +private val REQUISITIONS = listOf(REQUISITION_1, REQUISITION_2, REQUISITION_3) + +private val AGGREGATOR_COMPUTATION_DETAILS = computationDetails { + kingdomComputation = kingdomComputationDetails { + publicApiVersion = PUBLIC_API_VERSION + measurementPublicKey = ENCRYPTION_PUBLIC_KEY.toDuchyEncryptionPublicKey() + measurementSpec = SERIALIZED_MEASUREMENT_SPEC + } + reachOnlyLiquidLegionsV2 = + ReachOnlyLiquidLegionsSketchAggregationV2Kt.computationDetails { + role = RoleInComputation.AGGREGATOR + parameters = ROLLV2_PARAMETERS + participant += + listOf(COMPUTATION_PARTICIPANT_2, COMPUTATION_PARTICIPANT_3, COMPUTATION_PARTICIPANT_1) + combinedPublicKey = COMBINED_PUBLIC_KEY + // partiallyCombinedPublicKey and combinedPublicKey are the same at the aggregator. + partiallyCombinedPublicKey = COMBINED_PUBLIC_KEY + localElgamalKey = DUCHY_ONE_KEY_PAIR + } +} + +private val NON_AGGREGATOR_COMPUTATION_DETAILS = computationDetails { + kingdomComputation = kingdomComputationDetails { + publicApiVersion = PUBLIC_API_VERSION + measurementPublicKey = ENCRYPTION_PUBLIC_KEY.toDuchyEncryptionPublicKey() + measurementSpec = SERIALIZED_MEASUREMENT_SPEC + } + reachOnlyLiquidLegionsV2 = + ReachOnlyLiquidLegionsSketchAggregationV2Kt.computationDetails { + role = RoleInComputation.NON_AGGREGATOR + parameters = ROLLV2_PARAMETERS + participant += + listOf(COMPUTATION_PARTICIPANT_1, COMPUTATION_PARTICIPANT_2, COMPUTATION_PARTICIPANT_3) + combinedPublicKey = COMBINED_PUBLIC_KEY + partiallyCombinedPublicKey = PARTIALLY_COMBINED_PUBLIC_KEY + localElgamalKey = DUCHY_ONE_KEY_PAIR + } +} + +@RunWith(JUnit4::class) +class ReachOnlyLiquidLegionsV2MillTest { + private val mockReachOnlyLiquidLegionsComputationControl: ComputationControlCoroutineImplBase = + mockService { + onBlocking { advanceComputation(any()) } + .thenAnswer { + val request: Flow = it.getArgument(0) + computationControlRequests = runBlocking { request.toList() } + AdvanceComputationResponse.getDefaultInstance() + } + } + private val mockSystemComputations: SystemComputationsCoroutineImplBase = mockService() + private val mockComputationParticipants: SystemComputationParticipantsCoroutineImplBase = + mockService() + private val mockComputationLogEntries: SystemComputationLogEntriesCoroutineImplBase = + mockService() + private val mockComputationStats: ComputationStatsCoroutineImplBase = mockService() + private val mockCryptoWorker: ReachOnlyLiquidLegionsV2Encryption = + mock(useConstructor = UseConstructor.parameterless()) { + on { combineElGamalPublicKeys(any()) } + .thenAnswer { + val cryptoRequest: CombineElGamalPublicKeysRequest = it.getArgument(0) + CombineElGamalPublicKeysResponse.newBuilder() + .apply { + elGamalKeysBuilder.apply { + generator = + ByteString.copyFromUtf8( + cryptoRequest.elGamalKeysList + .sortedBy { key -> key.generator.toStringUtf8() } + .joinToString(separator = "_") { key -> key.generator.toStringUtf8() } + ) + element = + ByteString.copyFromUtf8( + cryptoRequest.elGamalKeysList + .sortedBy { key -> key.element.toStringUtf8() } + .joinToString(separator = "_") { key -> key.element.toStringUtf8() } + ) + } + } + .build() + } + } + private val fakeComputationDb = FakeComputationsDatabase() + + private lateinit var computationDataClients: ComputationDataClients + private lateinit var computationStore: ComputationStore + private lateinit var requisitionStore: RequisitionStore + + private val tempDirectory = TemporaryFolder() + + private val grpcTestServerRule = GrpcTestServerRule { + val storageClient = FileSystemStorageClient(tempDirectory.root) + computationStore = ComputationStore(storageClient) + requisitionStore = RequisitionStore(storageClient) + computationDataClients = + ComputationDataClients.forTesting( + ComputationsCoroutineStub(channel), + computationStore, + requisitionStore + ) + addService(mockReachOnlyLiquidLegionsComputationControl) + addService(mockSystemComputations) + addService(mockComputationLogEntries) + addService(mockComputationParticipants) + addService(mockComputationStats) + addService( + ComputationsService( + fakeComputationDb, + systemComputationLogEntriesStub, + computationStore, + requisitionStore, + DUCHY_THREE_NAME, + Clock.systemUTC() + ) + ) + } + + @get:Rule val ruleChain = chainRulesSequentially(tempDirectory, grpcTestServerRule) + + private val workerStub: ComputationControlCoroutineStub by lazy { + ComputationControlCoroutineStub(grpcTestServerRule.channel) + } + + private val systemComputationStub: SystemComputationsCoroutineStub by lazy { + SystemComputationsCoroutineStub(grpcTestServerRule.channel) + } + + private val systemComputationLogEntriesStub: SystemComputationLogEntriesCoroutineStub by lazy { + SystemComputationLogEntriesCoroutineStub(grpcTestServerRule.channel) + } + + private val systemComputationParticipantsStub: + SystemComputationParticipantsCoroutineStub by lazy { + SystemComputationParticipantsCoroutineStub(grpcTestServerRule.channel) + } + + private val computationStatsStub: ComputationStatsCoroutineStub by lazy { + ComputationStatsCoroutineStub(grpcTestServerRule.channel) + } + + private lateinit var computationControlRequests: List + + // Just use the same workerStub for all other duchies, since it is not relevant to this test. + private val workerStubs = mapOf(DUCHY_TWO_NAME to workerStub, DUCHY_THREE_NAME to workerStub) + + private lateinit var aggregatorMill: ReachOnlyLiquidLegionsV2Mill + private lateinit var nonAggregatorMill: ReachOnlyLiquidLegionsV2Mill + + private fun buildAdvanceComputationRequests( + globalComputationId: String, + description: ReachOnlyLiquidLegionsV2.Description, + vararg chunkContents: String + ): List { + val header = + AdvanceComputationRequest.newBuilder() + .apply { + headerBuilder.apply { + name = ComputationKey(globalComputationId).toName() + reachOnlyLiquidLegionsV2Builder.description = description + } + } + .build() + val body = + chunkContents.asList().map { + AdvanceComputationRequest.newBuilder() + .apply { bodyChunkBuilder.apply { partialData = ByteString.copyFromUtf8(it) } } + .build() + } + return listOf(header) + body + } + + @Before + fun initializeMill() = runBlocking { + val throttler = MinimumIntervalThrottler(Clock.systemUTC(), Duration.ofSeconds(60)) + DuchyInfo.setForTest(setOf(DUCHY_ONE_NAME, DUCHY_TWO_NAME, DUCHY_THREE_NAME)) + val csX509Certificate = readCertificate(CONSENT_SIGNALING_CERT_DER) + val csSigningKey = + SigningKeyHandle( + csX509Certificate, + readPrivateKey(CONSENT_SIGNALING_PRIVATE_KEY_DER, csX509Certificate.publicKey.algorithm) + ) + val csCertificate = Certificate(CONSENT_SIGNALING_CERT_NAME, csX509Certificate) + val trustedCertificates = + DuchyInfo.entries.values.associateBy( + { it.rootCertificateSkid }, + { CONSENT_SIGNALING_ROOT_CERT } + ) + + aggregatorMill = + ReachOnlyLiquidLegionsV2Mill( + millId = MILL_ID, + duchyId = DUCHY_ONE_NAME, + signingKey = csSigningKey, + consentSignalCert = csCertificate, + trustedCertificates = trustedCertificates, + dataClients = computationDataClients, + systemComputationParticipantsClient = systemComputationParticipantsStub, + systemComputationsClient = systemComputationStub, + systemComputationLogEntriesClient = systemComputationLogEntriesStub, + computationStatsClient = computationStatsStub, + throttler = throttler, + workerStubs = workerStubs, + cryptoWorker = mockCryptoWorker, + workLockDuration = Duration.ofMinutes(5), + openTelemetry = GlobalOpenTelemetry.get(), + requestChunkSizeBytes = 20, + maximumAttempts = 2, + parallelism = PARALLELISM + ) + nonAggregatorMill = + ReachOnlyLiquidLegionsV2Mill( + millId = MILL_ID, + duchyId = DUCHY_ONE_NAME, + signingKey = csSigningKey, + consentSignalCert = csCertificate, + trustedCertificates = trustedCertificates, + dataClients = computationDataClients, + systemComputationParticipantsClient = systemComputationParticipantsStub, + systemComputationsClient = systemComputationStub, + systemComputationLogEntriesClient = systemComputationLogEntriesStub, + computationStatsClient = computationStatsStub, + throttler = throttler, + workerStubs = workerStubs, + cryptoWorker = mockCryptoWorker, + workLockDuration = Duration.ofMinutes(5), + openTelemetry = GlobalOpenTelemetry.get(), + requestChunkSizeBytes = 20, + maximumAttempts = 2, + parallelism = PARALLELISM + ) + } + + @Test + fun `exceeding max attempt should fail the computation`() = runBlocking { + // Stage 0. preparing the database and set up mock + val partialToken = + FakeComputationsDatabase.newPartialToken( + localId = LOCAL_ID, + stage = INITIALIZATION_PHASE.toProtocolStage() + ) + .build() + + val initialComputationDetails = + NON_AGGREGATOR_COMPUTATION_DETAILS.toBuilder() + .apply { + reachOnlyLiquidLegionsV2Builder.apply { + parametersBuilder.ellipticCurveId = CURVE_ID.toInt() + clearPartiallyCombinedPublicKey() + clearCombinedPublicKey() + clearLocalElgamalKey() + } + } + .build() + + fakeComputationDb.addComputation( + partialToken.localComputationId, + partialToken.computationStage, + computationDetails = initialComputationDetails, + requisitions = REQUISITIONS + ) + + whenever(mockCryptoWorker.completeReachOnlyInitializationPhase(any())).thenAnswer { + CompleteReachOnlyInitializationPhaseResponse.newBuilder() + .apply { + elGamalKeyPairBuilder.apply { + publicKeyBuilder.apply { + generator = ByteString.copyFromUtf8("generator-foo") + element = ByteString.copyFromUtf8("element-foo") + } + secretKey = ByteString.copyFromUtf8("secretKey-foo") + } + } + .build() + } + + // This will result in TRANSIENT gRPC failure. + whenever(mockComputationParticipants.setParticipantRequisitionParams(any())) + .thenThrow(Status.UNKNOWN.asRuntimeException()) + + // First attempt fails, which doesn't change the computation stage. + nonAggregatorMill.pollAndProcessNextComputation() + + assertThat(fakeComputationDb[LOCAL_ID]) + .isEqualTo( + ComputationToken.newBuilder() + .apply { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 1 + computationStage = INITIALIZATION_PHASE.toProtocolStage() + version = 3 // claimTask + updateComputationDetails + enqueueComputation + computationDetails = + initialComputationDetails + .toBuilder() + .apply { + reachOnlyLiquidLegionsV2Builder.localElgamalKeyBuilder.apply { + publicKeyBuilder.apply { + generator = ByteString.copyFromUtf8("generator-foo") + element = ByteString.copyFromUtf8("element-foo") + } + secretKey = ByteString.copyFromUtf8("secretKey-foo") + } + } + .build() + addAllRequisitions(REQUISITIONS) + } + .build() + ) + // Second attempt fails, which doesn't change the computation stage. + nonAggregatorMill.pollAndProcessNextComputation() + + assertThat(fakeComputationDb[LOCAL_ID]) + .isEqualTo( + ComputationToken.newBuilder() + .apply { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 2 + computationStage = INITIALIZATION_PHASE.toProtocolStage() + version = 5 // claimTask + updateComputationDetails + enqueueComputation + computationDetails = + initialComputationDetails + .toBuilder() + .apply { + reachOnlyLiquidLegionsV2Builder.localElgamalKeyBuilder.apply { + publicKeyBuilder.apply { + generator = ByteString.copyFromUtf8("generator-foo") + element = ByteString.copyFromUtf8("element-foo") + } + secretKey = ByteString.copyFromUtf8("secretKey-foo") + } + } + .build() + addAllRequisitions(REQUISITIONS) + } + .build() + ) + + // Third attempt fails, which will fail the computation. + nonAggregatorMill.pollAndProcessNextComputation() + + assertThat(fakeComputationDb[LOCAL_ID]) + .isEqualTo( + ComputationToken.newBuilder() + .apply { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 3 + computationStage = COMPLETE.toProtocolStage() + version = 8 // claimTask + updateComputationDetails + enqueueComputation + claimTask + + // EndComputation + computationDetails = + initialComputationDetails + .toBuilder() + .apply { + endingState = CompletedReason.FAILED + reachOnlyLiquidLegionsV2Builder.localElgamalKeyBuilder.apply { + publicKeyBuilder.apply { + generator = ByteString.copyFromUtf8("generator-foo") + element = ByteString.copyFromUtf8("element-foo") + } + secretKey = ByteString.copyFromUtf8("secretKey-foo") + } + } + .build() + addAllRequisitions(REQUISITIONS) + } + .build() + ) + } + + @Test + fun `initialization phase`() = runBlocking { + // Stage 0. preparing the database and set up mock + val partialToken = + FakeComputationsDatabase.newPartialToken( + localId = LOCAL_ID, + stage = INITIALIZATION_PHASE.toProtocolStage() + ) + .build() + + val initialComputationDetails = + NON_AGGREGATOR_COMPUTATION_DETAILS.toBuilder() + .apply { + reachOnlyLiquidLegionsV2Builder.apply { + parametersBuilder.ellipticCurveId = CURVE_ID.toInt() + clearPartiallyCombinedPublicKey() + clearCombinedPublicKey() + clearLocalElgamalKey() + } + } + .build() + + fakeComputationDb.addComputation( + partialToken.localComputationId, + partialToken.computationStage, + computationDetails = initialComputationDetails, + requisitions = REQUISITIONS + ) + + var cryptoRequest = CompleteReachOnlyInitializationPhaseRequest.getDefaultInstance() + whenever(mockCryptoWorker.completeReachOnlyInitializationPhase(any())).thenAnswer { + cryptoRequest = it.getArgument(0) + CompleteReachOnlyInitializationPhaseResponse.newBuilder() + .apply { + elGamalKeyPairBuilder.apply { + publicKeyBuilder.apply { + generator = ByteString.copyFromUtf8("generator-foo") + element = ByteString.copyFromUtf8("element-foo") + } + secretKey = ByteString.copyFromUtf8("secretKey-foo") + } + } + .build() + } + + // Stage 1. Process the above computation + nonAggregatorMill.pollAndProcessNextComputation() + + // Stage 2. Check the status of the computation + assertThat(fakeComputationDb[LOCAL_ID]) + .isEqualTo( + ComputationToken.newBuilder() + .apply { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 1 + computationStage = WAIT_REQUISITIONS_AND_KEY_SET.toProtocolStage() + version = 3 // claimTask + updateComputationDetails + transitionStage + computationDetails = + initialComputationDetails + .toBuilder() + .apply { + reachOnlyLiquidLegionsV2Builder.localElgamalKeyBuilder.apply { + publicKeyBuilder.apply { + generator = ByteString.copyFromUtf8("generator-foo") + element = ByteString.copyFromUtf8("element-foo") + } + secretKey = ByteString.copyFromUtf8("secretKey-foo") + } + } + .build() + addAllRequisitions(REQUISITIONS) + } + .build() + ) + + verifyProtoArgument( + mockComputationParticipants, + SystemComputationParticipantsCoroutineImplBase::setParticipantRequisitionParams + ) + .comparingExpectedFieldsOnly() + .isEqualTo( + SetParticipantRequisitionParamsRequest.newBuilder() + .apply { + name = ComputationParticipantKey(GLOBAL_ID, DUCHY_ONE_NAME).toName() + requisitionParamsBuilder.apply { + duchyCertificate = CONSENT_SIGNALING_CERT_NAME + reachOnlyLiquidLegionsV2Builder.apply { + elGamalPublicKey = CONSENT_SIGNALING_EL_GAMAL_PUBLIC_KEY.toByteString() + } + } + } + .build() + ) + + assertThat(cryptoRequest) + .isEqualTo( + CompleteReachOnlyInitializationPhaseRequest.newBuilder().apply { curveId = CURVE_ID }.build() + ) + } + + @Test + fun `confirmation phase, failed due to missing local requisition`() = runBlocking { + // Stage 0. preparing the storage and set up mock + val requisition1 = REQUISITION_1 + // requisition2 is fulfilled at Duchy One, but doesn't have path set. + val requisition2 = + REQUISITION_2.copy { details = details.copy { externalFulfillingDuchyId = DUCHY_ONE_NAME } } + val computationDetailsWithoutPublicKey = + AGGREGATOR_COMPUTATION_DETAILS.toBuilder() + .apply { reachOnlyLiquidLegionsV2Builder.clearCombinedPublicKey().clearPartiallyCombinedPublicKey() } + .build() + fakeComputationDb.addComputation( + globalId = GLOBAL_ID, + stage = CONFIRMATION_PHASE.toProtocolStage(), + computationDetails = computationDetailsWithoutPublicKey, + requisitions = listOf(requisition1, requisition2) + ) + + whenever(mockComputationLogEntries.createComputationLogEntry(any())) + .thenReturn(ComputationLogEntry.getDefaultInstance()) + + // Stage 1. Process the above computation + aggregatorMill.pollAndProcessNextComputation() + + // Stage 2. Check the status of the computation + assertThat(fakeComputationDb[LOCAL_ID]!!) + .isEqualTo( + ComputationToken.newBuilder() + .apply { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 1 + computationStage = COMPLETE.toProtocolStage() + version = 2 // claimTask + transitionStage + computationDetails = + computationDetailsWithoutPublicKey + .toBuilder() + .apply { endingState = CompletedReason.FAILED } + .build() + addAllRequisitions(listOf(requisition1, requisition2)) + } + .build() + ) + + verifyProtoArgument( + mockComputationParticipants, + SystemComputationParticipantsCoroutineImplBase::failComputationParticipant + ) + .comparingExpectedFieldsOnly() + .isEqualTo( + FailComputationParticipantRequest.newBuilder() + .apply { + name = ComputationParticipantKey(GLOBAL_ID, DUCHY_ONE_NAME).toName() + failureBuilder.apply { + participantChildReferenceId = MILL_ID + errorMessage = + "PERMANENT error: java.lang.Exception: @Mill a nice mill, Computation 1234 " + + "failed due to:\n" + + "Cannot verify participation of all DataProviders.\n" + + "Missing expected data for requisition 222." + stageAttemptBuilder.apply { + stage = CONFIRMATION_PHASE.number + stageName = CONFIRMATION_PHASE.name + attemptNumber = 1 + } + } + } + .build() + ) + } + + @Test + fun `confirmation phase, passed at non-aggregator`() = runBlocking { + // Stage 0. preparing the storage and set up mock + val computationDetailsWithoutPublicKey = + NON_AGGREGATOR_COMPUTATION_DETAILS.toBuilder() + .apply { reachOnlyLiquidLegionsV2Builder.clearCombinedPublicKey().clearPartiallyCombinedPublicKey() } + .build() + fakeComputationDb.addComputation( + globalId = GLOBAL_ID, + stage = CONFIRMATION_PHASE.toProtocolStage(), + computationDetails = computationDetailsWithoutPublicKey, + requisitions = REQUISITIONS + ) + + whenever(mockComputationLogEntries.createComputationLogEntry(any())) + .thenReturn(ComputationLogEntry.getDefaultInstance()) + + // Stage 1. Process the above computation + aggregatorMill.pollAndProcessNextComputation() + + // Stage 2. Check the status of the computation + assertThat(fakeComputationDb[LOCAL_ID]!!) + .isEqualTo( + ComputationToken.newBuilder() + .apply { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 1 + computationStage = WAIT_TO_START.toProtocolStage() + version = 3 // claimTask + updateComputationDetail + transitionStage + computationDetails = + NON_AGGREGATOR_COMPUTATION_DETAILS.toBuilder() + .apply { + reachOnlyLiquidLegionsV2Builder.apply { + combinedPublicKey = COMBINED_PUBLIC_KEY + partiallyCombinedPublicKey = PARTIALLY_COMBINED_PUBLIC_KEY + } + } + .build() + addAllRequisitions(REQUISITIONS) + } + .build() + ) + + verifyProtoArgument( + mockComputationParticipants, + SystemComputationParticipantsCoroutineImplBase::confirmComputationParticipant + ) + .isEqualTo( + ConfirmComputationParticipantRequest.newBuilder() + .apply { name = ComputationParticipantKey(GLOBAL_ID, DUCHY_ONE_NAME).toName() } + .build() + ) + } + + @Test + fun `confirmation phase, passed at aggregator`() = runBlocking { + // Stage 0. preparing the storage and set up mock + val computationDetailsWithoutPublicKey = + AGGREGATOR_COMPUTATION_DETAILS.toBuilder() + .apply { reachOnlyLiquidLegionsV2Builder.clearCombinedPublicKey().clearPartiallyCombinedPublicKey() } + .build() + fakeComputationDb.addComputation( + globalId = GLOBAL_ID, + stage = CONFIRMATION_PHASE.toProtocolStage(), + computationDetails = computationDetailsWithoutPublicKey, + requisitions = REQUISITIONS + ) + + whenever(mockComputationLogEntries.createComputationLogEntry(any())) + .thenReturn(ComputationLogEntry.getDefaultInstance()) + + // Stage 1. Process the above computation + aggregatorMill.pollAndProcessNextComputation() + + // Stage 2. Check the status of the computation + assertThat(fakeComputationDb[LOCAL_ID]!!) + .isEqualTo( + ComputationToken.newBuilder() + .apply { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 1 + computationStage = WAIT_SETUP_PHASE_INPUTS.toProtocolStage() + version = 3 // claimTask + updateComputationDetails + transitionStage + addAllBlobs(listOf(newEmptyOutputBlobMetadata(0), newEmptyOutputBlobMetadata(1))) + stageSpecificDetailsBuilder.apply { + reachOnlyLiquidLegionsV2Builder.waitSetupPhaseInputsDetailsBuilder.apply { + putExternalDuchyLocalBlobId("DUCHY_TWO", 0L) + putExternalDuchyLocalBlobId("DUCHY_THREE", 1L) + } + } + computationDetails = + AGGREGATOR_COMPUTATION_DETAILS.toBuilder() + .apply { + reachOnlyLiquidLegionsV2Builder.apply { + combinedPublicKey = COMBINED_PUBLIC_KEY + partiallyCombinedPublicKey = COMBINED_PUBLIC_KEY + } + } + .build() + addAllRequisitions(REQUISITIONS) + } + .build() + ) + + verifyProtoArgument( + mockComputationParticipants, + SystemComputationParticipantsCoroutineImplBase::confirmComputationParticipant + ) + .isEqualTo( + ConfirmComputationParticipantRequest.newBuilder() + .apply { name = ComputationParticipantKey(GLOBAL_ID, DUCHY_ONE_NAME).toName() } + .build() + ) + } + + @Test + fun `confirmation phase, failed due to invalid nonce and ElGamal key signature`() = runBlocking { + // Stage 0. preparing the storage and set up mock + val computationDetailsWithoutInvalidDuchySignature = + AGGREGATOR_COMPUTATION_DETAILS.toBuilder() + .apply { + reachOnlyLiquidLegionsV2Builder.apply { + participantBuilderList[0].apply { + elGamalPublicKeySignature = ByteString.copyFromUtf8("An invalid signature") + } + } + } + .build() + val requisitionWithInvalidNonce = REQUISITION_1.copy { details = details.copy { nonce = 404L } } + fakeComputationDb.addComputation( + globalId = GLOBAL_ID, + stage = CONFIRMATION_PHASE.toProtocolStage(), + computationDetails = computationDetailsWithoutInvalidDuchySignature, + requisitions = listOf(requisitionWithInvalidNonce) + ) + + whenever(mockComputationLogEntries.createComputationLogEntry(any())) + .thenReturn(ComputationLogEntry.getDefaultInstance()) + + // Stage 1. Process the above computation + aggregatorMill.pollAndProcessNextComputation() + + // Stage 2. Check the status of the computation + assertThat(fakeComputationDb[LOCAL_ID]!!) + .isEqualTo( + ComputationToken.newBuilder() + .apply { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 1 + computationStage = COMPLETE.toProtocolStage() + version = 2 // claimTask + transitionStage + computationDetails = + computationDetailsWithoutInvalidDuchySignature + .toBuilder() + .apply { endingState = CompletedReason.FAILED } + .build() + addAllRequisitions(listOf(requisitionWithInvalidNonce)) + } + .build() + ) + + verifyProtoArgument( + mockComputationParticipants, + SystemComputationParticipantsCoroutineImplBase::failComputationParticipant + ) + .comparingExpectedFieldsOnly() + .isEqualTo( + FailComputationParticipantRequest.newBuilder() + .apply { + name = ComputationParticipantKey(GLOBAL_ID, DUCHY_ONE_NAME).toName() + failureBuilder.apply { + participantChildReferenceId = MILL_ID + errorMessage = + "PERMANENT error: java.lang.Exception: @Mill a nice mill, Computation 1234 " + + "failed due to:\n" + + "Cannot verify participation of all DataProviders.\n" + + "Invalid ElGamal public key signature for Duchy $DUCHY_TWO_NAME" + stageAttemptBuilder.apply { + stage = CONFIRMATION_PHASE.number + stageName = CONFIRMATION_PHASE.name + attemptNumber = 1 + } + } + } + .build() + ) + } + + @Test + fun `setup phase at non-aggregator using calculated result`() = runBlocking { + // Stage 0. preparing the storage and set up mock + val partialToken = + FakeComputationsDatabase.newPartialToken( + localId = LOCAL_ID, + stage = SETUP_PHASE.toProtocolStage() + ) + .build() + val requisitionBlobContext = + RequisitionBlobContext(GLOBAL_ID, REQUISITION_1.externalKey.externalRequisitionId) + val calculatedBlobContext = ComputationBlobContext(GLOBAL_ID, SETUP_PHASE.toProtocolStage(), 1L) + requisitionStore.writeString(requisitionBlobContext, "local_requisition") + fakeComputationDb.addComputation( + partialToken.localComputationId, + partialToken.computationStage, + computationDetails = NON_AGGREGATOR_COMPUTATION_DETAILS, + requisitions = listOf(REQUISITION_1, REQUISITION_2, REQUISITION_3), + blobs = listOf(newEmptyOutputBlobMetadata(calculatedBlobContext.blobId)) + ) + + var cryptoRequest = CompleteReachOnlySetupPhaseRequest.getDefaultInstance() + whenever(mockCryptoWorker.completeReachOnlySetupPhase(any())).thenAnswer { + cryptoRequest = it.getArgument(0) + val postFix = ByteString.copyFromUtf8("-completeReachOnlySetupPhase") + CompleteReachOnlySetupPhaseResponse.newBuilder() + .apply { + combinedRegisterVector = cryptoRequest.combinedRegisterVector.concat(postFix) + serializedExcessiveNoiseCiphertext = ByteString.copyFromUtf8("-encryptedNoise") + } + .build() + } + + // Stage 1. Process the above computation + nonAggregatorMill.pollAndProcessNextComputation() + + // Stage 2. Check the status of the computation + val blobKey = calculatedBlobContext.blobKey + assertThat(fakeComputationDb[LOCAL_ID]) + .isEqualTo( + ComputationToken.newBuilder() + .apply { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 1 + computationStage = WAIT_EXECUTION_PHASE_INPUTS.toProtocolStage() + addBlobsBuilder().apply { + dependencyType = ComputationBlobDependency.INPUT + blobId = 0L + path = blobKey + } + addBlobsBuilder().apply { + dependencyType = ComputationBlobDependency.OUTPUT + blobId = 1L + } + version = 3 // claimTask + writeOutputBlob + transitionStage + computationDetails = NON_AGGREGATOR_COMPUTATION_DETAILS + addAllRequisitions(listOf(REQUISITION_1, REQUISITION_2, REQUISITION_3)) + } + .build() + ) + + assertThat(computationStore.get(blobKey)?.readToString()) + .isEqualTo("local_requisition-completeReachOnlySetupPhase-encryptedNoise") + + assertThat(computationControlRequests) + .containsExactlyElementsIn( + buildAdvanceComputationRequests( + GLOBAL_ID, + SETUP_PHASE_INPUT, + "local_requisition-co", + "mpleteReachOnlySetup", + "Phase-encryptedNoise" + ) + ) + .inOrder() + + assertThat(cryptoRequest) + .isEqualTo( + completeReachOnlySetupPhaseRequest { + combinedRegisterVector = ByteString.copyFromUtf8("local_requisition") + curveId = CURVE_ID + noiseParameters = registerNoiseGenerationParameters { + compositeElGamalPublicKey = COMBINED_PUBLIC_KEY + curveId = CURVE_ID + contributorsCount = WORKER_COUNT + totalSketchesCount = REQUISITIONS.size + dpParams = reachNoiseDifferentialPrivacyParams { + blindHistogram = TEST_NOISE_CONFIG.reachNoiseConfig.blindHistogramNoise + noiseForPublisherNoise = TEST_NOISE_CONFIG.reachNoiseConfig.noiseForPublisherNoise + globalReachDpNoise = TEST_NOISE_CONFIG.reachNoiseConfig.globalReachDpNoise + } + } + compositeElGamalPublicKey = COMBINED_PUBLIC_KEY + parallelism = PARALLELISM + } + ) + } + + @Test + fun `setup phase at aggregator using calculated result`() = runBlocking { + // Stage 0. preparing the storage and set up mock + val partialToken = + FakeComputationsDatabase.newPartialToken( + localId = LOCAL_ID, + stage = SETUP_PHASE.toProtocolStage() + ) + .build() + val requisitionBlobContext = + RequisitionBlobContext(GLOBAL_ID, REQUISITION_1.externalKey.externalRequisitionId) + requisitionStore.writeString(requisitionBlobContext, "local_requisition_") + val inputBlob0Context = ComputationBlobContext(GLOBAL_ID, SETUP_PHASE.toProtocolStage(), 0L) + computationStore.writeString(inputBlob0Context, "duchy_2_sketch_" + NOISE_CIPHERTEXT) + val inputBlob1Context = ComputationBlobContext(GLOBAL_ID, SETUP_PHASE.toProtocolStage(), 1L) + computationStore.writeString(inputBlob1Context, "duchy_3_sketch_" + NOISE_CIPHERTEXT) + fakeComputationDb.addComputation( + partialToken.localComputationId, + partialToken.computationStage, + computationDetails = AGGREGATOR_COMPUTATION_DETAILS, + blobs = + listOf( + newInputBlobMetadata(0L, inputBlob0Context.blobKey), + newInputBlobMetadata(1L, inputBlob1Context.blobKey), + newEmptyOutputBlobMetadata(3L) + ), + requisitions = listOf(REQUISITION_1, REQUISITION_2, REQUISITION_3) + ) + + var cryptoRequest = CompleteReachOnlySetupPhaseRequest.getDefaultInstance() + whenever(mockCryptoWorker.completeReachOnlySetupPhaseAtAggregator(any())).thenAnswer { + cryptoRequest = it.getArgument(0) + val postFix = ByteString.copyFromUtf8("-completeReachOnlySetupPhase") + CompleteReachOnlySetupPhaseResponse.newBuilder() + .apply { + combinedRegisterVector = cryptoRequest.combinedRegisterVector.concat(postFix) + serializedExcessiveNoiseCiphertext = ByteString.copyFromUtf8("-encryptedNoise") + } + .build() + } + + // Stage 1. Process the above computation + aggregatorMill.pollAndProcessNextComputation() + + // Stage 2. Check the status of the computation + val blobKey = ComputationBlobContext(GLOBAL_ID, SETUP_PHASE.toProtocolStage(), 3L).blobKey + assertThat(fakeComputationDb[LOCAL_ID]) + .isEqualTo( + ComputationToken.newBuilder() + .apply { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 1 + computationStage = WAIT_EXECUTION_PHASE_INPUTS.toProtocolStage() + addBlobsBuilder().apply { + dependencyType = ComputationBlobDependency.INPUT + blobId = 0 + path = blobKey + } + addBlobsBuilder().apply { + dependencyType = ComputationBlobDependency.OUTPUT + blobId = 1 + } + version = 3 // claimTask + writeOutputBlob + transitionStage + computationDetails = AGGREGATOR_COMPUTATION_DETAILS + addAllRequisitions(listOf(REQUISITION_1, REQUISITION_2, REQUISITION_3)) + } + .build() + ) + + assertThat(computationStore.get(blobKey)?.readToString()) + .isEqualTo("local_requisition_duchy_2_sketch_duchy_3_sketch_-completeReachOnlySetupPhase-encryptedNoise") + + assertThat(computationControlRequests) + .containsExactlyElementsIn( + buildAdvanceComputationRequests( + GLOBAL_ID, + EXECUTION_PHASE_INPUT, + "local_requisition_du", + "chy_2_sketch_duchy_3", + "_sketch_-completeRea", + "chOnlySetupPhase-enc", + "ryptedNoise" + ) + ) + .inOrder() + + assertThat(cryptoRequest) + .isEqualTo( + completeReachOnlySetupPhaseRequest { + combinedRegisterVector = + ByteString.copyFromUtf8("local_requisition_duchy_2_sketch_duchy_3_sketch_") + curveId = CURVE_ID + noiseParameters = registerNoiseGenerationParameters { + compositeElGamalPublicKey = COMBINED_PUBLIC_KEY + curveId = CURVE_ID + contributorsCount = WORKER_COUNT + totalSketchesCount = REQUISITIONS.size + dpParams = reachNoiseDifferentialPrivacyParams { + blindHistogram = TEST_NOISE_CONFIG.reachNoiseConfig.blindHistogramNoise + noiseForPublisherNoise = TEST_NOISE_CONFIG.reachNoiseConfig.noiseForPublisherNoise + globalReachDpNoise = TEST_NOISE_CONFIG.reachNoiseConfig.globalReachDpNoise + } + } + compositeElGamalPublicKey = COMBINED_PUBLIC_KEY + serializedExcessiveNoiseCiphertext = SERIALIZED_NOISE_CIPHERTEXT.concat(SERIALIZED_NOISE_CIPHERTEXT) + parallelism = PARALLELISM + } + ) + } + + @Test + fun `setup phase at aggregator, failed due to invalid input blob size`() = runBlocking { + // Stage 0. preparing the storage and set up mock + val partialToken = + FakeComputationsDatabase.newPartialToken( + localId = LOCAL_ID, + stage = SETUP_PHASE.toProtocolStage() + ) + .build() + val requisitionBlobContext = + RequisitionBlobContext(GLOBAL_ID, REQUISITION_1.externalKey.externalRequisitionId) + requisitionStore.writeString(requisitionBlobContext, "local_requisition_") + val inputBlob0Context = ComputationBlobContext(GLOBAL_ID, SETUP_PHASE.toProtocolStage(), 0L) + computationStore.writeString(inputBlob0Context, "duchy_2_sketch_") + val inputBlob1Context = ComputationBlobContext(GLOBAL_ID, SETUP_PHASE.toProtocolStage(), 1L) + computationStore.writeString(inputBlob1Context, "duchy_3_sketch_" + NOISE_CIPHERTEXT) + fakeComputationDb.addComputation( + partialToken.localComputationId, + partialToken.computationStage, + computationDetails = + AGGREGATOR_COMPUTATION_DETAILS, + blobs = + listOf( + newInputBlobMetadata(0L, inputBlob0Context.blobKey), + newInputBlobMetadata(1L, inputBlob1Context.blobKey), + newEmptyOutputBlobMetadata(3L) + ), + requisitions = listOf(REQUISITION_1, REQUISITION_2, REQUISITION_3) + ) + + // Stage 1. Process the above computation + aggregatorMill.pollAndProcessNextComputation() + + // Stage 2. Check the status of the computation + assertThat(fakeComputationDb[LOCAL_ID]!!) + .isEqualTo( + ComputationToken.newBuilder() + .apply { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 1 + computationStage = COMPLETE.toProtocolStage() + version = 2 // claimTask + transitionStage + computationDetails = + AGGREGATOR_COMPUTATION_DETAILS + .toBuilder() + .apply { + endingState = CompletedReason.FAILED + } + .build() + addAllRequisitions(listOf(REQUISITION_1, REQUISITION_2, REQUISITION_3)) + } + .build() + ) + + verifyProtoArgument( + mockComputationParticipants, + SystemComputationParticipantsCoroutineImplBase::failComputationParticipant + ) + .comparingExpectedFieldsOnly() + .isEqualTo( + FailComputationParticipantRequest.newBuilder() + .apply { + name = ComputationParticipantKey(GLOBAL_ID, DUCHY_ONE_NAME).toName() + failureBuilder.apply { + participantChildReferenceId = MILL_ID + errorMessage = + "PERMANENT error: java.lang.IllegalStateException: Invalid input blob size. Input" + + " blob duchy_2_sketch_ has size 15 which is less than (66)." + stageAttemptBuilder.apply { + stage = SETUP_PHASE.number + stageName = SETUP_PHASE.name + attemptNumber = 1 + } + } + } + .build() + ) + } + + + + @Test + fun `execution phase at non-aggregator, failed due to invalid input blob size`() = runBlocking { + // Stage 0. preparing the storage and set up mock + val partialToken = + FakeComputationsDatabase.newPartialToken( + localId = LOCAL_ID, + stage = EXECUTION_PHASE.toProtocolStage() + ) + .build() + val inputBlobContext = + ComputationBlobContext(GLOBAL_ID, EXECUTION_PHASE.toProtocolStage(), 0L) + val calculatedBlobContext = + ComputationBlobContext(GLOBAL_ID, EXECUTION_PHASE.toProtocolStage(), 1L) + computationStore.writeString(inputBlobContext, "data") + fakeComputationDb.addComputation( + partialToken.localComputationId, + partialToken.computationStage, + computationDetails = NON_AGGREGATOR_COMPUTATION_DETAILS, + blobs = + listOf( + inputBlobContext.toMetadata(ComputationBlobDependency.INPUT), + newEmptyOutputBlobMetadata(calculatedBlobContext.blobId) + ), + requisitions = REQUISITIONS + ) + + // Stage 1. Process the above computation + nonAggregatorMill.pollAndProcessNextComputation() + + // Stage 2. Check the status of the computation + assertThat(fakeComputationDb[LOCAL_ID]!!) + .isEqualTo( + ComputationToken.newBuilder() + .apply { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 1 + computationStage = COMPLETE.toProtocolStage() + version = 2 // claimTask + transitionStage + computationDetails = + NON_AGGREGATOR_COMPUTATION_DETAILS + .toBuilder() + .apply { + endingState = CompletedReason.FAILED + } + .build() + addAllRequisitions(listOf(REQUISITION_1, REQUISITION_2, REQUISITION_3)) + } + .build() + ) + + verifyProtoArgument( + mockComputationParticipants, + SystemComputationParticipantsCoroutineImplBase::failComputationParticipant + ) + .comparingExpectedFieldsOnly() + .isEqualTo( + FailComputationParticipantRequest.newBuilder() + .apply { + name = ComputationParticipantKey(GLOBAL_ID, DUCHY_ONE_NAME).toName() + failureBuilder.apply { + participantChildReferenceId = MILL_ID + errorMessage = + "PERMANENT error: Invalid input blob size. Input blob data has size 4 which is less than (66)." + stageAttemptBuilder.apply { + stage = EXECUTION_PHASE.number + stageName = EXECUTION_PHASE.name + attemptNumber = 1 + } + } + } + .build() + ) + } + + @Test + fun `execution phase at non-aggregator using cached result`() = runBlocking { + // Stage 0. preparing the storage and set up mock + val partialToken = + FakeComputationsDatabase.newPartialToken( + localId = LOCAL_ID, + stage = EXECUTION_PHASE.toProtocolStage() + ) + .build() + val inputBlobContext = + ComputationBlobContext(GLOBAL_ID, EXECUTION_PHASE.toProtocolStage(), 0L) + computationStore.writeString(inputBlobContext, "sketch" + NOISE_CIPHERTEXT) + val cachedBlobContext = + ComputationBlobContext(GLOBAL_ID, EXECUTION_PHASE.toProtocolStage(), 1L) + computationStore.writeString(cachedBlobContext, "cached result") + fakeComputationDb.addComputation( + partialToken.localComputationId, + partialToken.computationStage, + computationDetails = + NON_AGGREGATOR_COMPUTATION_DETAILS, + blobs = + listOf( + inputBlobContext.toMetadata(ComputationBlobDependency.INPUT), + cachedBlobContext.toMetadata(ComputationBlobDependency.OUTPUT) + ), + requisitions = REQUISITIONS + ) + + // Stage 1. Process the above computation + nonAggregatorMill.pollAndProcessNextComputation() + + // Stage 2. Check the status of the computation + assertThat(fakeComputationDb[LOCAL_ID]) + .isEqualTo( + computationToken { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 1 + computationStage = COMPLETE.toProtocolStage() + version = 2 // claimTask + transitionStage + computationDetails = + NON_AGGREGATOR_COMPUTATION_DETAILS + .toBuilder() + .apply { endingState = CompletedReason.SUCCEEDED } + .build() + requisitions += REQUISITIONS + } + ) + + assertThat(computationControlRequests) + .containsExactlyElementsIn( + buildAdvanceComputationRequests(GLOBAL_ID, EXECUTION_PHASE_INPUT, "cached result") + ) + .inOrder() + } + + @Test + fun `execution phase at non-aggregator using calculated result`() = runBlocking { + // Stage 0. preparing the storage and set up mock + val partialToken = + FakeComputationsDatabase.newPartialToken( + localId = LOCAL_ID, + stage = EXECUTION_PHASE.toProtocolStage() + ) + .build() + val inputBlobContext = + ComputationBlobContext(GLOBAL_ID, EXECUTION_PHASE.toProtocolStage(), 0L) + val calculatedBlobContext = + ComputationBlobContext(GLOBAL_ID, EXECUTION_PHASE.toProtocolStage(), 1L) + computationStore.writeString(inputBlobContext, "data" + NOISE_CIPHERTEXT) + fakeComputationDb.addComputation( + partialToken.localComputationId, + partialToken.computationStage, + computationDetails = NON_AGGREGATOR_COMPUTATION_DETAILS, + blobs = + listOf( + inputBlobContext.toMetadata(ComputationBlobDependency.INPUT), + newEmptyOutputBlobMetadata(calculatedBlobContext.blobId) + ), + requisitions = REQUISITIONS + ) + + var cryptoRequest = CompleteReachOnlyExecutionPhaseRequest.getDefaultInstance() + whenever(mockCryptoWorker.completeReachOnlyExecutionPhase(any())).thenAnswer { + cryptoRequest = it.getArgument(0) + val postFix = ByteString.copyFromUtf8("-completeReachOnlyExecutionPhase") + CompleteReachOnlyExecutionPhaseResponse.newBuilder() + .apply { + combinedRegisterVector = cryptoRequest.combinedRegisterVector.concat(postFix) + serializedExcessiveNoiseCiphertext = ByteString.copyFromUtf8("-partiallyDecryptedNoise") + } + .build() + } + + // Stage 1. Process the above computation + nonAggregatorMill.pollAndProcessNextComputation() + + // Stage 2. Check the status of the computation + val blobKey = calculatedBlobContext.blobKey + assertThat(fakeComputationDb[LOCAL_ID]) + .isEqualTo( + ComputationToken.newBuilder() + .apply { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 1 + computationStage = COMPLETE.toProtocolStage() + version = 3 // claimTask + writeOutputBlob + transitionStage + computationDetails = + NON_AGGREGATOR_COMPUTATION_DETAILS + .toBuilder() + .apply { endingState = CompletedReason.SUCCEEDED } + .build() + addAllRequisitions(REQUISITIONS) + } + .build() + ) + assertThat(computationStore.get(blobKey)?.readToString()) + .isEqualTo("data-completeReachOnlyExecutionPhase-partiallyDecryptedNoise") + + assertThat(cryptoRequest) + .isEqualTo( + completeReachOnlyExecutionPhaseRequest { + combinedRegisterVector = ByteString.copyFromUtf8("data") + localElGamalKeyPair = DUCHY_ONE_KEY_PAIR + curveId = CURVE_ID + serializedExcessiveNoiseCiphertext = SERIALIZED_NOISE_CIPHERTEXT + parallelism = PARALLELISM + } + ) + } + + @Test + fun `execution phase at aggregator, failed due to invalid input blob size`() = runBlocking { + // Stage 0. preparing the storage and set up mock + val partialToken = + FakeComputationsDatabase.newPartialToken( + localId = LOCAL_ID, + stage = EXECUTION_PHASE.toProtocolStage() + ) + .build() + val inputBlobContext = + ComputationBlobContext(GLOBAL_ID, EXECUTION_PHASE.toProtocolStage(), 0L) + val calculatedBlobContext = + ComputationBlobContext(GLOBAL_ID, EXECUTION_PHASE.toProtocolStage(), 1L) + computationStore.writeString(inputBlobContext, "data") + fakeComputationDb.addComputation( + partialToken.localComputationId, + partialToken.computationStage, + computationDetails = AGGREGATOR_COMPUTATION_DETAILS, + blobs = + listOf( + inputBlobContext.toMetadata(ComputationBlobDependency.INPUT), + newEmptyOutputBlobMetadata(calculatedBlobContext.blobId) + ), + requisitions = REQUISITIONS + ) + + // Stage 1. Process the above computation + nonAggregatorMill.pollAndProcessNextComputation() + + // Stage 2. Check the status of the computation + assertThat(fakeComputationDb[LOCAL_ID]!!) + .isEqualTo( + ComputationToken.newBuilder() + .apply { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 1 + computationStage = COMPLETE.toProtocolStage() + version = 2 // claimTask + transitionStage + computationDetails = + AGGREGATOR_COMPUTATION_DETAILS + .toBuilder() + .apply { + endingState = CompletedReason.FAILED + } + .build() + addAllRequisitions(listOf(REQUISITION_1, REQUISITION_2, REQUISITION_3)) + } + .build() + ) + + verifyProtoArgument( + mockComputationParticipants, + SystemComputationParticipantsCoroutineImplBase::failComputationParticipant + ) + .comparingExpectedFieldsOnly() + .isEqualTo( + FailComputationParticipantRequest.newBuilder() + .apply { + name = ComputationParticipantKey(GLOBAL_ID, DUCHY_ONE_NAME).toName() + failureBuilder.apply { + participantChildReferenceId = MILL_ID + errorMessage = + "PERMANENT error: Invalid input blob size. Input blob data has size 4 which is less than (66)." + stageAttemptBuilder.apply { + stage = EXECUTION_PHASE.number + stageName = EXECUTION_PHASE.name + attemptNumber = 1 + } + } + } + .build() + ) + } + + @Test + fun `execution phase at aggregator using cached result`() = runBlocking { + // Stage 0. preparing the storage and set up mock + val partialToken = + FakeComputationsDatabase.newPartialToken( + localId = LOCAL_ID, + stage = EXECUTION_PHASE.toProtocolStage() + ) + .build() + val inputBlobContext = + ComputationBlobContext(GLOBAL_ID, EXECUTION_PHASE.toProtocolStage(), 0L) + computationStore.writeString(inputBlobContext, "sketch" + NOISE_CIPHERTEXT) + val cachedBlobContext = + ComputationBlobContext(GLOBAL_ID, EXECUTION_PHASE.toProtocolStage(), 1L) + computationStore.writeString(cachedBlobContext, "cached result") + fakeComputationDb.addComputation( + partialToken.localComputationId, + partialToken.computationStage, + computationDetails = + AGGREGATOR_COMPUTATION_DETAILS, + blobs = + listOf( + inputBlobContext.toMetadata(ComputationBlobDependency.INPUT), + cachedBlobContext.toMetadata(ComputationBlobDependency.OUTPUT) + ), + requisitions = REQUISITIONS + ) + + // Stage 1. Process the above computation + nonAggregatorMill.pollAndProcessNextComputation() + + // Stage 2. Check the status of the computation + assertThat(fakeComputationDb[LOCAL_ID]) + .isEqualTo( + computationToken { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 1 + computationStage = COMPLETE.toProtocolStage() + version = 2 // claimTask + transitionStage + computationDetails = + AGGREGATOR_COMPUTATION_DETAILS + .toBuilder() + .apply { endingState = CompletedReason.SUCCEEDED } + .build() + requisitions += REQUISITIONS + } + ) + } + + @Test + fun `execution phase at aggregator using calculated result`() = runBlocking { + // Stage 0. preparing the storage and set up mock + val partialToken = + FakeComputationsDatabase.newPartialToken( + localId = LOCAL_ID, + stage = EXECUTION_PHASE.toProtocolStage() + ) + .build() + val computationDetailsWithVidSamplingWidth = + AGGREGATOR_COMPUTATION_DETAILS.copy { + kingdomComputation = + kingdomComputation.copy { + measurementSpec = SERIALIZED_MEASUREMENT_SPEC_WITH_VID_SAMPLING_WIDTH + } + } + val inputBlobContext = + ComputationBlobContext(GLOBAL_ID, EXECUTION_PHASE.toProtocolStage(), 0L) + val calculatedBlobContext = + ComputationBlobContext(GLOBAL_ID, EXECUTION_PHASE.toProtocolStage(), 1L) + computationStore.writeString(inputBlobContext, "data" + NOISE_CIPHERTEXT) + fakeComputationDb.addComputation( + partialToken.localComputationId, + partialToken.computationStage, + computationDetails = computationDetailsWithVidSamplingWidth, + blobs = + listOf( + inputBlobContext.toMetadata(ComputationBlobDependency.INPUT), + newEmptyOutputBlobMetadata(calculatedBlobContext.blobId) + ), + requisitions = REQUISITIONS + ) + val testReach = 123L + var cryptoRequest = CompleteReachOnlyExecutionPhaseAtAggregatorRequest.getDefaultInstance() + whenever(mockCryptoWorker.completeReachOnlyExecutionPhaseAtAggregator(any())).thenAnswer { + cryptoRequest = it.getArgument(0) + CompleteReachOnlyExecutionPhaseAtAggregatorResponse.newBuilder() + .apply { reach = testReach } + .build() + } + var systemComputationResult = SetComputationResultRequest.getDefaultInstance() + whenever(mockSystemComputations.setComputationResult(any())).thenAnswer { + systemComputationResult = it.getArgument(0) + Computation.getDefaultInstance() + } + + // Stage 1. Process the above computation + aggregatorMill.pollAndProcessNextComputation() + + // Stage 2. Check the status of the computation + val blobKey = calculatedBlobContext.blobKey + assertThat(fakeComputationDb[LOCAL_ID]) + .isEqualTo( + ComputationToken.newBuilder() + .apply { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 1 + computationStage = COMPLETE.toProtocolStage() + version = 3 // claimTask + writeOutputBlob + transitionStage + computationDetails = + computationDetailsWithVidSamplingWidth + .toBuilder() + .apply { endingState = CompletedReason.SUCCEEDED } + .build() + addAllRequisitions(REQUISITIONS) + } + .build() + ) + assertThat(computationStore.get(blobKey)?.readToString()).isNotEmpty() + + assertThat(systemComputationResult.name).isEqualTo("computations/$GLOBAL_ID") + // The signature is non-deterministic, so we only verity the encryption is not empty. + assertThat(systemComputationResult.encryptedResult).isNotEmpty() + assertThat(systemComputationResult) + .comparingExpectedFieldsOnly() + .isEqualTo( + setComputationResultRequest { + name = "computations/$GLOBAL_ID" + aggregatorCertificate = CONSENT_SIGNALING_CERT_NAME + resultPublicKey = ENCRYPTION_PUBLIC_KEY.toByteString() + } + ) + + assertThat(cryptoRequest) + .isEqualTo( + completeReachOnlyExecutionPhaseAtAggregatorRequest { + combinedRegisterVector = ByteString.copyFromUtf8("data") + localElGamalKeyPair = DUCHY_ONE_KEY_PAIR + curveId = CURVE_ID + serializedExcessiveNoiseCiphertext = SERIALIZED_NOISE_CIPHERTEXT + liquidLegionsParameters = liquidLegionsSketchParameters { + decayRate = DECAY_RATE + size = SKETCH_SIZE + } + reachDpNoiseBaseline = globalReachDpNoiseBaseline { + contributorsCount = WORKER_COUNT + globalReachDpNoise = TEST_NOISE_CONFIG.reachNoiseConfig.globalReachDpNoise + } + vidSamplingIntervalWidth = 0.5f + noiseParameters = registerNoiseGenerationParameters { + compositeElGamalPublicKey = COMBINED_PUBLIC_KEY + curveId = CURVE_ID + contributorsCount = WORKER_COUNT + totalSketchesCount = REQUISITIONS.size + dpParams = reachNoiseDifferentialPrivacyParams { + blindHistogram = TEST_NOISE_CONFIG.reachNoiseConfig.blindHistogramNoise + noiseForPublisherNoise = TEST_NOISE_CONFIG.reachNoiseConfig.noiseForPublisherNoise + globalReachDpNoise = TEST_NOISE_CONFIG.reachNoiseConfig.globalReachDpNoise + } + } + parallelism = PARALLELISM + } + ) + } +} + +private fun ComputationBlobContext.toMetadata(dependencyType: ComputationBlobDependency) = + computationStageBlobMetadata { + blobId = this@toMetadata.blobId + path = blobKey + this.dependencyType = dependencyType + } + +private suspend fun Blob.readToString(): String = read().flatten().toStringUtf8() + +private suspend fun ComputationStore.writeString( + context: ComputationBlobContext, + content: String +): Blob = write(context, ByteString.copyFromUtf8(content)) + +private suspend fun RequisitionStore.writeString( + context: RequisitionBlobContext, + content: String +): Blob = write(context, ByteString.copyFromUtf8(content)) diff --git a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/BUILD.bazel b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/BUILD.bazel index ff0ea7a100a..64d03902a74 100644 --- a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/BUILD.bazel +++ b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/BUILD.bazel @@ -30,3 +30,35 @@ kt_jvm_test( "@wfa_common_jvm//imports/kotlin/kotlin/test", ], ) + +kt_jvm_test( + name = "ReachOnlyLiquidLegionsV2EncryptionUtilityTest", + srcs = ["ReachOnlyLiquidLegionsV2EncryptionUtilityTest.kt"], + test_class = "org.wfanet.measurement.duchy.daemon.mill.liquidlegionsv2.crypto.ReachOnlyLiquidLegionsV2EncryptionUtilityTest", + deps = [ + "//src/main/kotlin/org/wfanet/measurement/duchy/daemon/utils:computation_conversions", + "//src/main/proto/wfa/any_sketch:sketch_kt_jvm_proto", + "//src/main/proto/wfa/any_sketch/crypto:sketch_encryption_methods_kt_jvm_proto", + "//src/main/proto/wfa/measurement/internal/duchy/protocol:reach_only_liquid_legions_v2_encryption_methods_kt_jvm_proto", + "//src/main/swig/protocol/reachonlyliquidlegionsv2:reach_only_liquid_legions_v2_encryption_utility", + "@any_sketch_java//src/main/java/org/wfanet/anysketch/crypto:sketch_encrypter_adapter", + "@any_sketch_java//src/main/java/org/wfanet/estimation:estimators", + "@wfa_common_jvm//imports/java/com/google/common/truth", + "@wfa_common_jvm//imports/java/com/google/protobuf", + "@wfa_common_jvm//imports/java/org/junit", + "@wfa_common_jvm//imports/kotlin/kotlin/test", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common", + ], +) + +kt_jvm_test( + name = "JniReachOnlyLiquidLegionsV2EncryptionTest", + srcs = ["JniReachOnlyLiquidLegionsV2EncryptionTest.kt"], + test_class = "org.wfanet.measurement.duchy.daemon.mill.liquidlegionsv2.crypto.JniReachOnlyLiquidLegionsV2EncryptionTest", + deps = [ + "//src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto:reachonlyliquidlegionsv2encryption", + "//src/main/proto/wfa/measurement/internal/duchy/protocol:reach_only_liquid_legions_v2_encryption_methods_kt_jvm_proto", + "@wfa_common_jvm//imports/java/com/google/common/truth", + "@wfa_common_jvm//imports/kotlin/kotlin/test", + ], +) diff --git a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/JniReachOnlyLiquidLegionsV2EncryptionTest.kt b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/JniReachOnlyLiquidLegionsV2EncryptionTest.kt new file mode 100644 index 00000000000..20676ac6984 --- /dev/null +++ b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/JniReachOnlyLiquidLegionsV2EncryptionTest.kt @@ -0,0 +1,43 @@ +// Copyright 2023 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package org.wfanet.measurement.duchy.daemon.mill.liquidlegionsv2.crypto + +import com.google.common.truth.Truth.assertThat +import kotlin.test.assertFailsWith +import org.junit.Test +import org.wfanet.anysketch.crypto.CombineElGamalPublicKeysRequest +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlySetupPhaseRequest + +class JniReachOnlyLiquidLegionsV2EncryptionTest { + + @Test + fun `check JNI lib is loaded successfully`() { + // Send an invalid request and check if we can get the error thrown inside JNI. + val e1 = + assertFailsWith(RuntimeException::class) { + JniReachOnlyLiquidLegionsV2Encryption() + .completeReachOnlySetupPhase(CompleteReachOnlySetupPhaseRequest.getDefaultInstance()) + } + assertThat(e1.message).contains("ECGroup::CreateGroup() - Could not create group.") + + // Send an invalid request and check if we can get the error thrown inside JNI. + val e2 = + assertFailsWith(RuntimeException::class) { + JniReachOnlyLiquidLegionsV2Encryption() + .combineElGamalPublicKeys(CombineElGamalPublicKeysRequest.getDefaultInstance()) + } + assertThat(e2.message).contains("Keys cannot be empty") + } +} diff --git a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/ReachOnlyLiquidLegionsV2EncryptionUtilityTest.kt b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/ReachOnlyLiquidLegionsV2EncryptionUtilityTest.kt new file mode 100644 index 00000000000..7eb5c1768be --- /dev/null +++ b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/ReachOnlyLiquidLegionsV2EncryptionUtilityTest.kt @@ -0,0 +1,305 @@ +// Copyright 2023 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package org.wfanet.measurement.duchy.daemon.mill.liquidlegionsv2.crypto + +import com.google.common.truth.Truth.assertThat +import com.google.protobuf.ByteString +import java.nio.file.Paths +import kotlin.test.assertFailsWith +import kotlin.test.assertEquals +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.JUnit4 +import org.wfanet.anysketch.Sketch +import org.wfanet.anysketch.SketchConfig.ValueSpec.Aggregator +import org.wfanet.anysketch.crypto.CombineElGamalPublicKeysRequest +import org.wfanet.anysketch.crypto.CombineElGamalPublicKeysResponse +import org.wfanet.anysketch.crypto.EncryptSketchRequest +import org.wfanet.anysketch.crypto.EncryptSketchRequest.DestroyedRegisterStrategy.FLAGGED_KEY +import org.wfanet.anysketch.crypto.EncryptSketchResponse +import org.wfanet.anysketch.crypto.SketchEncrypterAdapter +import org.wfanet.estimation.Estimators +import org.wfanet.measurement.common.loadLibrary +import org.wfanet.measurement.duchy.daemon.utils.toAnySketchElGamalPublicKey +import org.wfanet.measurement.duchy.daemon.utils.toCmmsElGamalPublicKey +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyExecutionPhaseAtAggregatorResponse +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyExecutionPhaseResponse +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyInitializationPhaseRequest +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyInitializationPhaseResponse +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlySetupPhaseResponse +import org.wfanet.measurement.internal.duchy.protocol.completeReachOnlyExecutionPhaseAtAggregatorRequest +import org.wfanet.measurement.internal.duchy.protocol.completeReachOnlyExecutionPhaseRequest +import org.wfanet.measurement.internal.duchy.protocol.completeReachOnlySetupPhaseRequest +import org.wfanet.measurement.internal.duchy.protocol.liquidLegionsSketchParameters +import org.wfanet.measurement.internal.duchy.protocol.reachonlyliquidlegionsv2.ReachOnlyLiquidLegionsV2EncryptionUtility + +@RunWith(JUnit4::class) +class ReachOnlyLiquidLegionsV2EncryptionUtilityTest { + + private fun createEmptyReachOnlyLiquidLegionsSketch(): Sketch.Builder { + return Sketch.newBuilder() + } + + private fun Sketch.Builder.addRegister(index: Long) { + addRegistersBuilder().also { + it.index = index + } + } + + // Helper function to go through the entire Liquid Legions V2 protocol using the input data. + // The final relative_frequency_distribution map are returned. + private fun goThroughEntireMpcProtocol( + encrypted_sketch: ByteString + ): CompleteReachOnlyExecutionPhaseAtAggregatorResponse { + // Setup phase at Duchy 1 (NON_AGGREGATOR). Duchy 1 receives all the sketches. + val completeReachOnlySetupPhaseRequest1 = completeReachOnlySetupPhaseRequest { + combinedRegisterVector = encrypted_sketch + curveId = CURVE_ID + compositeElGamalPublicKey = CLIENT_EL_GAMAL_KEYS + parallelism = PARALLELISM + } + val completeReachOnlySetupPhaseResponse1 = + CompleteReachOnlySetupPhaseResponse.parseFrom( + ReachOnlyLiquidLegionsV2EncryptionUtility.completeReachOnlySetupPhase( + completeReachOnlySetupPhaseRequest1.toByteArray()) + ) + // Setup phase at Duchy 2 (NON_AGGREGATOR). Duchy 2 does not receive any sketche. + val completeReachOnlySetupPhaseRequest2 = completeReachOnlySetupPhaseRequest { + curveId = CURVE_ID + compositeElGamalPublicKey = CLIENT_EL_GAMAL_KEYS + parallelism = PARALLELISM + } + val completeReachOnlySetupPhaseResponse2 = + CompleteReachOnlySetupPhaseResponse.parseFrom( + ReachOnlyLiquidLegionsV2EncryptionUtility.completeReachOnlySetupPhase( + completeReachOnlySetupPhaseRequest2.toByteArray()) + ) + // Setup phase at Duchy 3 (AGGREGATOR). Aggregator receives the combined register vector and + // the concatenated excessive noise ciphertexts. + val completeReachOnlySetupPhaseRequest3 = completeReachOnlySetupPhaseRequest { + combinedRegisterVector = + completeReachOnlySetupPhaseResponse1.combinedRegisterVector.concat( + completeReachOnlySetupPhaseResponse2.combinedRegisterVector) + curveId = CURVE_ID + compositeElGamalPublicKey = CLIENT_EL_GAMAL_KEYS + serializedExcessiveNoiseCiphertext = + completeReachOnlySetupPhaseResponse1.serializedExcessiveNoiseCiphertext.concat( + completeReachOnlySetupPhaseResponse2.serializedExcessiveNoiseCiphertext) + parallelism = PARALLELISM + } + val completeReachOnlySetupPhaseResponse3 = + CompleteReachOnlySetupPhaseResponse.parseFrom( + ReachOnlyLiquidLegionsV2EncryptionUtility.completeReachOnlySetupPhase( + completeReachOnlySetupPhaseRequest3.toByteArray()) + ) + + // Execution phase at duchy 1 (non-aggregator). + val completeReachOnlyExecutionPhaseRequest1 = completeReachOnlyExecutionPhaseRequest { + combinedRegisterVector = completeReachOnlySetupPhaseResponse3.combinedRegisterVector + localElGamalKeyPair = DUCHY_1_EL_GAMAL_KEYS + curveId = CURVE_ID + serializedExcessiveNoiseCiphertext = + completeReachOnlySetupPhaseResponse3.serializedExcessiveNoiseCiphertext + parallelism = PARALLELISM + } + val completeReachOnlyExecutionPhaseResponse1 = + CompleteReachOnlyExecutionPhaseResponse.parseFrom( + ReachOnlyLiquidLegionsV2EncryptionUtility.completeReachOnlyExecutionPhase( + completeReachOnlyExecutionPhaseRequest1.toByteArray() + ) + ) + + // Execution phase at duchy 2 (non-aggregator). + val completeReachOnlyExecutionPhaseRequest2 = completeReachOnlyExecutionPhaseRequest { + combinedRegisterVector = completeReachOnlyExecutionPhaseResponse1.combinedRegisterVector + localElGamalKeyPair = DUCHY_2_EL_GAMAL_KEYS + curveId = CURVE_ID + serializedExcessiveNoiseCiphertext = + completeReachOnlyExecutionPhaseResponse1.serializedExcessiveNoiseCiphertext + parallelism = PARALLELISM + } + val completeReachOnlyExecutionPhaseResponse2 = + CompleteReachOnlyExecutionPhaseResponse.parseFrom( + ReachOnlyLiquidLegionsV2EncryptionUtility.completeReachOnlyExecutionPhase( + completeReachOnlyExecutionPhaseRequest2.toByteArray() + ) + ) + + // Execution phase at duchy 3 (aggregator). + val completeReachOnlyExecutionPhaseAtAggregatorRequest = + completeReachOnlyExecutionPhaseAtAggregatorRequest { + combinedRegisterVector = completeReachOnlyExecutionPhaseResponse2.combinedRegisterVector + localElGamalKeyPair = DUCHY_3_EL_GAMAL_KEYS + curveId = CURVE_ID + serializedExcessiveNoiseCiphertext = + completeReachOnlyExecutionPhaseResponse2.serializedExcessiveNoiseCiphertext + parallelism = PARALLELISM + liquidLegionsParameters = liquidLegionsSketchParameters { + decayRate = DECAY_RATE + size = LIQUID_LEGIONS_SIZE + } + vidSamplingIntervalWidth = VID_SAMPLING_INTERVAL_WIDTH + parallelism = PARALLELISM + } + return CompleteReachOnlyExecutionPhaseAtAggregatorResponse.parseFrom( + ReachOnlyLiquidLegionsV2EncryptionUtility.completeReachOnlyExecutionPhaseAtAggregator( + completeReachOnlyExecutionPhaseAtAggregatorRequest.toByteArray() + ) + ) + } + + @Test + fun endToEnd_basicBehavior() { + val rawSketch = + createEmptyReachOnlyLiquidLegionsSketch() + .apply { + addRegister(index = 1L) + addRegister(index = 2L) + addRegister(index = 2L) + addRegister(index = 4L) + addRegister(index = 5L) + } + .build() + val request = + EncryptSketchRequest.newBuilder() + .apply { + sketch = rawSketch + curveId = CURVE_ID + maximumValue = MAX_COUNTER_VALUE + elGamalKeys = CLIENT_EL_GAMAL_KEYS.toAnySketchElGamalPublicKey() + destroyedRegisterStrategy = FLAGGED_KEY + } + .build() + val response = + EncryptSketchResponse.parseFrom(SketchEncrypterAdapter.EncryptSketch(request.toByteArray())) + val encryptedSketch = response.encryptedSketch + val result = goThroughEntireMpcProtocol(encryptedSketch).reach + val expectedResult = Estimators.EstimateCardinalityLiquidLegions( + DECAY_RATE, LIQUID_LEGIONS_SIZE, 4, VID_SAMPLING_INTERVAL_WIDTH.toDouble()) + assertEquals(expectedResult, result) + } + + @Test + fun `completeReachOnlySetupPhase fails with invalid request message`() { + val exception = + assertFailsWith(RuntimeException::class) { + ReachOnlyLiquidLegionsV2EncryptionUtility.completeReachOnlySetupPhase("something not a proto".toByteArray()) + } + assertThat(exception).hasMessageThat().contains("Failed to parse") + } + + @Test + fun `completeReachOnlyExecutionPhase fails with invalid request message`() { + val exception = + assertFailsWith(RuntimeException::class) { + ReachOnlyLiquidLegionsV2EncryptionUtility.completeReachOnlyExecutionPhase( + "something not a proto".toByteArray() + ) + } + assertThat(exception).hasMessageThat().contains("Failed to parse") + } + + @Test + fun `completeReachOnlyExecutionPhaseAtAggregator fails with invalid request message`() { + val exception = + assertFailsWith(RuntimeException::class) { + ReachOnlyLiquidLegionsV2EncryptionUtility.completeReachOnlyExecutionPhaseAtAggregator( + "something not a proto".toByteArray() + ) + } + assertThat(exception).hasMessageThat().contains("Failed to parse") + } + + companion object { + init { + loadLibrary( + "reach_only_liquid_legions_v2_encryption_utility", + Paths.get("wfa_measurement_system/src/main/swig/protocol/reachonlyliquidlegionsv2") + ) + loadLibrary( + "sketch_encrypter_adapter", + Paths.get("any_sketch_java/src/main/java/org/wfanet/anysketch/crypto") + ) + loadLibrary( + "estimators", + Paths.get("any_sketch_java/src/main/java/org/wfanet/estimation") + ) + } + + private const val DECAY_RATE = 12.0 + private const val LIQUID_LEGIONS_SIZE = 100_000L + private const val MAXIMUM_FREQUENCY = 10 + private const val VID_SAMPLING_INTERVAL_WIDTH = 0.5f + + private const val CURVE_ID = 415L // NID_X9_62_prime256v1 + private const val PARALLELISM = 3 + private const val MAX_COUNTER_VALUE = 10 + + private val COMPLETE_INITIALIZATION_REQUEST = + CompleteReachOnlyInitializationPhaseRequest.newBuilder().apply { curveId = CURVE_ID }.build() + private val DUCHY_1_EL_GAMAL_KEYS = + CompleteReachOnlyInitializationPhaseResponse.parseFrom( + ReachOnlyLiquidLegionsV2EncryptionUtility.completeReachOnlyInitializationPhase( + COMPLETE_INITIALIZATION_REQUEST.toByteArray() + ) + ) + .elGamalKeyPair + private val DUCHY_2_EL_GAMAL_KEYS = + CompleteReachOnlyInitializationPhaseResponse.parseFrom( + ReachOnlyLiquidLegionsV2EncryptionUtility.completeReachOnlyInitializationPhase( + COMPLETE_INITIALIZATION_REQUEST.toByteArray() + ) + ) + .elGamalKeyPair + private val DUCHY_3_EL_GAMAL_KEYS = + CompleteReachOnlyInitializationPhaseResponse.parseFrom( + ReachOnlyLiquidLegionsV2EncryptionUtility.completeReachOnlyInitializationPhase( + COMPLETE_INITIALIZATION_REQUEST.toByteArray() + ) + ) + .elGamalKeyPair + + private val CLIENT_EL_GAMAL_KEYS = + CombineElGamalPublicKeysResponse.parseFrom( + SketchEncrypterAdapter.CombineElGamalPublicKeys( + CombineElGamalPublicKeysRequest.newBuilder() + .apply { + curveId = CURVE_ID + addElGamalKeys(DUCHY_1_EL_GAMAL_KEYS.publicKey.toAnySketchElGamalPublicKey()) + addElGamalKeys(DUCHY_2_EL_GAMAL_KEYS.publicKey.toAnySketchElGamalPublicKey()) + addElGamalKeys(DUCHY_3_EL_GAMAL_KEYS.publicKey.toAnySketchElGamalPublicKey()) + } + .build() + .toByteArray() + ) + ) + .elGamalKeys + .toCmmsElGamalPublicKey() + private val DUCHY_2_3_COMBINED_EL_GAMAL_KEYS = + CombineElGamalPublicKeysResponse.parseFrom( + SketchEncrypterAdapter.CombineElGamalPublicKeys( + CombineElGamalPublicKeysRequest.newBuilder() + .apply { + curveId = CURVE_ID + addElGamalKeys(DUCHY_2_EL_GAMAL_KEYS.publicKey.toAnySketchElGamalPublicKey()) + addElGamalKeys(DUCHY_3_EL_GAMAL_KEYS.publicKey.toAnySketchElGamalPublicKey()) + } + .build() + .toByteArray() + ) + ) + .elGamalKeys + .toCmmsElGamalPublicKey() + } +} diff --git a/src/test/kotlin/org/wfanet/measurement/duchy/db/computation/BUILD.bazel b/src/test/kotlin/org/wfanet/measurement/duchy/db/computation/BUILD.bazel index 1ae13f4c696..df6ea8f832d 100644 --- a/src/test/kotlin/org/wfanet/measurement/duchy/db/computation/BUILD.bazel +++ b/src/test/kotlin/org/wfanet/measurement/duchy/db/computation/BUILD.bazel @@ -53,6 +53,18 @@ kt_jvm_test( ], ) +kt_jvm_test( + name = "ReachOnlyLiquidLegionsSketchAggregationV2ProtocolEnumStagesTest", + srcs = ["ReachOnlyLiquidLegionsSketchAggregationV2ProtocolEnumStagesTest.kt"], + test_class = "org.wfanet.measurement.duchy.db.computation.ReachOnlyLiquidLegionsSketchAggregationV2ProtocolEnumStagesTest", + deps = [ + "//src/main/kotlin/org/wfanet/measurement/duchy/db/computation", + "//src/main/proto/wfa/measurement/internal/duchy/protocol:reach_only_liquid_legions_v2_kt_jvm_proto", + "@wfa_common_jvm//imports/java/org/junit", + "@wfa_common_jvm//imports/kotlin/kotlin/test", + ], +) + kt_jvm_test( name = "LiquidLegionsSketchAggregationV2ProtocolEnumStagesDetailsTest", srcs = ["LiquidLegionsSketchAggregationV2ProtocolEnumStagesDetailsTest.kt"], @@ -66,3 +78,17 @@ kt_jvm_test( "@wfa_common_jvm//imports/kotlin/kotlin/test", ], ) + +kt_jvm_test( + name = "ReachOnlyLiquidLegionsSketchAggregationV2ProtocolEnumStagesDetailsTest", + srcs = ["ReachOnlyLiquidLegionsSketchAggregationV2ProtocolEnumStagesDetailsTest.kt"], + test_class = "org.wfanet.measurement.duchy.db.computation.ReachOnlyLiquidLegionsSketchAggregationV2ProtocolEnumStagesDetailsTest", + deps = [ + "//src/main/kotlin/org/wfanet/measurement/duchy/db/computation", + "//src/main/proto/wfa/measurement/internal/duchy/protocol:reach_only_liquid_legions_v2_kt_jvm_proto", + "@wfa_common_jvm//imports/java/com/google/common/truth", + "@wfa_common_jvm//imports/java/com/google/common/truth/extensions/proto", + "@wfa_common_jvm//imports/java/org/junit", + "@wfa_common_jvm//imports/kotlin/kotlin/test", + ], +) diff --git a/src/test/kotlin/org/wfanet/measurement/duchy/db/computation/ComputationsEnumHelperTest.kt b/src/test/kotlin/org/wfanet/measurement/duchy/db/computation/ComputationsEnumHelperTest.kt index 27a94f94330..094cd8cfc65 100644 --- a/src/test/kotlin/org/wfanet/measurement/duchy/db/computation/ComputationsEnumHelperTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/duchy/db/computation/ComputationsEnumHelperTest.kt @@ -18,7 +18,9 @@ import kotlin.test.assertEquals import kotlin.test.assertFails import org.junit.Test import org.wfanet.measurement.internal.duchy.ComputationStage +import org.wfanet.measurement.internal.duchy.computationStage import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggregationV2 +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2 class ComputationsEnumHelperTest { @@ -39,6 +41,24 @@ class ComputationsEnumHelperTest { } } + @Test + fun `reachOnlyLiquidLegionsSketchAggregationV2 round trip conversion should get the same stage`() { + for (stage in ReachOnlyLiquidLegionsSketchAggregationV2.Stage.values()) { + if (stage != ReachOnlyLiquidLegionsSketchAggregationV2.Stage.UNRECOGNIZED) { + val computationStage = computationStage { + reachOnlyLiquidLegionsSketchAggregationV2 = stage + } + assertEquals( + computationStage, + ComputationProtocolStages.longValuesToComputationStageEnum( + ComputationProtocolStages.computationStageEnumToLongValues(computationStage) + ), + "protocolEnumToLong and longToProtocolEnum were not inverses for $stage" + ) + } + } + } + @Test fun `longValuesToComputationStageEnum with invalid numbers`() { assertFails { diff --git a/src/test/kotlin/org/wfanet/measurement/duchy/db/computation/ReachOnlyLiquidLegionsSketchAggregationV2ProtocolEnumStagesDetailsTest.kt b/src/test/kotlin/org/wfanet/measurement/duchy/db/computation/ReachOnlyLiquidLegionsSketchAggregationV2ProtocolEnumStagesDetailsTest.kt new file mode 100644 index 00000000000..9f98cfa618d --- /dev/null +++ b/src/test/kotlin/org/wfanet/measurement/duchy/db/computation/ReachOnlyLiquidLegionsSketchAggregationV2ProtocolEnumStagesDetailsTest.kt @@ -0,0 +1,64 @@ +// Copyright 2023 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package org.wfanet.measurement.duchy.db.computation + +import com.google.common.truth.extensions.proto.ProtoTruth +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.JUnit4 +import org.wfanet.measurement.duchy.db.computation.ReachOnlyLiquidLegionsSketchAggregationV2Protocol.EnumStages.Details +import org.wfanet.measurement.internal.duchy.ComputationStageDetails +import org.wfanet.measurement.internal.duchy.computationStageDetails +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2 +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2Kt.ComputationDetailsKt.computationParticipant +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2Kt.computationDetails +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2Kt.stageDetails +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2Kt.waitSetupPhaseInputsDetails + +@RunWith(JUnit4::class) +class ReachOnlyLiquidLegionsSketchAggregationV2ProtocolEnumStagesDetailsTest { + + @Test + fun `stage defaults and conversions`() { + for (stage in ReachOnlyLiquidLegionsSketchAggregationV2.Stage.values()) { + val expected = + when (stage) { + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.WAIT_SETUP_PHASE_INPUTS -> + computationStageDetails { + reachOnlyLiquidLegionsV2 = stageDetails { + waitSetupPhaseInputsDetails = waitSetupPhaseInputsDetails { + externalDuchyLocalBlobId["A"] = 0L + externalDuchyLocalBlobId["B"] = 1L + externalDuchyLocalBlobId["C"] = 2L + } + } + } + else -> ComputationStageDetails.getDefaultInstance() + } + val stageProto = + Details.detailsFor( + stage, + computationDetails { + participant += computationParticipant { duchyId = "A" } + participant += computationParticipant { duchyId = "B" } + participant += computationParticipant { duchyId = "C" } + participant += computationParticipant { duchyId = "D" } + } + ) + ProtoTruth.assertThat(stageProto).isEqualTo(expected) + ProtoTruth.assertThat(Details.parseDetails(stageProto.toByteArray())).isEqualTo(stageProto) + } + } +} diff --git a/src/test/kotlin/org/wfanet/measurement/duchy/db/computation/ReachOnlyLiquidLegionsSketchAggregationV2ProtocolEnumStagesTest.kt b/src/test/kotlin/org/wfanet/measurement/duchy/db/computation/ReachOnlyLiquidLegionsSketchAggregationV2ProtocolEnumStagesTest.kt new file mode 100644 index 00000000000..f93a6705671 --- /dev/null +++ b/src/test/kotlin/org/wfanet/measurement/duchy/db/computation/ReachOnlyLiquidLegionsSketchAggregationV2ProtocolEnumStagesTest.kt @@ -0,0 +1,103 @@ +// Copyright 2023 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package org.wfanet.measurement.duchy.db.computation + +import kotlin.test.assertEquals +import kotlin.test.assertFails +import kotlin.test.assertFalse +import kotlin.test.assertTrue +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.JUnit4 +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2 + +@RunWith(JUnit4::class) +class ReachOnlyLiquidLegionsSketchAggregationV2ProtocolEnumStagesTest { + @Test + fun `verify initial stage`() { + assertTrue { + ReachOnlyLiquidLegionsSketchAggregationV2Protocol.EnumStages.validInitialStage( + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.INITIALIZATION_PHASE + ) + } + assertFalse { + ReachOnlyLiquidLegionsSketchAggregationV2Protocol.EnumStages.validInitialStage( + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.WAIT_SETUP_PHASE_INPUTS + ) + } + } + + @Test + fun `enumToLong then longToEnum results in same enum value`() { + for (stage in ReachOnlyLiquidLegionsSketchAggregationV2.Stage.values()) { + if (stage == ReachOnlyLiquidLegionsSketchAggregationV2.Stage.UNRECOGNIZED) { + assertFails { + ReachOnlyLiquidLegionsSketchAggregationV2Protocol.EnumStages.enumToLong(stage) + } + } else { + assertEquals( + stage, + ReachOnlyLiquidLegionsSketchAggregationV2Protocol.EnumStages.longToEnum( + ReachOnlyLiquidLegionsSketchAggregationV2Protocol.EnumStages.enumToLong(stage) + ), + "enumToLong and longToEnum were not inverses for $stage" + ) + } + } + } + + @Test + fun `longToEnum with invalid numbers`() { + assertEquals( + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.UNRECOGNIZED, + ReachOnlyLiquidLegionsSketchAggregationV2Protocol.EnumStages.longToEnum(-1) + ) + assertEquals( + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.UNRECOGNIZED, + ReachOnlyLiquidLegionsSketchAggregationV2Protocol.EnumStages.longToEnum(1000) + ) + } + + @Test + fun `verify transistions`() { + assertTrue { + ReachOnlyLiquidLegionsSketchAggregationV2Protocol.EnumStages.validTransition( + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.WAIT_SETUP_PHASE_INPUTS, + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.SETUP_PHASE + ) + } + + assertFalse { + ReachOnlyLiquidLegionsSketchAggregationV2Protocol.EnumStages.validTransition( + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.WAIT_SETUP_PHASE_INPUTS, + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.COMPLETE + ) + } + + assertFalse { + ReachOnlyLiquidLegionsSketchAggregationV2Protocol.EnumStages.validTransition( + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.STAGE_UNSPECIFIED, + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.CONFIRMATION_PHASE + ) + } + + assertFalse { + ReachOnlyLiquidLegionsSketchAggregationV2Protocol.EnumStages.validTransition( + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.UNRECOGNIZED, + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.CONFIRMATION_PHASE + ) + } + } +} diff --git a/src/test/kotlin/org/wfanet/measurement/duchy/service/internal/computationcontrol/BUILD.bazel b/src/test/kotlin/org/wfanet/measurement/duchy/service/internal/computationcontrol/BUILD.bazel index 313a211e312..013cd9338f9 100644 --- a/src/test/kotlin/org/wfanet/measurement/duchy/service/internal/computationcontrol/BUILD.bazel +++ b/src/test/kotlin/org/wfanet/measurement/duchy/service/internal/computationcontrol/BUILD.bazel @@ -33,3 +33,19 @@ kt_jvm_test( "@wfa_common_jvm//imports/kotlin/org/mockito/kotlin", ], ) + +kt_jvm_test( + name = "ReachOnlyLiquidLegionsV2StagesTest", + srcs = ["ReachOnlyLiquidLegionsV2StagesTest.kt"], + test_class = "org.wfanet.measurement.duchy.service.internal.computationcontrol.ReachOnlyLiquidLegionsV2StagesTest", + deps = [ + "//src/main/kotlin/org/wfanet/measurement/duchy/db/computation/testing", + "//src/main/kotlin/org/wfanet/measurement/duchy/service/internal/computationcontrol:async_computation_control_service", + "@wfa_common_jvm//imports/java/com/google/common/truth", + "@wfa_common_jvm//imports/java/com/google/common/truth/extensions/proto", + "@wfa_common_jvm//imports/java/org/junit", + "@wfa_common_jvm//imports/kotlin/kotlin/test", + "@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core", + "@wfa_common_jvm//imports/kotlin/org/mockito/kotlin", + ], +) diff --git a/src/test/kotlin/org/wfanet/measurement/duchy/service/internal/computationcontrol/ReachOnlyLiquidLegionsV2StagesTest.kt b/src/test/kotlin/org/wfanet/measurement/duchy/service/internal/computationcontrol/ReachOnlyLiquidLegionsV2StagesTest.kt new file mode 100644 index 00000000000..d1c0da60d20 --- /dev/null +++ b/src/test/kotlin/org/wfanet/measurement/duchy/service/internal/computationcontrol/ReachOnlyLiquidLegionsV2StagesTest.kt @@ -0,0 +1,124 @@ +// Copyright 2020 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package org.wfanet.measurement.duchy.service.internal.computationcontrol + +import com.google.common.truth.Truth.assertThat +import kotlin.test.assertFailsWith +import kotlin.test.assertTrue +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.JUnit4 +import org.wfanet.measurement.duchy.db.computation.ReachOnlyLiquidLegionsSketchAggregationV2Protocol +import org.wfanet.measurement.duchy.service.internal.computations.newEmptyOutputBlobMetadata +import org.wfanet.measurement.duchy.toProtocolStage +import org.wfanet.measurement.internal.duchy.computationStage +import org.wfanet.measurement.internal.duchy.computationStageDetails +import org.wfanet.measurement.internal.duchy.computationToken +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2.Stage +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2Kt.stageDetails +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2Kt.waitSetupPhaseInputsDetails + +@RunWith(JUnit4::class) +class ReachOnlyLiquidLegionsV2StagesTest { + private val stages = ReachOnlyLiquidLegionsV2Stages() + + @Test + fun `next stages are valid`() { + fun assertContextThrowsErrorWhenCallingNextStage(stage: Stage) { + if (stage == Stage.UNRECOGNIZED) { + return + } + val token = computationToken { + computationStage = computationStage { reachOnlyLiquidLegionsSketchAggregationV2 = stage } + blobs += newEmptyOutputBlobMetadata(1L) + } + val ex = assertFailsWith { stages.outputBlob(token, "Buck") } + } + + for (stage in Stage.values()) { + when (stage) { + Stage.WAIT_SETUP_PHASE_INPUTS, + Stage.WAIT_EXECUTION_PHASE_INPUTS, -> { + val next = + stages.nextStage(stage.toProtocolStage()).reachOnlyLiquidLegionsSketchAggregationV2 + assertTrue("$next is not a valid successor of $stage") { + ReachOnlyLiquidLegionsSketchAggregationV2Protocol.EnumStages.validTransition( + stage, + next + ) + } + } + else -> assertContextThrowsErrorWhenCallingNextStage(stage) + } + } + } + + @Test + fun `output blob for wait sketches`() { + val token = computationToken { + computationStage = Stage.WAIT_SETUP_PHASE_INPUTS.toProtocolStage() + blobs += newEmptyOutputBlobMetadata(1L) + blobs += newEmptyOutputBlobMetadata(21L) + + stageSpecificDetails = computationStageDetails { + reachOnlyLiquidLegionsV2 = stageDetails { + waitSetupPhaseInputsDetails = waitSetupPhaseInputsDetails { + externalDuchyLocalBlobId["alice"] = 21L + externalDuchyLocalBlobId["bob"] = 1L + } + } + } + } + + assertThat(stages.outputBlob(token, "alice")).isEqualTo(newEmptyOutputBlobMetadata(21L)) + assertThat(stages.outputBlob(token, "bob")).isEqualTo(newEmptyOutputBlobMetadata(1L)) + assertFailsWith { stages.outputBlob(token, "unknown-sender") } + } + + @Test + fun `output blob for execution phase inputs`() { + val token = computationToken { + computationStage = Stage.WAIT_EXECUTION_PHASE_INPUTS.toProtocolStage() + blobs += newEmptyOutputBlobMetadata(1L) + } + + assertThat(stages.outputBlob(token, "Buck")).isEqualTo(newEmptyOutputBlobMetadata(1L)) + } + + @Test + fun `output blob for unsupported stages throws`() { + fun assertContextThrowsErrorWhenGettingBlob(stage: Stage) { + if (stage == Stage.UNRECOGNIZED) { + return + } + + val token = computationToken { + computationStage = computationStage { reachOnlyLiquidLegionsSketchAggregationV2 = stage } + blobs += newEmptyOutputBlobMetadata(1L) + } + + assertFailsWith { stages.outputBlob(token, "Buck") } + } + + for (stage in Stage.values()) { + when (stage) { + // Skip all the supported stages, they are tested elsewhere. + Stage.WAIT_SETUP_PHASE_INPUTS, + Stage.WAIT_EXECUTION_PHASE_INPUTS -> {} + else -> assertContextThrowsErrorWhenGettingBlob(stage) + } + } + } +} diff --git a/src/test/kotlin/org/wfanet/measurement/duchy/service/system/v1alpha/ComputationControlServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/duchy/service/system/v1alpha/ComputationControlServiceTest.kt index 5d146674e04..2856c900066 100644 --- a/src/test/kotlin/org/wfanet/measurement/duchy/service/system/v1alpha/ComputationControlServiceTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/duchy/service/system/v1alpha/ComputationControlServiceTest.kt @@ -46,13 +46,16 @@ import org.wfanet.measurement.internal.duchy.AdvanceComputationResponse as Async import org.wfanet.measurement.internal.duchy.AsyncComputationControlGrpcKt.AsyncComputationControlCoroutineImplBase import org.wfanet.measurement.internal.duchy.AsyncComputationControlGrpcKt.AsyncComputationControlCoroutineStub import org.wfanet.measurement.internal.duchy.ComputationBlobDependency +import org.wfanet.measurement.internal.duchy.advanceComputationRequest as asyncAdvanceComputationRequest import org.wfanet.measurement.internal.duchy.computationStageBlobMetadata import org.wfanet.measurement.internal.duchy.getOutputBlobMetadataRequest import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggregationV2 +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2 import org.wfanet.measurement.storage.filesystem.FileSystemStorageClient import org.wfanet.measurement.storage.testing.BlobSubject.Companion.assertThat import org.wfanet.measurement.system.v1alpha.AdvanceComputationRequest import org.wfanet.measurement.system.v1alpha.LiquidLegionsV2 +import org.wfanet.measurement.system.v1alpha.ReachOnlyLiquidLegionsV2 private const val RUNNING_DUCHY_NAME = "Alsace" private const val BAVARIA = "Bavaria" @@ -315,6 +318,72 @@ class ComputationControlServiceTest { } } } + + @Test + fun `reach only liquid legions v2 send setup inputs`() = runBlocking { + val id = "311311" + val blobKey = "$id/WAIT_SETUP_PHASE_INPUTS/$BLOB_ID" + val carinthiaHeader = + advanceComputationHeader(ReachOnlyLiquidLegionsV2.Description.SETUP_PHASE_INPUT, id) + withSender(carinthia) { advanceComputation(carinthiaHeader.withContent("contents")) } + + verifyProtoArgument( + mockAsyncControlService, + AsyncComputationControlCoroutineImplBase::getOutputBlobMetadata + ) + .isEqualTo( + getOutputBlobMetadataRequest { + globalComputationId = id + dataOrigin = CARINTHIA + } + ) + assertThat(advanceAsyncComputationRequests) + .containsExactly( + asyncAdvanceComputationRequest { + globalComputationId = id + computationStage = + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.WAIT_SETUP_PHASE_INPUTS + .toProtocolStage() + blobId = BLOB_ID + blobPath = blobKey + } + ) + val data = assertNotNull(computationStore.get(blobKey)) + assertThat(data).contentEqualTo(ByteString.copyFromUtf8("contents")) + } + + @Test + fun `reach only liquid legions v2 send execution phase inputs`() = runBlocking { + val id = "444444" + val blobKey = "$id/WAIT_EXECUTION_PHASE_INPUTS/$BLOB_ID" + val carinthiaHeader = + advanceComputationHeader(ReachOnlyLiquidLegionsV2.Description.EXECUTION_PHASE_INPUT, id) + withSender(carinthia) { advanceComputation(carinthiaHeader.withContent("contents")) } + + verifyProtoArgument( + mockAsyncControlService, + AsyncComputationControlCoroutineImplBase::getOutputBlobMetadata + ) + .isEqualTo( + getOutputBlobMetadataRequest { + globalComputationId = id + dataOrigin = CARINTHIA + } + ) + assertThat(advanceAsyncComputationRequests) + .containsExactly( + asyncAdvanceComputationRequest { + globalComputationId = id + computationStage = + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.WAIT_EXECUTION_PHASE_INPUTS + .toProtocolStage() + blobId = BLOB_ID + blobPath = blobKey + } + ) + val data = assertNotNull(computationStore.get(blobKey)) + assertThat(data).contentEqualTo(ByteString.copyFromUtf8("contents")) + } } private fun AdvanceComputationRequest.Header.withContent( diff --git a/src/test/kotlin/org/wfanet/measurement/kingdom/service/system/v1alpha/ComputationParticipantsServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/kingdom/service/system/v1alpha/ComputationParticipantsServiceTest.kt index 96699d8952d..b655bdf1316 100644 --- a/src/test/kotlin/org/wfanet/measurement/kingdom/service/system/v1alpha/ComputationParticipantsServiceTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/kingdom/service/system/v1alpha/ComputationParticipantsServiceTest.kt @@ -49,6 +49,7 @@ import org.wfanet.measurement.internal.kingdom.certificate as internalCertificat import org.wfanet.measurement.internal.kingdom.copy import org.wfanet.measurement.internal.kingdom.duchyMeasurementLogEntry import org.wfanet.measurement.internal.kingdom.measurementLogEntry +import org.wfanet.measurement.internal.kingdom.setParticipantRequisitionParamsRequest as internalSetParticipantRequisitionParamsRequest import org.wfanet.measurement.system.v1alpha.ComputationParticipant import org.wfanet.measurement.system.v1alpha.ComputationParticipantKt.RequisitionParamsKt.liquidLegionsV2 import org.wfanet.measurement.system.v1alpha.ComputationParticipantKt.requisitionParams @@ -56,6 +57,7 @@ import org.wfanet.measurement.system.v1alpha.ConfirmComputationParticipantReques import org.wfanet.measurement.system.v1alpha.FailComputationParticipantRequest import org.wfanet.measurement.system.v1alpha.SetParticipantRequisitionParamsRequest import org.wfanet.measurement.system.v1alpha.computationParticipant +import org.wfanet.measurement.system.v1alpha.setParticipantRequisitionParamsRequest private const val DUCHY_ID: String = "some-duchy-id" private const val MILL_ID: String = "some-mill-id" @@ -89,6 +91,14 @@ private val INTERNAL_COMPUTATION_PARTICIPANT = nanos = 456 } apiVersion = PUBLIC_API_VERSION + details = + InternalComputationParticipantKt.details { + liquidLegionsV2 = + InternalComputationParticipantKt.liquidLegionsV2Details { + elGamalPublicKey = DUCHY_ELGAMAL_KEY + elGamalPublicKeySignature = DUCHY_ELGAMAL_KEY_SIGNATURE + } + } } .build() @@ -168,7 +178,7 @@ class ComputationParticipantsServiceTest { ) @Test - fun `SetParticipantRequisitionParams successfully`() = runBlocking { + fun `SetParticipantRequisitionParams for llv2 successfully`() = runBlocking { whenever(internalComputationParticipantsServiceMock.setParticipantRequisitionParams(any())) .thenReturn(INTERNAL_COMPUTATION_PARTICIPANT_WITH_PARAMS) @@ -219,6 +229,66 @@ class ComputationParticipantsServiceTest { ) } + @Test + fun `SetParticipantRequisitionParams for rollv2 successfully`() = runBlocking { + val internalComputationParticipantWithRoLlv2Params = + INTERNAL_COMPUTATION_PARTICIPANT_WITH_PARAMS.copy { + details = + InternalComputationParticipantKt.details { + reachOnlyLiquidLegionsV2 = + InternalComputationParticipantKt.liquidLegionsV2Details { + elGamalPublicKey = DUCHY_ELGAMAL_KEY + elGamalPublicKeySignature = DUCHY_ELGAMAL_KEY_SIGNATURE + } + } + } + + whenever(internalComputationParticipantsServiceMock.setParticipantRequisitionParams(any())) + .thenReturn(internalComputationParticipantWithRoLlv2Params) + + val request = setParticipantRequisitionParamsRequest { + name = SYSTEM_COMPUTATION_PARTICIPANT_NAME + requisitionParams = requisitionParams { + duchyCertificate = DUCHY_CERTIFICATE_PUBLIC_API_NAME + reachOnlyLiquidLegionsV2 = liquidLegionsV2 { + elGamalPublicKey = DUCHY_ELGAMAL_KEY + elGamalPublicKeySignature = DUCHY_ELGAMAL_KEY_SIGNATURE + } + } + } + val response: ComputationParticipant = service.setParticipantRequisitionParams(request) + + assertThat(response) + .isEqualTo( + computationParticipant { + name = SYSTEM_COMPUTATION_PARTICIPANT_NAME + state = ComputationParticipant.State.REQUISITION_PARAMS_SET + updateTime = INTERNAL_COMPUTATION_PARTICIPANT.updateTime + requisitionParams = requisitionParams { + duchyCertificate = DUCHY_CERTIFICATE_PUBLIC_API_NAME + duchyCertificateDer = DUCHY_CERTIFICATE_DER + reachOnlyLiquidLegionsV2 = liquidLegionsV2 { + elGamalPublicKey = DUCHY_ELGAMAL_KEY + elGamalPublicKeySignature = DUCHY_ELGAMAL_KEY_SIGNATURE + } + } + } + ) + verifyProtoArgument( + internalComputationParticipantsServiceMock, + InternalComputationParticipantsCoroutineService::setParticipantRequisitionParams + ) + .isEqualTo( + internalSetParticipantRequisitionParamsRequest { + externalComputationId = EXTERNAL_COMPUTATION_ID + externalDuchyId = DUCHY_ID + externalDuchyCertificateId = EXTERNAL_DUCHY_CERTIFICATE_ID + reachOnlyLiquidLegionsV2 = + internalComputationParticipantWithRoLlv2Params.details.reachOnlyLiquidLegionsV2 + } + ) + } + @Test fun `FailComputationParticipant successfully`() = runBlocking { whenever(internalComputationParticipantsServiceMock.failComputationParticipant(any())) diff --git a/src/test/kotlin/org/wfanet/measurement/kingdom/service/system/v1alpha/ComputationsServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/kingdom/service/system/v1alpha/ComputationsServiceTest.kt index 45b0c81a3ec..ad4271dd421 100644 --- a/src/test/kotlin/org/wfanet/measurement/kingdom/service/system/v1alpha/ComputationsServiceTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/kingdom/service/system/v1alpha/ComputationsServiceTest.kt @@ -16,6 +16,7 @@ package org.wfanet.measurement.kingdom.service.system.v1alpha import com.google.common.truth.extensions.proto.ProtoTruth.assertThat import com.google.protobuf.ByteString +import com.google.protobuf.timestamp import com.google.protobuf.util.Timestamps import kotlinx.coroutines.flow.flowOf import kotlinx.coroutines.flow.take @@ -38,6 +39,7 @@ import org.wfanet.measurement.common.identity.externalIdToApiId import org.wfanet.measurement.common.identity.testing.DuchyIdSetter import org.wfanet.measurement.common.testing.verifyProtoArgument import org.wfanet.measurement.internal.kingdom.ComputationParticipant as InternalComputationParticipant +import org.wfanet.measurement.internal.kingdom.ComputationParticipantKt as InternalComputationParticipantKt import org.wfanet.measurement.internal.kingdom.DuchyProtocolConfigKt import org.wfanet.measurement.internal.kingdom.DuchyProtocolConfigKt.LiquidLegionsV2Kt.mpcNoise import org.wfanet.measurement.internal.kingdom.GetMeasurementByComputationIdRequest @@ -54,8 +56,10 @@ import org.wfanet.measurement.internal.kingdom.StreamMeasurementsRequest import org.wfanet.measurement.internal.kingdom.StreamMeasurementsRequestKt import org.wfanet.measurement.internal.kingdom.StreamMeasurementsRequestKt.filter import org.wfanet.measurement.internal.kingdom.computationKey -import org.wfanet.measurement.internal.kingdom.differentialPrivacyParams +import org.wfanet.measurement.internal.kingdom.copy +import org.wfanet.measurement.internal.kingdom.differentialPrivacyParams as internalDifferentialPrivacyParams import org.wfanet.measurement.internal.kingdom.duchyProtocolConfig +import org.wfanet.measurement.internal.kingdom.getMeasurementByComputationIdRequest import org.wfanet.measurement.internal.kingdom.liquidLegionsSketchParams import org.wfanet.measurement.internal.kingdom.measurement as internalMeasurement import org.wfanet.measurement.internal.kingdom.protocolConfig @@ -63,12 +67,23 @@ import org.wfanet.measurement.internal.kingdom.streamMeasurementsRequest import org.wfanet.measurement.system.v1alpha.Computation import org.wfanet.measurement.system.v1alpha.Computation.MpcProtocolConfig.NoiseMechanism import org.wfanet.measurement.system.v1alpha.ComputationKey +import org.wfanet.measurement.system.v1alpha.ComputationKt.MpcProtocolConfigKt +import org.wfanet.measurement.system.v1alpha.ComputationKt.mpcProtocolConfig import org.wfanet.measurement.system.v1alpha.ComputationParticipant +import org.wfanet.measurement.system.v1alpha.ComputationParticipantKt +import org.wfanet.measurement.system.v1alpha.ComputationParticipantKt.RequisitionParamsKt +import org.wfanet.measurement.system.v1alpha.ComputationParticipantKt.requisitionParams import org.wfanet.measurement.system.v1alpha.GetComputationRequest import org.wfanet.measurement.system.v1alpha.Requisition import org.wfanet.measurement.system.v1alpha.SetComputationResultRequest import org.wfanet.measurement.system.v1alpha.StreamActiveComputationsRequest import org.wfanet.measurement.system.v1alpha.StreamActiveComputationsResponse +import org.wfanet.measurement.system.v1alpha.computation +import org.wfanet.measurement.system.v1alpha.computationParticipant +import org.wfanet.measurement.system.v1alpha.differentialPrivacyParams +import org.wfanet.measurement.system.v1alpha.getComputationRequest +import org.wfanet.measurement.system.v1alpha.requisition +import org.wfanet.measurement.system.v1alpha.stageAttempt private const val DUCHY_ID: String = "some-duchy-id" private const val MILL_ID: String = "some-mill-id" @@ -169,6 +184,18 @@ private val INTERNAL_COMPUTATION_PARTICIPANT = } .build() +private val INTERNAL_RO_LLV2_COMPUTATION_PARTICIPANT = + INTERNAL_COMPUTATION_PARTICIPANT.copy { + details = + InternalComputationParticipantKt.details { + reachOnlyLiquidLegionsV2 = + InternalComputationParticipantKt.liquidLegionsV2Details { + elGamalPublicKey = DUCHY_ELGAMAL_KEY + elGamalPublicKeySignature = DUCHY_ELGAMAL_KEY_SIGNATURE + } + } + } + private val INTERNAL_MEASUREMENT = internalMeasurement { externalComputationId = EXTERNAL_COMPUTATION_ID state = InternalMeasurement.State.FAILED @@ -179,11 +206,11 @@ private val INTERNAL_MEASUREMENT = internalMeasurement { liquidLegionsV2 = DuchyProtocolConfigKt.liquidLegionsV2 { mpcNoise = mpcNoise { - blindedHistogramNoise = differentialPrivacyParams { + blindedHistogramNoise = internalDifferentialPrivacyParams { epsilon = 1.1 delta = 2.1 } - noiseForPublisherNoise = differentialPrivacyParams { + noiseForPublisherNoise = internalDifferentialPrivacyParams { epsilon = 3.1 delta = 4.1 } @@ -213,6 +240,43 @@ private val INTERNAL_MEASUREMENT = internalMeasurement { requisitions += INTERNAL_REQUISITION } +private val INTERNAL_RO_LLV2_MEASUREMENT = + INTERNAL_MEASUREMENT.copy { + details = details { + apiVersion = PUBLIC_API_VERSION + measurementSpec = MEASUREMENT_SPEC + duchyProtocolConfig = duchyProtocolConfig { + reachOnlyLiquidLegionsV2 = + DuchyProtocolConfigKt.liquidLegionsV2 { + mpcNoise = mpcNoise { + blindedHistogramNoise = internalDifferentialPrivacyParams { + epsilon = 1.1 + delta = 2.1 + } + noiseForPublisherNoise = internalDifferentialPrivacyParams { + epsilon = 3.1 + delta = 4.1 + } + } + } + } + protocolConfig = protocolConfig { + reachOnlyLiquidLegionsV2 = + ProtocolConfigKt.liquidLegionsV2 { + sketchParams = liquidLegionsSketchParams { + decayRate = 10.0 + maxSize = 100 + samplingIndicatorSize = 1000 + } + ellipticCurveId = 123 + noiseMechanism = InternalNoiseMechanism.GEOMETRIC + } + } + } + computationParticipants.clear() + computationParticipants += INTERNAL_RO_LLV2_COMPUTATION_PARTICIPANT + } + @RunWith(JUnit4::class) class ComputationsServiceTest { @get:Rule val duchyIdSetter = DuchyIdSetter(DUCHY_ID) @@ -231,7 +295,7 @@ class ComputationsServiceTest { ) @Test - fun `get computation successfully`() = runBlocking { + fun `get llv2 computation successfully`() = runBlocking { whenever(internalMeasurementsServiceMock.getMeasurementByComputationId(any())) .thenReturn(INTERNAL_MEASUREMENT) @@ -328,6 +392,100 @@ class ComputationsServiceTest { ) } + @Test + fun `get rollv2 computation successfully`() = runBlocking { + whenever(internalMeasurementsServiceMock.getMeasurementByComputationId(any())) + .thenReturn(INTERNAL_RO_LLV2_MEASUREMENT) + + val request = getComputationRequest { name = SYSTEM_COMPUTATION_NAME } + + val response = service.getComputation(request) + + assertThat(response) + .isEqualTo( + computation { + name = SYSTEM_COMPUTATION_NAME + publicApiVersion = PUBLIC_API_VERSION + measurementSpec = MEASUREMENT_SPEC + state = Computation.State.FAILED + aggregatorCertificate = DUCHY_CERTIFICATE_PUBLIC_API_NAME + encryptedResult = ENCRYPTED_RESULT + mpcProtocolConfig = mpcProtocolConfig { + reachOnlyLiquidLegionsV2 = + MpcProtocolConfigKt.liquidLegionsV2 { + sketchParams = + MpcProtocolConfigKt.LiquidLegionsV2Kt.liquidLegionsSketchParams { + decayRate = 10.0 + maxSize = 100 + } + mpcNoise = + MpcProtocolConfigKt.LiquidLegionsV2Kt.mpcNoise { + blindedHistogramNoise = differentialPrivacyParams { + epsilon = 1.1 + delta = 2.1 + } + publisherNoise = differentialPrivacyParams { + epsilon = 3.1 + delta = 4.1 + } + } + ellipticCurveId = 123 + noiseMechanism = NoiseMechanism.GEOMETRIC + } + } + requisitions += requisition { + name = SYSTEM_REQUISITION_NAME + state = Requisition.State.FULFILLED + requisitionSpecHash = ENCRYPTED_REQUISITION_SPEC_HASH.bytes + nonceHash = NONCE_HASH.bytes + fulfillingComputationParticipant = SYSTEM_COMPUTATION_PARTICIPATE_NAME + nonce = NONCE + } + computationParticipants += computationParticipant { + name = SYSTEM_COMPUTATION_PARTICIPATE_NAME + state = ComputationParticipant.State.FAILED + updateTime = timestamp { + seconds = 123 + nanos = 456 + } + requisitionParams = requisitionParams { + duchyCertificate = DUCHY_CERTIFICATE_PUBLIC_API_NAME + reachOnlyLiquidLegionsV2 = + RequisitionParamsKt.liquidLegionsV2 { + elGamalPublicKey = DUCHY_ELGAMAL_KEY + elGamalPublicKeySignature = DUCHY_ELGAMAL_KEY_SIGNATURE + } + } + failure = + ComputationParticipantKt.failure { + participantChildReferenceId = MILL_ID + errorMessage = DUCHY_ERROR_MESSAGE + errorTime = timestamp { + seconds = 1001 + nanos = 2002 + } + stageAttempt = stageAttempt { + stage = STAGE_ATTEMPT_STAGE + stageName = STAGE_ATTEMPT_STAGE_NAME + attemptNumber = STAGE_ATTEMPT_ATTEMPT_NUMBER + stageStartTime = timestamp { + seconds = 100 + nanos = 200 + } + } + } + } + } + ) + verifyProtoArgument( + internalMeasurementsServiceMock, + InternalMeasurementsCoroutineService::getMeasurementByComputationId + ) + .isEqualTo( + getMeasurementByComputationIdRequest { externalComputationId = EXTERNAL_COMPUTATION_ID } + ) + } + @Test fun `stream active computations successfully`() = runBlocking { var calls = 0L From b0ba12a9fe26135bdafabd37256ce75a8dc1a0f5 Mon Sep 17 00:00:00 2001 From: Phi Hung Le Date: Sat, 5 Aug 2023 01:54:37 -0400 Subject: [PATCH 11/15] Fix build and test issues. --- .../daemon/herald/LiquidLegionsV2Starter.kt | 2 +- .../liquidlegionsv2/LiquidLegionsV2Mill.kt | 4 +- .../ReachOnlyLiquidLegionsV2Mill.kt | 16 +- ...liquid_legions_sketch_aggregation_v2.proto | 3 +- .../duchy/daemon/herald/HeraldTest.kt | 4 +- .../LiquidLegionsV2MillTest.kt | 2 +- .../ReachOnlyLiquidLegionsV2MillTest.kt | 278 ++++++++++++------ 7 files changed, 212 insertions(+), 97 deletions(-) diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/herald/LiquidLegionsV2Starter.kt b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/herald/LiquidLegionsV2Starter.kt index 00ee53c174b..aef0566817d 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/herald/LiquidLegionsV2Starter.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/herald/LiquidLegionsV2Starter.kt @@ -298,7 +298,7 @@ object LiquidLegionsV2Starter { return parameters { maximumFrequency = llv2Config.maximumFrequency - liquidLegionsSketch = liquidLegionsSketchParameters { + sketchParameters = liquidLegionsSketchParameters { decayRate = llv2Config.sketchParams.decayRate size = llv2Config.sketchParams.maxSize } diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/LiquidLegionsV2Mill.kt b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/LiquidLegionsV2Mill.kt index fc51824e5a7..488301d7e4c 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/LiquidLegionsV2Mill.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/LiquidLegionsV2Mill.kt @@ -673,8 +673,8 @@ class LiquidLegionsV2Mill( flagCountTuples = readAndCombineAllInputBlobs(token, 1) maximumFrequency = maximumRequestedFrequency liquidLegionsParameters = liquidLegionsSketchParameters { - decayRate = llv2Parameters.liquidLegionsSketch.decayRate - size = llv2Parameters.liquidLegionsSketch.size + decayRate = llv2Parameters.sketchParameters.decayRate + size = llv2Parameters.sketchParameters.size } vidSamplingIntervalWidth = measurementSpec.vidSamplingInterval.width if (llv2Parameters.noise.hasReachNoiseConfig()) { diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2Mill.kt b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2Mill.kt index 63e16ab2be5..c130f179660 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2Mill.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2Mill.kt @@ -464,7 +464,8 @@ class ReachOnlyLiquidLegionsV2Mill( .readAllRequisitionBlobs(token, duchyId) .concat(readAndCombineAllInputBlobsSetupPhaseAtAggregator(token, workerStubs.size)) .toCompleteReachOnlySetupPhaseAtAggregatorRequest(rollv2Details, token.requisitionsCount) - val cryptoResult: CompleteReachOnlySetupPhaseResponse = cryptoWorker.completeReachOnlySetupPhaseAtAggregator(request) + val cryptoResult: CompleteReachOnlySetupPhaseResponse = + cryptoWorker.completeReachOnlySetupPhaseAtAggregator(request) logStageDurationMetric( token, CRYPTO_LIB_CPU_DURATION, @@ -539,7 +540,8 @@ class ReachOnlyLiquidLegionsV2Mill( val measurementSpec = MeasurementSpec.parseFrom(token.computationDetails.kingdomComputation.measurementSpec) val inputBlob = readAndCombineAllInputBlobs(token, 1) - require(inputBlob.size() < kBytesPerCipherText) { "Invalid input blob size. Input blob ${inputBlob.toStringUtf8()} has size ${inputBlob.size()} which is less than ($kBytesPerCipherText)." } + require(inputBlob.size() >= kBytesPerCipherText) { "Invalid input blob size. Input blob ${inputBlob.toStringUtf8()} has size ${inputBlob.size()} which is less than ($kBytesPerCipherText)." } + var reach = 0L val (bytes, nextToken) = existingOutputOr(token) { val request = completeReachOnlyExecutionPhaseAtAggregatorRequest { @@ -554,8 +556,8 @@ class ReachOnlyLiquidLegionsV2Mill( } } liquidLegionsParameters = liquidLegionsSketchParameters { - decayRate = rollv2Parameters.reachOnlyLiquidLegionsSketch.decayRate - size = rollv2Parameters.reachOnlyLiquidLegionsSketch.size + decayRate = rollv2Parameters.sketchParameters.decayRate + size = rollv2Parameters.sketchParameters.size } vidSamplingIntervalWidth = measurementSpec.vidSamplingInterval.width if (noiseConfig.hasReachNoiseConfig()) { @@ -582,10 +584,10 @@ class ReachOnlyLiquidLegionsV2Mill( cryptoResult.elapsedCpuTimeMillis, executionPhaseCryptoCpuTimeDurationHistogram ) + reach = cryptoResult.reach cryptoResult.toByteString() } - val reach = CompleteReachOnlyExecutionPhaseAtAggregatorResponse.parseFrom(bytes.flatten()).reach sendResultToKingdom(token, ReachResult(reach)) return completeComputation(nextToken, CompletedReason.SUCCEEDED) } @@ -596,7 +598,7 @@ class ReachOnlyLiquidLegionsV2Mill( val rollv2Details = token.computationDetails.reachOnlyLiquidLegionsV2 require(NON_AGGREGATOR == rollv2Details.role) { "invalid role for this function." } val inputBlob = readAndCombineAllInputBlobs(token, 1) - require(inputBlob.size() < kBytesPerCipherText) { "Invalid input blob size. Input blob ${inputBlob.toStringUtf8()} has size ${inputBlob.size()} which is less than ($kBytesPerCipherText)." } + require(inputBlob.size() >= kBytesPerCipherText) { "Invalid input blob size. Input blob ${inputBlob.toStringUtf8()} has size ${inputBlob.size()} which is less than ($kBytesPerCipherText)." } val (bytes, nextToken) = existingOutputOr(token) { val cryptoResult: CompleteReachOnlyExecutionPhaseResponse = @@ -750,7 +752,7 @@ class ReachOnlyLiquidLegionsV2Mill( var combinedRegisterVector = ByteString.EMPTY var combinedNoiseCiphertext = ByteString.EMPTY for (str in blobMap.values) { - require(str.size() < kBytesPerCipherText) { "Invalid input blob size. Input blob ${str.toStringUtf8()} has size ${str.size()} which is less than ($kBytesPerCipherText)." } + require(str.size() >= kBytesPerCipherText) { "Invalid input blob size. Input blob ${str.toStringUtf8()} has size ${str.size()} which is less than ($kBytesPerCipherText)." } combinedRegisterVector = combinedRegisterVector.concat(str.substring(0, str.size() - kBytesPerCipherText)) combinedNoiseCiphertext = combinedNoiseCiphertext.concat(str.substring(str.size() - kBytesPerCipherText, str.size())) } diff --git a/src/main/proto/wfa/measurement/internal/duchy/protocol/liquid_legions_sketch_aggregation_v2.proto b/src/main/proto/wfa/measurement/internal/duchy/protocol/liquid_legions_sketch_aggregation_v2.proto index bc49b00bd17..d99d9c43a38 100644 --- a/src/main/proto/wfa/measurement/internal/duchy/protocol/liquid_legions_sketch_aggregation_v2.proto +++ b/src/main/proto/wfa/measurement/internal/duchy/protocol/liquid_legions_sketch_aggregation_v2.proto @@ -119,8 +119,7 @@ message LiquidLegionsSketchAggregationV2 { // The maximum frequency to reveal in the histogram. int32 maximum_frequency = 1; // Parameters used for liquidLegions sketch creation and estimation. - // TODO(@ple): rename to sketch_parameters. - LiquidLegionsSketchParameters liquid_legions_sketch = 2; + LiquidLegionsSketchParameters sketch_parameters = 2; // Noise parameters selected for the LiquidLegionV2 MPC protocol. LiquidLegionsV2NoiseConfig noise = 3; // ID of the OpenSSL built-in elliptic curve. For example, 415 for the diff --git a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/herald/HeraldTest.kt b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/herald/HeraldTest.kt index 055b6c572e6..9ff4d00a892 100644 --- a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/herald/HeraldTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/herald/HeraldTest.kt @@ -465,7 +465,7 @@ class HeraldTest { parameters = LiquidLegionsSketchAggregationV2Kt.ComputationDetailsKt.parameters { maximumFrequency = 10 - liquidLegionsSketch = liquidLegionsSketchParameters { + sketchParameters = liquidLegionsSketchParameters { decayRate = 12.0 size = 100_000L } @@ -567,7 +567,7 @@ class HeraldTest { parameters = LiquidLegionsSketchAggregationV2Kt.ComputationDetailsKt.parameters { maximumFrequency = 10 - liquidLegionsSketch = liquidLegionsSketchParameters { + sketchParameters = liquidLegionsSketchParameters { decayRate = 12.0 size = 100_000L } diff --git a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/LiquidLegionsV2MillTest.kt b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/LiquidLegionsV2MillTest.kt index 5d2deba3548..a2e0b097296 100644 --- a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/LiquidLegionsV2MillTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/LiquidLegionsV2MillTest.kt @@ -256,7 +256,7 @@ private val LLV2_PARAMETERS = Parameters.newBuilder() .apply { maximumFrequency = MAX_FREQUENCY - liquidLegionsSketchBuilder.apply { + sketchParametersBuilder.apply { decayRate = DECAY_RATE size = SKETCH_SIZE } diff --git a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2MillTest.kt b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2MillTest.kt index c8065eec7a1..066f0f4ebc0 100644 --- a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2MillTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2MillTest.kt @@ -232,7 +232,7 @@ private val TEST_NOISE_CONFIG = private val ROLLV2_PARAMETERS = Parameters.newBuilder() .apply { - reachOnlyLiquidLegionsSketchBuilder.apply { + sketchParametersBuilder.apply { decayRate = DECAY_RATE size = SKETCH_SIZE } @@ -1066,6 +1066,64 @@ class ReachOnlyLiquidLegionsV2MillTest { ) } + @Test + fun `setup phase at non-aggregator using cached result`() = runBlocking { + // Stage 0. preparing the storage and set up mock + val partialToken = + FakeComputationsDatabase.newPartialToken( + localId = LOCAL_ID, + stage = SETUP_PHASE.toProtocolStage() + ) + .build() + val requisitionBlobContext = + RequisitionBlobContext(GLOBAL_ID, REQUISITION_1.externalKey.externalRequisitionId) + requisitionStore.writeString(requisitionBlobContext, "local_requisition") + val cachedBlobContext = + ComputationBlobContext(GLOBAL_ID, SETUP_PHASE.toProtocolStage(), 1L) + computationStore.writeString(cachedBlobContext, "cached result") + fakeComputationDb.addComputation( + partialToken.localComputationId, + partialToken.computationStage, + computationDetails = NON_AGGREGATOR_COMPUTATION_DETAILS, + requisitions = listOf(REQUISITION_1, REQUISITION_2, REQUISITION_3), + blobs = listOf(cachedBlobContext.toMetadata(ComputationBlobDependency.OUTPUT)) + ) + + // Stage 1. Process the above computation + nonAggregatorMill.pollAndProcessNextComputation() + + // Stage 2. Check the status of the computation + assertThat(fakeComputationDb[LOCAL_ID]) + .isEqualTo( + ComputationToken.newBuilder() + .apply { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 1 + computationStage = WAIT_EXECUTION_PHASE_INPUTS.toProtocolStage() + addBlobsBuilder().apply { + dependencyType = ComputationBlobDependency.INPUT + blobId = 0L + path = cachedBlobContext.blobKey + } + addBlobsBuilder().apply { + dependencyType = ComputationBlobDependency.OUTPUT + blobId = 1L + } + version = 2 // claimTask + transitionStage + computationDetails = NON_AGGREGATOR_COMPUTATION_DETAILS + addAllRequisitions(listOf(REQUISITION_1, REQUISITION_2, REQUISITION_3)) + } + .build() + ) + + assertThat(computationControlRequests) + .containsExactlyElementsIn( + buildAdvanceComputationRequests(GLOBAL_ID, SETUP_PHASE_INPUT, "cached result") + ) + .inOrder() + } + @Test fun `setup phase at non-aggregator using calculated result`() = runBlocking { // Stage 0. preparing the storage and set up mock @@ -1165,6 +1223,64 @@ class ReachOnlyLiquidLegionsV2MillTest { ) } + @Test + fun `setup phase at aggregator using cached result`() = runBlocking { + // Stage 0. preparing the storage and set up mock + val partialToken = + FakeComputationsDatabase.newPartialToken( + localId = LOCAL_ID, + stage = SETUP_PHASE.toProtocolStage() + ) + .build() + val requisitionBlobContext = + RequisitionBlobContext(GLOBAL_ID, REQUISITION_1.externalKey.externalRequisitionId) + requisitionStore.writeString(requisitionBlobContext, "local_requisition") + val cachedBlobContext = + ComputationBlobContext(GLOBAL_ID, SETUP_PHASE.toProtocolStage(), 1L) + computationStore.writeString(cachedBlobContext, "cached result") + fakeComputationDb.addComputation( + partialToken.localComputationId, + partialToken.computationStage, + computationDetails = AGGREGATOR_COMPUTATION_DETAILS, + requisitions = listOf(REQUISITION_1, REQUISITION_2, REQUISITION_3), + blobs = listOf(cachedBlobContext.toMetadata(ComputationBlobDependency.OUTPUT)) + ) + + // Stage 1. Process the above computation + nonAggregatorMill.pollAndProcessNextComputation() + + // Stage 2. Check the status of the computation + assertThat(fakeComputationDb[LOCAL_ID]) + .isEqualTo( + ComputationToken.newBuilder() + .apply { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 1 + computationStage = WAIT_EXECUTION_PHASE_INPUTS.toProtocolStage() + addBlobsBuilder().apply { + dependencyType = ComputationBlobDependency.INPUT + blobId = 0L + path = cachedBlobContext.blobKey + } + addBlobsBuilder().apply { + dependencyType = ComputationBlobDependency.OUTPUT + blobId = 1L + } + version = 2 // claimTask + transitionStage + computationDetails = AGGREGATOR_COMPUTATION_DETAILS + addAllRequisitions(listOf(REQUISITION_1, REQUISITION_2, REQUISITION_3)) + } + .build() + ) + + assertThat(computationControlRequests) + .containsExactlyElementsIn( + buildAdvanceComputationRequests(GLOBAL_ID, EXECUTION_PHASE_INPUT, "cached result") + ) + .inOrder() + } + @Test fun `setup phase at aggregator using calculated result`() = runBlocking { // Stage 0. preparing the storage and set up mock @@ -1343,7 +1459,7 @@ class ReachOnlyLiquidLegionsV2MillTest { failureBuilder.apply { participantChildReferenceId = MILL_ID errorMessage = - "PERMANENT error: java.lang.IllegalStateException: Invalid input blob size. Input" + + "PERMANENT error: java.lang.IllegalArgumentException: Invalid input blob size. Input" + " blob duchy_2_sketch_ has size 15 which is less than (66)." stageAttemptBuilder.apply { stage = SETUP_PHASE.number @@ -1356,83 +1472,6 @@ class ReachOnlyLiquidLegionsV2MillTest { ) } - - - @Test - fun `execution phase at non-aggregator, failed due to invalid input blob size`() = runBlocking { - // Stage 0. preparing the storage and set up mock - val partialToken = - FakeComputationsDatabase.newPartialToken( - localId = LOCAL_ID, - stage = EXECUTION_PHASE.toProtocolStage() - ) - .build() - val inputBlobContext = - ComputationBlobContext(GLOBAL_ID, EXECUTION_PHASE.toProtocolStage(), 0L) - val calculatedBlobContext = - ComputationBlobContext(GLOBAL_ID, EXECUTION_PHASE.toProtocolStage(), 1L) - computationStore.writeString(inputBlobContext, "data") - fakeComputationDb.addComputation( - partialToken.localComputationId, - partialToken.computationStage, - computationDetails = NON_AGGREGATOR_COMPUTATION_DETAILS, - blobs = - listOf( - inputBlobContext.toMetadata(ComputationBlobDependency.INPUT), - newEmptyOutputBlobMetadata(calculatedBlobContext.blobId) - ), - requisitions = REQUISITIONS - ) - - // Stage 1. Process the above computation - nonAggregatorMill.pollAndProcessNextComputation() - - // Stage 2. Check the status of the computation - assertThat(fakeComputationDb[LOCAL_ID]!!) - .isEqualTo( - ComputationToken.newBuilder() - .apply { - globalComputationId = GLOBAL_ID - localComputationId = LOCAL_ID - attempt = 1 - computationStage = COMPLETE.toProtocolStage() - version = 2 // claimTask + transitionStage - computationDetails = - NON_AGGREGATOR_COMPUTATION_DETAILS - .toBuilder() - .apply { - endingState = CompletedReason.FAILED - } - .build() - addAllRequisitions(listOf(REQUISITION_1, REQUISITION_2, REQUISITION_3)) - } - .build() - ) - - verifyProtoArgument( - mockComputationParticipants, - SystemComputationParticipantsCoroutineImplBase::failComputationParticipant - ) - .comparingExpectedFieldsOnly() - .isEqualTo( - FailComputationParticipantRequest.newBuilder() - .apply { - name = ComputationParticipantKey(GLOBAL_ID, DUCHY_ONE_NAME).toName() - failureBuilder.apply { - participantChildReferenceId = MILL_ID - errorMessage = - "PERMANENT error: Invalid input blob size. Input blob data has size 4 which is less than (66)." - stageAttemptBuilder.apply { - stage = EXECUTION_PHASE.number - stageName = EXECUTION_PHASE.name - attemptNumber = 1 - } - } - } - .build() - ) - } - @Test fun `execution phase at non-aggregator using cached result`() = runBlocking { // Stage 0. preparing the storage and set up mock @@ -1566,7 +1605,7 @@ class ReachOnlyLiquidLegionsV2MillTest { } @Test - fun `execution phase at aggregator, failed due to invalid input blob size`() = runBlocking { + fun `execution phase at non-aggregator, failed due to invalid input blob size`() = runBlocking { // Stage 0. preparing the storage and set up mock val partialToken = FakeComputationsDatabase.newPartialToken( @@ -1582,7 +1621,7 @@ class ReachOnlyLiquidLegionsV2MillTest { fakeComputationDb.addComputation( partialToken.localComputationId, partialToken.computationStage, - computationDetails = AGGREGATOR_COMPUTATION_DETAILS, + computationDetails = NON_AGGREGATOR_COMPUTATION_DETAILS, blobs = listOf( inputBlobContext.toMetadata(ComputationBlobDependency.INPUT), @@ -1605,7 +1644,7 @@ class ReachOnlyLiquidLegionsV2MillTest { computationStage = COMPLETE.toProtocolStage() version = 2 // claimTask + transitionStage computationDetails = - AGGREGATOR_COMPUTATION_DETAILS + NON_AGGREGATOR_COMPUTATION_DETAILS .toBuilder() .apply { endingState = CompletedReason.FAILED @@ -1805,6 +1844,81 @@ class ReachOnlyLiquidLegionsV2MillTest { } ) } + + @Test + fun `execution phase at aggregator, failed due to invalid input blob size`() = runBlocking { + // Stage 0. preparing the storage and set up mock + val partialToken = + FakeComputationsDatabase.newPartialToken( + localId = LOCAL_ID, + stage = EXECUTION_PHASE.toProtocolStage() + ) + .build() + val inputBlobContext = + ComputationBlobContext(GLOBAL_ID, EXECUTION_PHASE.toProtocolStage(), 0L) + val calculatedBlobContext = + ComputationBlobContext(GLOBAL_ID, EXECUTION_PHASE.toProtocolStage(), 1L) + computationStore.writeString(inputBlobContext, "data") + fakeComputationDb.addComputation( + partialToken.localComputationId, + partialToken.computationStage, + computationDetails = AGGREGATOR_COMPUTATION_DETAILS, + blobs = + listOf( + inputBlobContext.toMetadata(ComputationBlobDependency.INPUT), + newEmptyOutputBlobMetadata(calculatedBlobContext.blobId) + ), + requisitions = REQUISITIONS + ) + + // Stage 1. Process the above computation + nonAggregatorMill.pollAndProcessNextComputation() + + // Stage 2. Check the status of the computation + assertThat(fakeComputationDb[LOCAL_ID]!!) + .isEqualTo( + ComputationToken.newBuilder() + .apply { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 1 + computationStage = COMPLETE.toProtocolStage() + version = 2 // claimTask + transitionStage + computationDetails = + AGGREGATOR_COMPUTATION_DETAILS + .toBuilder() + .apply { + endingState = CompletedReason.FAILED + } + .build() + addAllRequisitions(listOf(REQUISITION_1, REQUISITION_2, REQUISITION_3)) + } + .build() + ) + + verifyProtoArgument( + mockComputationParticipants, + SystemComputationParticipantsCoroutineImplBase::failComputationParticipant + ) + .comparingExpectedFieldsOnly() + .isEqualTo( + FailComputationParticipantRequest.newBuilder() + .apply { + name = ComputationParticipantKey(GLOBAL_ID, DUCHY_ONE_NAME).toName() + failureBuilder.apply { + participantChildReferenceId = MILL_ID + errorMessage = + "PERMANENT error: Invalid input blob size. Input blob data has size 4 which is less than (66)." + stageAttemptBuilder.apply { + stage = EXECUTION_PHASE.number + stageName = EXECUTION_PHASE.name + attemptNumber = 1 + } + } + } + .build() + ) + } } private fun ComputationBlobContext.toMetadata(dependencyType: ComputationBlobDependency) = From 8adb7fe18a4274d8a7efa259c09e3e01e310d2d8 Mon Sep 17 00:00:00 2001 From: Phi Hung Le Date: Sun, 6 Aug 2023 00:33:18 +0000 Subject: [PATCH 12/15] Format files. --- .../daemon/mill/liquidlegionsv2/BUILD.bazel | 2 +- .../ReachOnlyLiquidLegionsV2Mill.kt | 80 +++++++++----- .../mill/liquidlegionsv2/crypto/BUILD.bazel | 4 +- .../JniReachOnlyLiquidLegionsV2Encryption.kt | 30 ++++-- .../reachonlyliquidlegionsv2/README.md | 2 +- .../ReachOnlyLiquidLegionsV2MillTest.kt | 100 ++++++++---------- ...nlyLiquidLegionsV2EncryptionUtilityTest.kt | 40 ++++--- 7 files changed, 146 insertions(+), 112 deletions(-) diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/BUILD.bazel index 62fb65e625f..442c31b904d 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/BUILD.bazel @@ -75,9 +75,9 @@ kt_jvm_library( "//src/main/proto/wfa/measurement/internal/duchy:differential_privacy_kt_jvm_proto", "//src/main/proto/wfa/measurement/internal/duchy/config:protocols_setup_config_kt_jvm_proto", "//src/main/proto/wfa/measurement/internal/duchy/protocol:liquid_legions_sketch_parameter_kt_jvm_proto", + "//src/main/proto/wfa/measurement/internal/duchy/protocol:liquid_legions_v2_noise_config_kt_jvm_proto", "//src/main/proto/wfa/measurement/internal/duchy/protocol:reach_only_liquid_legions_v2_encryption_methods_kt_jvm_proto", "//src/main/proto/wfa/measurement/internal/duchy/protocol:reach_only_liquid_legions_v2_kt_jvm_proto", - "//src/main/proto/wfa/measurement/internal/duchy/protocol:liquid_legions_v2_noise_config_kt_jvm_proto", "//src/main/proto/wfa/measurement/system/v1alpha:computation_control_service_kt_jvm_grpc_proto", "//src/main/proto/wfa/measurement/system/v1alpha:computation_participants_service_kt_jvm_grpc_proto", "//src/main/swig/protocol/reachonlyliquidlegionsv2:reach_only_liquid_legions_v2_encryption_utility", diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2Mill.kt b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2Mill.kt index c130f179660..e3b3ff7ea43 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2Mill.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2Mill.kt @@ -18,22 +18,18 @@ import com.google.protobuf.ByteString import io.opentelemetry.api.OpenTelemetry import io.opentelemetry.api.metrics.LongHistogram import io.opentelemetry.api.metrics.Meter -import java.nio.file.Paths import java.security.SignatureException import java.security.cert.CertPathValidatorException import java.security.cert.X509Certificate import java.time.Clock import java.time.Duration import java.util.logging.Logger -import kotlin.math.min import org.wfanet.anysketch.crypto.CombineElGamalPublicKeysRequest import org.wfanet.measurement.api.Version import org.wfanet.measurement.api.v2alpha.MeasurementSpec import org.wfanet.measurement.common.crypto.SigningKeyHandle import org.wfanet.measurement.common.crypto.readCertificate -import org.wfanet.measurement.common.flatten import org.wfanet.measurement.common.identity.DuchyInfo -import org.wfanet.measurement.common.loadLibrary import org.wfanet.measurement.common.throttler.MinimumIntervalThrottler import org.wfanet.measurement.consent.client.duchy.encryptResult import org.wfanet.measurement.consent.client.duchy.signElgamalPublicKey @@ -51,11 +47,11 @@ import org.wfanet.measurement.duchy.daemon.utils.toAnySketchElGamalPublicKey import org.wfanet.measurement.duchy.daemon.utils.toCmmsElGamalPublicKey import org.wfanet.measurement.duchy.daemon.utils.toV2AlphaElGamalPublicKey import org.wfanet.measurement.duchy.daemon.utils.toV2AlphaEncryptionPublicKey +import org.wfanet.measurement.duchy.db.computation.BlobRef import org.wfanet.measurement.duchy.db.computation.ComputationDataClients import org.wfanet.measurement.duchy.service.internal.computations.outputPathList import org.wfanet.measurement.duchy.service.system.v1alpha.advanceComputationHeader import org.wfanet.measurement.duchy.toProtocolStage -import org.wfanet.measurement.duchy.db.computation.BlobRef import org.wfanet.measurement.internal.duchy.ComputationDetails.CompletedReason import org.wfanet.measurement.internal.duchy.ComputationDetails.KingdomComputationDetails import org.wfanet.measurement.internal.duchy.ComputationStage @@ -77,13 +73,11 @@ import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlySetupPhas import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlySetupPhaseResponse import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2 import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2.ComputationDetails.ComputationParticipant -import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2.ComputationDetails.Parameters import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2.Stage import org.wfanet.measurement.internal.duchy.protocol.completeReachOnlyExecutionPhaseAtAggregatorRequest import org.wfanet.measurement.internal.duchy.protocol.completeReachOnlySetupPhaseRequest import org.wfanet.measurement.internal.duchy.protocol.globalReachDpNoiseBaseline import org.wfanet.measurement.internal.duchy.protocol.liquidLegionsSketchParameters -import org.wfanet.measurement.internal.duchy.protocol.perBucketFrequencyDpNoiseBaseline import org.wfanet.measurement.internal.duchy.protocol.reachNoiseDifferentialPrivacyParams import org.wfanet.measurement.internal.duchy.protocol.registerNoiseGenerationParameters import org.wfanet.measurement.system.v1alpha.ComputationControlGrpcKt.ComputationControlCoroutineStub @@ -185,10 +179,11 @@ class ReachOnlyLiquidLegionsV2Mill( Pair(Stage.SETUP_PHASE, AGGREGATOR) to ::completeReachOnlySetupPhaseAtAggregator, Pair(Stage.SETUP_PHASE, NON_AGGREGATOR) to ::completeReachOnlySetupPhaseAtNonAggregator, Pair(Stage.EXECUTION_PHASE, AGGREGATOR) to ::completeReachOnlyExecutionPhaseAtAggregator, - Pair(Stage.EXECUTION_PHASE, NON_AGGREGATOR) to ::completeReachOnlyExecutionPhaseAtNonAggregator, + Pair(Stage.EXECUTION_PHASE, NON_AGGREGATOR) to + ::completeReachOnlyExecutionPhaseAtNonAggregator, ) - private val kBytesPerCipherText = 66; + private val kBytesPerCipherText = 66 override suspend fun processComputationImpl(token: ComputationToken) { require(token.computationDetails.hasReachOnlyLiquidLegionsV2()) { @@ -267,7 +262,9 @@ class ReachOnlyLiquidLegionsV2Mill( it.details = token.computationDetails .toBuilder() - .apply { reachOnlyLiquidLegionsV2Builder.localElgamalKey = cryptoResult.elGamalKeyPair } + .apply { + reachOnlyLiquidLegionsV2Builder.localElgamalKey = cryptoResult.elGamalKeyPair + } .build() } .build() @@ -431,7 +428,8 @@ class ReachOnlyLiquidLegionsV2Mill( when (checkNotNull(token.computationDetails.reachOnlyLiquidLegionsV2.role)) { AGGREGATOR -> Stage.WAIT_SETUP_PHASE_INPUTS.toProtocolStage() NON_AGGREGATOR -> Stage.WAIT_TO_START.toProtocolStage() - else -> error("Unknown role: ${latestToken.computationDetails.reachOnlyLiquidLegionsV2.role}") + else -> + error("Unknown role: ${latestToken.computationDetails.reachOnlyLiquidLegionsV2.role}") } ) } @@ -454,7 +452,9 @@ class ReachOnlyLiquidLegionsV2Mill( } } - private suspend fun completeReachOnlySetupPhaseAtAggregator(token: ComputationToken): ComputationToken { + private suspend fun completeReachOnlySetupPhaseAtAggregator( + token: ComputationToken + ): ComputationToken { val rollv2Details = token.computationDetails.reachOnlyLiquidLegionsV2 require(AGGREGATOR == rollv2Details.role) { "invalid role for this function." } val (bytes, nextToken) = @@ -463,7 +463,10 @@ class ReachOnlyLiquidLegionsV2Mill( dataClients .readAllRequisitionBlobs(token, duchyId) .concat(readAndCombineAllInputBlobsSetupPhaseAtAggregator(token, workerStubs.size)) - .toCompleteReachOnlySetupPhaseAtAggregatorRequest(rollv2Details, token.requisitionsCount) + .toCompleteReachOnlySetupPhaseAtAggregatorRequest( + rollv2Details, + token.requisitionsCount + ) val cryptoResult: CompleteReachOnlySetupPhaseResponse = cryptoWorker.completeReachOnlySetupPhaseAtAggregator(request) logStageDurationMetric( @@ -493,7 +496,9 @@ class ReachOnlyLiquidLegionsV2Mill( ) } - private suspend fun completeReachOnlySetupPhaseAtNonAggregator(token: ComputationToken): ComputationToken { + private suspend fun completeReachOnlySetupPhaseAtNonAggregator( + token: ComputationToken + ): ComputationToken { val rollv2Details = token.computationDetails.reachOnlyLiquidLegionsV2 require(NON_AGGREGATOR == rollv2Details.role) { "invalid role for this function." } val (bytes, nextToken) = @@ -502,7 +507,8 @@ class ReachOnlyLiquidLegionsV2Mill( dataClients .readAllRequisitionBlobs(token, duchyId) .toCompleteReachOnlySetupPhaseRequest(rollv2Details, token.requisitionsCount) - val cryptoResult: CompleteReachOnlySetupPhaseResponse = cryptoWorker.completeReachOnlySetupPhase(request) + val cryptoResult: CompleteReachOnlySetupPhaseResponse = + cryptoWorker.completeReachOnlySetupPhase(request) logStageDurationMetric( token, CRYPTO_LIB_CPU_DURATION, @@ -510,7 +516,7 @@ class ReachOnlyLiquidLegionsV2Mill( setupPhaseCryptoCpuTimeDurationHistogram ) // The nextToken consists of the CRV and the noise ciphertext. - cryptoResult.combinedRegisterVector.concat(cryptoResult.serializedExcessiveNoiseCiphertext); + cryptoResult.combinedRegisterVector.concat(cryptoResult.serializedExcessiveNoiseCiphertext) } sendAdvanceComputationRequest( @@ -540,7 +546,10 @@ class ReachOnlyLiquidLegionsV2Mill( val measurementSpec = MeasurementSpec.parseFrom(token.computationDetails.kingdomComputation.measurementSpec) val inputBlob = readAndCombineAllInputBlobs(token, 1) - require(inputBlob.size() >= kBytesPerCipherText) { "Invalid input blob size. Input blob ${inputBlob.toStringUtf8()} has size ${inputBlob.size()} which is less than ($kBytesPerCipherText)." } + require(inputBlob.size() >= kBytesPerCipherText) { + ("Invalid input blob size. Input blob ${inputBlob.toStringUtf8()} has size " + + "${inputBlob.size()} which is less than ($kBytesPerCipherText).") + } var reach = 0L val (bytes, nextToken) = existingOutputOr(token) { @@ -548,7 +557,8 @@ class ReachOnlyLiquidLegionsV2Mill( combinedRegisterVector = inputBlob.substring(0, inputBlob.size() - kBytesPerCipherText) localElGamalKeyPair = rollv2Details.localElgamalKey curveId = rollv2Details.parameters.ellipticCurveId.toLong() - serializedExcessiveNoiseCiphertext = inputBlob.substring(inputBlob.size() - kBytesPerCipherText, inputBlob.size()) + serializedExcessiveNoiseCiphertext = + inputBlob.substring(inputBlob.size() - kBytesPerCipherText, inputBlob.size()) if (rollv2Parameters.noise.hasReachNoiseConfig()) { reachDpNoiseBaseline = globalReachDpNoiseBaseline { contributorsCount = workerStubs.size + 1 @@ -598,17 +608,22 @@ class ReachOnlyLiquidLegionsV2Mill( val rollv2Details = token.computationDetails.reachOnlyLiquidLegionsV2 require(NON_AGGREGATOR == rollv2Details.role) { "invalid role for this function." } val inputBlob = readAndCombineAllInputBlobs(token, 1) - require(inputBlob.size() >= kBytesPerCipherText) { "Invalid input blob size. Input blob ${inputBlob.toStringUtf8()} has size ${inputBlob.size()} which is less than ($kBytesPerCipherText)." } + require(inputBlob.size() >= kBytesPerCipherText) { + ("Invalid input blob size. Input blob ${inputBlob.toStringUtf8()} has size " + + "${inputBlob.size()} which is less than ($kBytesPerCipherText).") + } val (bytes, nextToken) = existingOutputOr(token) { val cryptoResult: CompleteReachOnlyExecutionPhaseResponse = cryptoWorker.completeReachOnlyExecutionPhase( CompleteReachOnlyExecutionPhaseRequest.newBuilder() .apply { - combinedRegisterVector = inputBlob.substring(0, inputBlob.size() - kBytesPerCipherText) + combinedRegisterVector = + inputBlob.substring(0, inputBlob.size() - kBytesPerCipherText) localElGamalKeyPair = rollv2Details.localElgamalKey curveId = rollv2Details.parameters.ellipticCurveId.toLong() - serializedExcessiveNoiseCiphertext = inputBlob.substring(inputBlob.size() - kBytesPerCipherText, inputBlob.size()) + serializedExcessiveNoiseCiphertext = + inputBlob.substring(inputBlob.size() - kBytesPerCipherText, inputBlob.size()) parallelism = this@ReachOnlyLiquidLegionsV2Mill.parallelism } .build() @@ -716,7 +731,11 @@ class ReachOnlyLiquidLegionsV2Mill( val noiseConfig = rollv2Details.parameters.noise val combinedInputBlobs = this@toCompleteReachOnlySetupPhaseAtAggregatorRequest return completeReachOnlySetupPhaseRequest { - combinedRegisterVector = combinedInputBlobs.substring(0, combinedInputBlobs.size() - workerStubs.size*kBytesPerCipherText) + combinedRegisterVector = + combinedInputBlobs.substring( + 0, + combinedInputBlobs.size() - workerStubs.size * kBytesPerCipherText + ) curveId = rollv2Details.parameters.ellipticCurveId.toLong() if (noiseConfig.hasReachNoiseConfig()) { noiseParameters = registerNoiseGenerationParameters { @@ -733,7 +752,11 @@ class ReachOnlyLiquidLegionsV2Mill( noiseMechanism = rollv2Details.parameters.noise.noiseMechanism } compositeElGamalPublicKey = rollv2Details.combinedPublicKey - serializedExcessiveNoiseCiphertext = combinedInputBlobs.substring(combinedInputBlobs.size() - workerStubs.size*kBytesPerCipherText, combinedInputBlobs.size()) + serializedExcessiveNoiseCiphertext = + combinedInputBlobs.substring( + combinedInputBlobs.size() - workerStubs.size * kBytesPerCipherText, + combinedInputBlobs.size() + ) parallelism = this@ReachOnlyLiquidLegionsV2Mill.parallelism } } @@ -752,9 +775,14 @@ class ReachOnlyLiquidLegionsV2Mill( var combinedRegisterVector = ByteString.EMPTY var combinedNoiseCiphertext = ByteString.EMPTY for (str in blobMap.values) { - require(str.size() >= kBytesPerCipherText) { "Invalid input blob size. Input blob ${str.toStringUtf8()} has size ${str.size()} which is less than ($kBytesPerCipherText)." } - combinedRegisterVector = combinedRegisterVector.concat(str.substring(0, str.size() - kBytesPerCipherText)) - combinedNoiseCiphertext = combinedNoiseCiphertext.concat(str.substring(str.size() - kBytesPerCipherText, str.size())) + require(str.size() >= kBytesPerCipherText) { + ("Invalid input blob size. Input blob ${str.toStringUtf8()} has size " + + "${str.size()} which is less than ($kBytesPerCipherText).") + } + combinedRegisterVector = + combinedRegisterVector.concat(str.substring(0, str.size() - kBytesPerCipherText)) + combinedNoiseCiphertext = + combinedNoiseCiphertext.concat(str.substring(str.size() - kBytesPerCipherText, str.size())) } return combinedRegisterVector.concat(combinedNoiseCiphertext) } diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/BUILD.bazel index 312e9d83bec..12de420eaac 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/BUILD.bazel @@ -5,8 +5,8 @@ package(default_visibility = ["//visibility:public"]) kt_jvm_library( name = "liquidlegionsv2encryption", srcs = [ - "LiquidLegionsV2Encryption.kt", "JniLiquidLegionsV2Encryption.kt", + "LiquidLegionsV2Encryption.kt", ], deps = [ "//src/main/proto/wfa/any_sketch/crypto:sketch_encryption_methods_kt_jvm_proto", @@ -21,8 +21,8 @@ kt_jvm_library( kt_jvm_library( name = "reachonlyliquidlegionsv2encryption", srcs = [ - "ReachOnlyLiquidLegionsV2Encryption.kt", "JniReachOnlyLiquidLegionsV2Encryption.kt", + "ReachOnlyLiquidLegionsV2Encryption.kt", ], deps = [ "//src/main/proto/wfa/any_sketch/crypto:sketch_encryption_methods_kt_jvm_proto", diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/JniReachOnlyLiquidLegionsV2Encryption.kt b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/JniReachOnlyLiquidLegionsV2Encryption.kt index 49193561296..58464adc9d6 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/JniReachOnlyLiquidLegionsV2Encryption.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/JniReachOnlyLiquidLegionsV2Encryption.kt @@ -30,7 +30,8 @@ import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlySetupPhas import org.wfanet.measurement.internal.duchy.protocol.reachonlyliquidlegionsv2.ReachOnlyLiquidLegionsV2EncryptionUtility /** - * A [ReachOnlyLiquidLegionsV2Encryption] implementation using the JNI [ReachOnlyLiquidLegionsV2EncryptionUtility]. + * A [ReachOnlyLiquidLegionsV2Encryption] implementation using the JNI + * [ReachOnlyLiquidLegionsV2EncryptionUtility]. */ class JniReachOnlyLiquidLegionsV2Encryption : ReachOnlyLiquidLegionsV2Encryption { @@ -38,19 +39,27 @@ class JniReachOnlyLiquidLegionsV2Encryption : ReachOnlyLiquidLegionsV2Encryption request: CompleteReachOnlyInitializationPhaseRequest ): CompleteReachOnlyInitializationPhaseResponse { return CompleteReachOnlyInitializationPhaseResponse.parseFrom( - ReachOnlyLiquidLegionsV2EncryptionUtility.completeReachOnlyInitializationPhase(request.toByteArray()) + ReachOnlyLiquidLegionsV2EncryptionUtility.completeReachOnlyInitializationPhase( + request.toByteArray() + ) ) } - override fun completeReachOnlySetupPhase(request: CompleteReachOnlySetupPhaseRequest): CompleteReachOnlySetupPhaseResponse { + override fun completeReachOnlySetupPhase( + request: CompleteReachOnlySetupPhaseRequest + ): CompleteReachOnlySetupPhaseResponse { return CompleteReachOnlySetupPhaseResponse.parseFrom( ReachOnlyLiquidLegionsV2EncryptionUtility.completeReachOnlySetupPhase(request.toByteArray()) ) } - override fun completeReachOnlySetupPhaseAtAggregator(request: CompleteReachOnlySetupPhaseRequest): CompleteReachOnlySetupPhaseResponse { + override fun completeReachOnlySetupPhaseAtAggregator( + request: CompleteReachOnlySetupPhaseRequest + ): CompleteReachOnlySetupPhaseResponse { return CompleteReachOnlySetupPhaseResponse.parseFrom( - ReachOnlyLiquidLegionsV2EncryptionUtility.completeReachOnlySetupPhaseAtAggregator(request.toByteArray()) + ReachOnlyLiquidLegionsV2EncryptionUtility.completeReachOnlySetupPhaseAtAggregator( + request.toByteArray() + ) ) } @@ -58,7 +67,9 @@ class JniReachOnlyLiquidLegionsV2Encryption : ReachOnlyLiquidLegionsV2Encryption request: CompleteReachOnlyExecutionPhaseRequest ): CompleteReachOnlyExecutionPhaseResponse { return CompleteReachOnlyExecutionPhaseResponse.parseFrom( - ReachOnlyLiquidLegionsV2EncryptionUtility.completeReachOnlyExecutionPhase(request.toByteArray()) + ReachOnlyLiquidLegionsV2EncryptionUtility.completeReachOnlyExecutionPhase( + request.toByteArray() + ) ) } @@ -66,7 +77,9 @@ class JniReachOnlyLiquidLegionsV2Encryption : ReachOnlyLiquidLegionsV2Encryption request: CompleteReachOnlyExecutionPhaseAtAggregatorRequest ): CompleteReachOnlyExecutionPhaseAtAggregatorResponse { return CompleteReachOnlyExecutionPhaseAtAggregatorResponse.parseFrom( - ReachOnlyLiquidLegionsV2EncryptionUtility.completeReachOnlyExecutionPhaseAtAggregator(request.toByteArray()) + ReachOnlyLiquidLegionsV2EncryptionUtility.completeReachOnlyExecutionPhaseAtAggregator( + request.toByteArray() + ) ) } @@ -82,7 +95,8 @@ class JniReachOnlyLiquidLegionsV2Encryption : ReachOnlyLiquidLegionsV2Encryption init { loadLibrary( name = "reach_only_liquid_legions_v2_encryption_utility", - directoryPath = Paths.get("wfa_measurement_system/src/main/swig/protocol/reachonlyliquidlegionsv2") + directoryPath = + Paths.get("wfa_measurement_system/src/main/swig/protocol/reachonlyliquidlegionsv2") ) loadLibrary( name = "sketch_encrypter_adapter", diff --git a/src/main/swig/protocol/reachonlyliquidlegionsv2/README.md b/src/main/swig/protocol/reachonlyliquidlegionsv2/README.md index c3bfa55b45c..3f3f9cc9b28 100644 --- a/src/main/swig/protocol/reachonlyliquidlegionsv2/README.md +++ b/src/main/swig/protocol/reachonlyliquidlegionsv2/README.md @@ -1,4 +1,4 @@ -# Liquid Legions V2 Encryption Utility Java Library +# Reach Only Liquid Legions V2 Encryption Utility Java Library The ReachOnlyLiquidLegionsV2EncryptionUtility java class is auto-generated from the reach_only_liquid_legions_v2_encryption_utility.swig definition. The diff --git a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2MillTest.kt b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2MillTest.kt index 066f0f4ebc0..acecf238a7e 100644 --- a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2MillTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2MillTest.kt @@ -103,6 +103,7 @@ import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyInitializ import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyInitializationPhaseResponse import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlySetupPhaseRequest import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlySetupPhaseResponse +import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsV2NoiseConfig import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2.ComputationDetails.ComputationParticipant import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2.ComputationDetails.Parameters import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2.Stage.COMPLETE @@ -115,9 +116,7 @@ import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSket import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2.Stage.WAIT_SETUP_PHASE_INPUTS import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2.Stage.WAIT_TO_START import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2Kt -import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsV2NoiseConfig import org.wfanet.measurement.internal.duchy.protocol.completeReachOnlyExecutionPhaseAtAggregatorRequest -import org.wfanet.measurement.internal.duchy.protocol.completeReachOnlyExecutionPhaseAtAggregatorResponse import org.wfanet.measurement.internal.duchy.protocol.completeReachOnlyExecutionPhaseRequest import org.wfanet.measurement.internal.duchy.protocol.completeReachOnlySetupPhaseRequest import org.wfanet.measurement.internal.duchy.protocol.copy @@ -167,7 +166,8 @@ private const val LOCAL_ID = 1234L private const val GLOBAL_ID = LOCAL_ID.toString() private const val CIPHERTEXT_SIZE = 66 -private const val NOISE_CIPHERTEXT = "abcdefghijklmnopqrstuvwxyz0123456abcdefghijklmnopqrstuvwxyz0123456" +private const val NOISE_CIPHERTEXT = + "abcdefghijklmnopqrstuvwxyz0123456abcdefghijklmnopqrstuvwxyz0123456" private val SERIALIZED_NOISE_CIPHERTEXT = ByteString.copyFromUtf8(NOISE_CIPHERTEXT) private val DUCHY_ONE_KEY_PAIR = @@ -800,7 +800,9 @@ class ReachOnlyLiquidLegionsV2MillTest { assertThat(cryptoRequest) .isEqualTo( - CompleteReachOnlyInitializationPhaseRequest.newBuilder().apply { curveId = CURVE_ID }.build() + CompleteReachOnlyInitializationPhaseRequest.newBuilder() + .apply { curveId = CURVE_ID } + .build() ) } @@ -813,7 +815,9 @@ class ReachOnlyLiquidLegionsV2MillTest { REQUISITION_2.copy { details = details.copy { externalFulfillingDuchyId = DUCHY_ONE_NAME } } val computationDetailsWithoutPublicKey = AGGREGATOR_COMPUTATION_DETAILS.toBuilder() - .apply { reachOnlyLiquidLegionsV2Builder.clearCombinedPublicKey().clearPartiallyCombinedPublicKey() } + .apply { + reachOnlyLiquidLegionsV2Builder.clearCombinedPublicKey().clearPartiallyCombinedPublicKey() + } .build() fakeComputationDb.addComputation( globalId = GLOBAL_ID, @@ -880,7 +884,9 @@ class ReachOnlyLiquidLegionsV2MillTest { // Stage 0. preparing the storage and set up mock val computationDetailsWithoutPublicKey = NON_AGGREGATOR_COMPUTATION_DETAILS.toBuilder() - .apply { reachOnlyLiquidLegionsV2Builder.clearCombinedPublicKey().clearPartiallyCombinedPublicKey() } + .apply { + reachOnlyLiquidLegionsV2Builder.clearCombinedPublicKey().clearPartiallyCombinedPublicKey() + } .build() fakeComputationDb.addComputation( globalId = GLOBAL_ID, @@ -935,7 +941,9 @@ class ReachOnlyLiquidLegionsV2MillTest { // Stage 0. preparing the storage and set up mock val computationDetailsWithoutPublicKey = AGGREGATOR_COMPUTATION_DETAILS.toBuilder() - .apply { reachOnlyLiquidLegionsV2Builder.clearCombinedPublicKey().clearPartiallyCombinedPublicKey() } + .apply { + reachOnlyLiquidLegionsV2Builder.clearCombinedPublicKey().clearPartiallyCombinedPublicKey() + } .build() fakeComputationDb.addComputation( globalId = GLOBAL_ID, @@ -1078,8 +1086,7 @@ class ReachOnlyLiquidLegionsV2MillTest { val requisitionBlobContext = RequisitionBlobContext(GLOBAL_ID, REQUISITION_1.externalKey.externalRequisitionId) requisitionStore.writeString(requisitionBlobContext, "local_requisition") - val cachedBlobContext = - ComputationBlobContext(GLOBAL_ID, SETUP_PHASE.toProtocolStage(), 1L) + val cachedBlobContext = ComputationBlobContext(GLOBAL_ID, SETUP_PHASE.toProtocolStage(), 1L) computationStore.writeString(cachedBlobContext, "cached result") fakeComputationDb.addComputation( partialToken.localComputationId, @@ -1235,8 +1242,7 @@ class ReachOnlyLiquidLegionsV2MillTest { val requisitionBlobContext = RequisitionBlobContext(GLOBAL_ID, REQUISITION_1.externalKey.externalRequisitionId) requisitionStore.writeString(requisitionBlobContext, "local_requisition") - val cachedBlobContext = - ComputationBlobContext(GLOBAL_ID, SETUP_PHASE.toProtocolStage(), 1L) + val cachedBlobContext = ComputationBlobContext(GLOBAL_ID, SETUP_PHASE.toProtocolStage(), 1L) computationStore.writeString(cachedBlobContext, "cached result") fakeComputationDb.addComputation( partialToken.localComputationId, @@ -1352,7 +1358,9 @@ class ReachOnlyLiquidLegionsV2MillTest { ) assertThat(computationStore.get(blobKey)?.readToString()) - .isEqualTo("local_requisition_duchy_2_sketch_duchy_3_sketch_-completeReachOnlySetupPhase-encryptedNoise") + .isEqualTo( + "local_requisition_duchy_2_sketch_duchy_3_sketch_-completeReachOnlySetupPhase-encryptedNoise" + ) assertThat(computationControlRequests) .containsExactlyElementsIn( @@ -1386,7 +1394,8 @@ class ReachOnlyLiquidLegionsV2MillTest { } } compositeElGamalPublicKey = COMBINED_PUBLIC_KEY - serializedExcessiveNoiseCiphertext = SERIALIZED_NOISE_CIPHERTEXT.concat(SERIALIZED_NOISE_CIPHERTEXT) + serializedExcessiveNoiseCiphertext = + SERIALIZED_NOISE_CIPHERTEXT.concat(SERIALIZED_NOISE_CIPHERTEXT) parallelism = PARALLELISM } ) @@ -1411,8 +1420,7 @@ class ReachOnlyLiquidLegionsV2MillTest { fakeComputationDb.addComputation( partialToken.localComputationId, partialToken.computationStage, - computationDetails = - AGGREGATOR_COMPUTATION_DETAILS, + computationDetails = AGGREGATOR_COMPUTATION_DETAILS, blobs = listOf( newInputBlobMetadata(0L, inputBlob0Context.blobKey), @@ -1436,11 +1444,8 @@ class ReachOnlyLiquidLegionsV2MillTest { computationStage = COMPLETE.toProtocolStage() version = 2 // claimTask + transitionStage computationDetails = - AGGREGATOR_COMPUTATION_DETAILS - .toBuilder() - .apply { - endingState = CompletedReason.FAILED - } + AGGREGATOR_COMPUTATION_DETAILS.toBuilder() + .apply { endingState = CompletedReason.FAILED } .build() addAllRequisitions(listOf(REQUISITION_1, REQUISITION_2, REQUISITION_3)) } @@ -1460,7 +1465,7 @@ class ReachOnlyLiquidLegionsV2MillTest { participantChildReferenceId = MILL_ID errorMessage = "PERMANENT error: java.lang.IllegalArgumentException: Invalid input blob size. Input" + - " blob duchy_2_sketch_ has size 15 which is less than (66)." + " blob duchy_2_sketch_ has size 15 which is less than (66)." stageAttemptBuilder.apply { stage = SETUP_PHASE.number stageName = SETUP_PHASE.name @@ -1481,17 +1486,14 @@ class ReachOnlyLiquidLegionsV2MillTest { stage = EXECUTION_PHASE.toProtocolStage() ) .build() - val inputBlobContext = - ComputationBlobContext(GLOBAL_ID, EXECUTION_PHASE.toProtocolStage(), 0L) + val inputBlobContext = ComputationBlobContext(GLOBAL_ID, EXECUTION_PHASE.toProtocolStage(), 0L) computationStore.writeString(inputBlobContext, "sketch" + NOISE_CIPHERTEXT) - val cachedBlobContext = - ComputationBlobContext(GLOBAL_ID, EXECUTION_PHASE.toProtocolStage(), 1L) + val cachedBlobContext = ComputationBlobContext(GLOBAL_ID, EXECUTION_PHASE.toProtocolStage(), 1L) computationStore.writeString(cachedBlobContext, "cached result") fakeComputationDb.addComputation( partialToken.localComputationId, partialToken.computationStage, - computationDetails = - NON_AGGREGATOR_COMPUTATION_DETAILS, + computationDetails = NON_AGGREGATOR_COMPUTATION_DETAILS, blobs = listOf( inputBlobContext.toMetadata(ComputationBlobDependency.INPUT), @@ -1513,8 +1515,7 @@ class ReachOnlyLiquidLegionsV2MillTest { computationStage = COMPLETE.toProtocolStage() version = 2 // claimTask + transitionStage computationDetails = - NON_AGGREGATOR_COMPUTATION_DETAILS - .toBuilder() + NON_AGGREGATOR_COMPUTATION_DETAILS.toBuilder() .apply { endingState = CompletedReason.SUCCEEDED } .build() requisitions += REQUISITIONS @@ -1537,8 +1538,7 @@ class ReachOnlyLiquidLegionsV2MillTest { stage = EXECUTION_PHASE.toProtocolStage() ) .build() - val inputBlobContext = - ComputationBlobContext(GLOBAL_ID, EXECUTION_PHASE.toProtocolStage(), 0L) + val inputBlobContext = ComputationBlobContext(GLOBAL_ID, EXECUTION_PHASE.toProtocolStage(), 0L) val calculatedBlobContext = ComputationBlobContext(GLOBAL_ID, EXECUTION_PHASE.toProtocolStage(), 1L) computationStore.writeString(inputBlobContext, "data" + NOISE_CIPHERTEXT) @@ -1581,8 +1581,7 @@ class ReachOnlyLiquidLegionsV2MillTest { computationStage = COMPLETE.toProtocolStage() version = 3 // claimTask + writeOutputBlob + transitionStage computationDetails = - NON_AGGREGATOR_COMPUTATION_DETAILS - .toBuilder() + NON_AGGREGATOR_COMPUTATION_DETAILS.toBuilder() .apply { endingState = CompletedReason.SUCCEEDED } .build() addAllRequisitions(REQUISITIONS) @@ -1613,8 +1612,7 @@ class ReachOnlyLiquidLegionsV2MillTest { stage = EXECUTION_PHASE.toProtocolStage() ) .build() - val inputBlobContext = - ComputationBlobContext(GLOBAL_ID, EXECUTION_PHASE.toProtocolStage(), 0L) + val inputBlobContext = ComputationBlobContext(GLOBAL_ID, EXECUTION_PHASE.toProtocolStage(), 0L) val calculatedBlobContext = ComputationBlobContext(GLOBAL_ID, EXECUTION_PHASE.toProtocolStage(), 1L) computationStore.writeString(inputBlobContext, "data") @@ -1644,11 +1642,8 @@ class ReachOnlyLiquidLegionsV2MillTest { computationStage = COMPLETE.toProtocolStage() version = 2 // claimTask + transitionStage computationDetails = - NON_AGGREGATOR_COMPUTATION_DETAILS - .toBuilder() - .apply { - endingState = CompletedReason.FAILED - } + NON_AGGREGATOR_COMPUTATION_DETAILS.toBuilder() + .apply { endingState = CompletedReason.FAILED } .build() addAllRequisitions(listOf(REQUISITION_1, REQUISITION_2, REQUISITION_3)) } @@ -1688,17 +1683,14 @@ class ReachOnlyLiquidLegionsV2MillTest { stage = EXECUTION_PHASE.toProtocolStage() ) .build() - val inputBlobContext = - ComputationBlobContext(GLOBAL_ID, EXECUTION_PHASE.toProtocolStage(), 0L) + val inputBlobContext = ComputationBlobContext(GLOBAL_ID, EXECUTION_PHASE.toProtocolStage(), 0L) computationStore.writeString(inputBlobContext, "sketch" + NOISE_CIPHERTEXT) - val cachedBlobContext = - ComputationBlobContext(GLOBAL_ID, EXECUTION_PHASE.toProtocolStage(), 1L) + val cachedBlobContext = ComputationBlobContext(GLOBAL_ID, EXECUTION_PHASE.toProtocolStage(), 1L) computationStore.writeString(cachedBlobContext, "cached result") fakeComputationDb.addComputation( partialToken.localComputationId, partialToken.computationStage, - computationDetails = - AGGREGATOR_COMPUTATION_DETAILS, + computationDetails = AGGREGATOR_COMPUTATION_DETAILS, blobs = listOf( inputBlobContext.toMetadata(ComputationBlobDependency.INPUT), @@ -1720,8 +1712,7 @@ class ReachOnlyLiquidLegionsV2MillTest { computationStage = COMPLETE.toProtocolStage() version = 2 // claimTask + transitionStage computationDetails = - AGGREGATOR_COMPUTATION_DETAILS - .toBuilder() + AGGREGATOR_COMPUTATION_DETAILS.toBuilder() .apply { endingState = CompletedReason.SUCCEEDED } .build() requisitions += REQUISITIONS @@ -1745,8 +1736,7 @@ class ReachOnlyLiquidLegionsV2MillTest { measurementSpec = SERIALIZED_MEASUREMENT_SPEC_WITH_VID_SAMPLING_WIDTH } } - val inputBlobContext = - ComputationBlobContext(GLOBAL_ID, EXECUTION_PHASE.toProtocolStage(), 0L) + val inputBlobContext = ComputationBlobContext(GLOBAL_ID, EXECUTION_PHASE.toProtocolStage(), 0L) val calculatedBlobContext = ComputationBlobContext(GLOBAL_ID, EXECUTION_PHASE.toProtocolStage(), 1L) computationStore.writeString(inputBlobContext, "data" + NOISE_CIPHERTEXT) @@ -1854,8 +1844,7 @@ class ReachOnlyLiquidLegionsV2MillTest { stage = EXECUTION_PHASE.toProtocolStage() ) .build() - val inputBlobContext = - ComputationBlobContext(GLOBAL_ID, EXECUTION_PHASE.toProtocolStage(), 0L) + val inputBlobContext = ComputationBlobContext(GLOBAL_ID, EXECUTION_PHASE.toProtocolStage(), 0L) val calculatedBlobContext = ComputationBlobContext(GLOBAL_ID, EXECUTION_PHASE.toProtocolStage(), 1L) computationStore.writeString(inputBlobContext, "data") @@ -1885,11 +1874,8 @@ class ReachOnlyLiquidLegionsV2MillTest { computationStage = COMPLETE.toProtocolStage() version = 2 // claimTask + transitionStage computationDetails = - AGGREGATOR_COMPUTATION_DETAILS - .toBuilder() - .apply { - endingState = CompletedReason.FAILED - } + AGGREGATOR_COMPUTATION_DETAILS.toBuilder() + .apply { endingState = CompletedReason.FAILED } .build() addAllRequisitions(listOf(REQUISITION_1, REQUISITION_2, REQUISITION_3)) } diff --git a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/ReachOnlyLiquidLegionsV2EncryptionUtilityTest.kt b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/ReachOnlyLiquidLegionsV2EncryptionUtilityTest.kt index 7eb5c1768be..8fa395c88f5 100644 --- a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/ReachOnlyLiquidLegionsV2EncryptionUtilityTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/ReachOnlyLiquidLegionsV2EncryptionUtilityTest.kt @@ -16,13 +16,12 @@ package org.wfanet.measurement.duchy.daemon.mill.liquidlegionsv2.crypto import com.google.common.truth.Truth.assertThat import com.google.protobuf.ByteString import java.nio.file.Paths -import kotlin.test.assertFailsWith import kotlin.test.assertEquals +import kotlin.test.assertFailsWith import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.JUnit4 import org.wfanet.anysketch.Sketch -import org.wfanet.anysketch.SketchConfig.ValueSpec.Aggregator import org.wfanet.anysketch.crypto.CombineElGamalPublicKeysRequest import org.wfanet.anysketch.crypto.CombineElGamalPublicKeysResponse import org.wfanet.anysketch.crypto.EncryptSketchRequest @@ -52,9 +51,7 @@ class ReachOnlyLiquidLegionsV2EncryptionUtilityTest { } private fun Sketch.Builder.addRegister(index: Long) { - addRegistersBuilder().also { - it.index = index - } + addRegistersBuilder().also { it.index = index } } // Helper function to go through the entire Liquid Legions V2 protocol using the input data. @@ -72,7 +69,8 @@ class ReachOnlyLiquidLegionsV2EncryptionUtilityTest { val completeReachOnlySetupPhaseResponse1 = CompleteReachOnlySetupPhaseResponse.parseFrom( ReachOnlyLiquidLegionsV2EncryptionUtility.completeReachOnlySetupPhase( - completeReachOnlySetupPhaseRequest1.toByteArray()) + completeReachOnlySetupPhaseRequest1.toByteArray() + ) ) // Setup phase at Duchy 2 (NON_AGGREGATOR). Duchy 2 does not receive any sketche. val completeReachOnlySetupPhaseRequest2 = completeReachOnlySetupPhaseRequest { @@ -83,25 +81,29 @@ class ReachOnlyLiquidLegionsV2EncryptionUtilityTest { val completeReachOnlySetupPhaseResponse2 = CompleteReachOnlySetupPhaseResponse.parseFrom( ReachOnlyLiquidLegionsV2EncryptionUtility.completeReachOnlySetupPhase( - completeReachOnlySetupPhaseRequest2.toByteArray()) + completeReachOnlySetupPhaseRequest2.toByteArray() + ) ) // Setup phase at Duchy 3 (AGGREGATOR). Aggregator receives the combined register vector and // the concatenated excessive noise ciphertexts. val completeReachOnlySetupPhaseRequest3 = completeReachOnlySetupPhaseRequest { combinedRegisterVector = completeReachOnlySetupPhaseResponse1.combinedRegisterVector.concat( - completeReachOnlySetupPhaseResponse2.combinedRegisterVector) + completeReachOnlySetupPhaseResponse2.combinedRegisterVector + ) curveId = CURVE_ID compositeElGamalPublicKey = CLIENT_EL_GAMAL_KEYS serializedExcessiveNoiseCiphertext = completeReachOnlySetupPhaseResponse1.serializedExcessiveNoiseCiphertext.concat( - completeReachOnlySetupPhaseResponse2.serializedExcessiveNoiseCiphertext) + completeReachOnlySetupPhaseResponse2.serializedExcessiveNoiseCiphertext + ) parallelism = PARALLELISM } val completeReachOnlySetupPhaseResponse3 = CompleteReachOnlySetupPhaseResponse.parseFrom( ReachOnlyLiquidLegionsV2EncryptionUtility.completeReachOnlySetupPhase( - completeReachOnlySetupPhaseRequest3.toByteArray()) + completeReachOnlySetupPhaseRequest3.toByteArray() + ) ) // Execution phase at duchy 1 (non-aggregator). @@ -185,8 +187,13 @@ class ReachOnlyLiquidLegionsV2EncryptionUtilityTest { EncryptSketchResponse.parseFrom(SketchEncrypterAdapter.EncryptSketch(request.toByteArray())) val encryptedSketch = response.encryptedSketch val result = goThroughEntireMpcProtocol(encryptedSketch).reach - val expectedResult = Estimators.EstimateCardinalityLiquidLegions( - DECAY_RATE, LIQUID_LEGIONS_SIZE, 4, VID_SAMPLING_INTERVAL_WIDTH.toDouble()) + val expectedResult = + Estimators.EstimateCardinalityLiquidLegions( + DECAY_RATE, + LIQUID_LEGIONS_SIZE, + 4, + VID_SAMPLING_INTERVAL_WIDTH.toDouble() + ) assertEquals(expectedResult, result) } @@ -194,7 +201,9 @@ class ReachOnlyLiquidLegionsV2EncryptionUtilityTest { fun `completeReachOnlySetupPhase fails with invalid request message`() { val exception = assertFailsWith(RuntimeException::class) { - ReachOnlyLiquidLegionsV2EncryptionUtility.completeReachOnlySetupPhase("something not a proto".toByteArray()) + ReachOnlyLiquidLegionsV2EncryptionUtility.completeReachOnlySetupPhase( + "something not a proto".toByteArray() + ) } assertThat(exception).hasMessageThat().contains("Failed to parse") } @@ -231,10 +240,7 @@ class ReachOnlyLiquidLegionsV2EncryptionUtilityTest { "sketch_encrypter_adapter", Paths.get("any_sketch_java/src/main/java/org/wfanet/anysketch/crypto") ) - loadLibrary( - "estimators", - Paths.get("any_sketch_java/src/main/java/org/wfanet/estimation") - ) + loadLibrary("estimators", Paths.get("any_sketch_java/src/main/java/org/wfanet/estimation")) } private const val DECAY_RATE = 12.0 From e0e11a3e29ae00e32c1358f5dff50d44c211f16c Mon Sep 17 00:00:00 2001 From: Phi Hung Le Date: Wed, 9 Aug 2023 21:15:18 +0000 Subject: [PATCH 13/15] Replace Java style with Kotlin DSL. --- .../ReachOnlyLiquidLegionsV2Mill.kt | 173 ++- .../LiquidLegionsV2MillTest.kt | 23 +- .../ReachOnlyLiquidLegionsV2MillTest.kt | 1181 +++++++++-------- ...nlyLiquidLegionsV2EncryptionUtilityTest.kt | 74 +- 4 files changed, 716 insertions(+), 735 deletions(-) diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2Mill.kt b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2Mill.kt index e3b3ff7ea43..c47ec1c8433 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2Mill.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2Mill.kt @@ -24,7 +24,7 @@ import java.security.cert.X509Certificate import java.time.Clock import java.time.Duration import java.util.logging.Logger -import org.wfanet.anysketch.crypto.CombineElGamalPublicKeysRequest +import org.wfanet.anysketch.crypto.combineElGamalPublicKeysRequest import org.wfanet.measurement.api.Version import org.wfanet.measurement.api.v2alpha.MeasurementSpec import org.wfanet.measurement.common.crypto.SigningKeyHandle @@ -54,40 +54,44 @@ import org.wfanet.measurement.duchy.service.system.v1alpha.advanceComputationHea import org.wfanet.measurement.duchy.toProtocolStage import org.wfanet.measurement.internal.duchy.ComputationDetails.CompletedReason import org.wfanet.measurement.internal.duchy.ComputationDetails.KingdomComputationDetails -import org.wfanet.measurement.internal.duchy.ComputationStage import org.wfanet.measurement.internal.duchy.ComputationStatsGrpcKt.ComputationStatsCoroutineStub import org.wfanet.measurement.internal.duchy.ComputationToken import org.wfanet.measurement.internal.duchy.ComputationTypeEnum.ComputationType import org.wfanet.measurement.internal.duchy.ElGamalPublicKey import org.wfanet.measurement.internal.duchy.RequisitionMetadata -import org.wfanet.measurement.internal.duchy.UpdateComputationDetailsRequest +import org.wfanet.measurement.internal.duchy.computationDetails import org.wfanet.measurement.internal.duchy.config.LiquidLegionsV2SetupConfig.RoleInComputation.AGGREGATOR import org.wfanet.measurement.internal.duchy.config.LiquidLegionsV2SetupConfig.RoleInComputation.NON_AGGREGATOR import org.wfanet.measurement.internal.duchy.config.LiquidLegionsV2SetupConfig.RoleInComputation.UNKNOWN import org.wfanet.measurement.internal.duchy.config.LiquidLegionsV2SetupConfig.RoleInComputation.UNRECOGNIZED import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyExecutionPhaseAtAggregatorResponse -import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyExecutionPhaseRequest import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyExecutionPhaseResponse -import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyInitializationPhaseRequest import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlySetupPhaseRequest import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlySetupPhaseResponse import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2 import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2.ComputationDetails.ComputationParticipant import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2.Stage +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2Kt import org.wfanet.measurement.internal.duchy.protocol.completeReachOnlyExecutionPhaseAtAggregatorRequest +import org.wfanet.measurement.internal.duchy.protocol.completeReachOnlyExecutionPhaseRequest +import org.wfanet.measurement.internal.duchy.protocol.completeReachOnlyInitializationPhaseRequest import org.wfanet.measurement.internal.duchy.protocol.completeReachOnlySetupPhaseRequest import org.wfanet.measurement.internal.duchy.protocol.globalReachDpNoiseBaseline import org.wfanet.measurement.internal.duchy.protocol.liquidLegionsSketchParameters import org.wfanet.measurement.internal.duchy.protocol.reachNoiseDifferentialPrivacyParams import org.wfanet.measurement.internal.duchy.protocol.registerNoiseGenerationParameters +import org.wfanet.measurement.internal.duchy.updateComputationDetailsRequest import org.wfanet.measurement.system.v1alpha.ComputationControlGrpcKt.ComputationControlCoroutineStub import org.wfanet.measurement.system.v1alpha.ComputationLogEntriesGrpcKt.ComputationLogEntriesCoroutineStub import org.wfanet.measurement.system.v1alpha.ComputationParticipantKey +import org.wfanet.measurement.system.v1alpha.ComputationParticipantKt +import org.wfanet.measurement.system.v1alpha.ComputationParticipantKt.RequisitionParamsKt import org.wfanet.measurement.system.v1alpha.ComputationParticipantsGrpcKt.ComputationParticipantsCoroutineStub import org.wfanet.measurement.system.v1alpha.ComputationsGrpcKt -import org.wfanet.measurement.system.v1alpha.ConfirmComputationParticipantRequest import org.wfanet.measurement.system.v1alpha.ReachOnlyLiquidLegionsV2 -import org.wfanet.measurement.system.v1alpha.SetParticipantRequisitionParamsRequest +import org.wfanet.measurement.system.v1alpha.confirmComputationParticipantRequest +import org.wfanet.measurement.system.v1alpha.reachOnlyLiquidLegionsV2 +import org.wfanet.measurement.system.v1alpha.setParticipantRequisitionParamsRequest /** * Mill works on computations using the ReachOnlyLiquidLegionSketchAggregationProtocol. @@ -165,10 +169,8 @@ class ReachOnlyLiquidLegionsV2Mill( private val executionPhaseCryptoCpuTimeDurationHistogram: LongHistogram = meter.histogramBuilder("execution_phase_crypto_cpu_time_duration_millis").ofLongs().build() - override val endingStage: ComputationStage = - ComputationStage.newBuilder() - .apply { reachOnlyLiquidLegionsSketchAggregationV2 = Stage.COMPLETE } - .build() + override val endingStage = + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.COMPLETE.toProtocolStage() private val actions = mapOf( @@ -176,11 +178,10 @@ class ReachOnlyLiquidLegionsV2Mill( Pair(Stage.INITIALIZATION_PHASE, NON_AGGREGATOR) to ::initializationPhase, Pair(Stage.CONFIRMATION_PHASE, AGGREGATOR) to ::confirmationPhase, Pair(Stage.CONFIRMATION_PHASE, NON_AGGREGATOR) to ::confirmationPhase, - Pair(Stage.SETUP_PHASE, AGGREGATOR) to ::completeReachOnlySetupPhaseAtAggregator, - Pair(Stage.SETUP_PHASE, NON_AGGREGATOR) to ::completeReachOnlySetupPhaseAtNonAggregator, - Pair(Stage.EXECUTION_PHASE, AGGREGATOR) to ::completeReachOnlyExecutionPhaseAtAggregator, - Pair(Stage.EXECUTION_PHASE, NON_AGGREGATOR) to - ::completeReachOnlyExecutionPhaseAtNonAggregator, + Pair(Stage.SETUP_PHASE, AGGREGATOR) to ::completeSetupPhaseAtAggregator, + Pair(Stage.SETUP_PHASE, NON_AGGREGATOR) to ::completeSetupPhaseAtNonAggregator, + Pair(Stage.EXECUTION_PHASE, AGGREGATOR) to ::completeExecutionPhaseAtAggregator, + Pair(Stage.EXECUTION_PHASE, NON_AGGREGATOR) to ::completeExecutionPhaseAtNonAggregator, ) private val kBytesPerCipherText = 66 @@ -213,19 +214,18 @@ class ReachOnlyLiquidLegionsV2Mill( Version.VERSION_UNSPECIFIED -> error("Public api version is invalid or unspecified.") } - val request = - SetParticipantRequisitionParamsRequest.newBuilder() - .apply { - name = ComputationParticipantKey(token.globalComputationId, duchyId).toName() - requisitionParamsBuilder.apply { - duchyCertificate = consentSignalCert.name - reachOnlyLiquidLegionsV2Builder.apply { + val request = setParticipantRequisitionParamsRequest { + name = ComputationParticipantKey(token.globalComputationId, duchyId).toName() + requisitionParams = + ComputationParticipantKt.requisitionParams { + duchyCertificate = consentSignalCert.name + reachOnlyLiquidLegionsV2 = + ComputationParticipantKt.RequisitionParamsKt.liquidLegionsV2 { elGamalPublicKey = signedElgamalPublicKey.data elGamalPublicKeySignature = signedElgamalPublicKey.signature } - } } - .build() + } systemComputationParticipantsClient.setParticipantRequisitionParams(request) } @@ -241,10 +241,9 @@ class ReachOnlyLiquidLegionsV2Mill( token } else { // Generates a new set of ElGamalKeyPair. - val request = - CompleteReachOnlyInitializationPhaseRequest.newBuilder() - .apply { curveId = ellipticCurveId.toLong() } - .build() + val request = completeReachOnlyInitializationPhaseRequest { + curveId = ellipticCurveId.toLong() + } val cryptoResult = cryptoWorker.completeReachOnlyInitializationPhase(request) logStageDurationMetric( token, @@ -256,18 +255,23 @@ class ReachOnlyLiquidLegionsV2Mill( // Updates the newly generated localElgamalKey to the ComputationDetails. dataClients.computationsClient .updateComputationDetails( - UpdateComputationDetailsRequest.newBuilder() - .also { - it.token = token - it.details = - token.computationDetails - .toBuilder() - .apply { - reachOnlyLiquidLegionsV2Builder.localElgamalKey = cryptoResult.elGamalKeyPair - } - .build() + updateComputationDetailsRequest { + this.token = token + this.details = computationDetails { + this.blobsStoragePrefix = token.computationDetails.blobsStoragePrefix + this.endingState = token.computationDetails.endingState + this.kingdomComputation = token.computationDetails.kingdomComputation + this.reachOnlyLiquidLegionsV2 = + ReachOnlyLiquidLegionsSketchAggregationV2Kt.computationDetails { + this.role = token.computationDetails.reachOnlyLiquidLegionsV2.role + this.parameters = token.computationDetails.reachOnlyLiquidLegionsV2.parameters + participant.addAll( + token.computationDetails.reachOnlyLiquidLegionsV2.participantList + ) + this.localElgamalKey = cryptoResult.elGamalKeyPair + } } - .build() + } ) .token } @@ -354,13 +358,10 @@ class ReachOnlyLiquidLegionsV2Mill( } private fun List.toCombinedPublicKey(curveId: Int): ElGamalPublicKey { - val request = - CombineElGamalPublicKeysRequest.newBuilder() - .also { - it.curveId = curveId.toLong() - it.addAllElGamalKeys(this.map { key -> key.toAnySketchElGamalPublicKey() }) - } - .build() + val request = combineElGamalPublicKeysRequest { + this.curveId = curveId.toLong() + this.elGamalKeys += map { it.toAnySketchElGamalPublicKey() } + } return cryptoWorker.combineElGamalPublicKeys(request).elGamalKeys.toCmmsElGamalPublicKey() } @@ -395,21 +396,26 @@ class ReachOnlyLiquidLegionsV2Mill( return dataClients.computationsClient .updateComputationDetails( - UpdateComputationDetailsRequest.newBuilder() - .apply { - this.token = token - details = - token.computationDetails - .toBuilder() - .apply { - reachOnlyLiquidLegionsV2Builder.also { - it.combinedPublicKey = combinedPublicKey - it.partiallyCombinedPublicKey = partiallyCombinedPublicKey - } + updateComputationDetailsRequest { + this.token = token + details = computationDetails { + this.blobsStoragePrefix = token.computationDetails.blobsStoragePrefix + this.endingState = token.computationDetails.endingState + this.kingdomComputation = token.computationDetails.kingdomComputation + this.reachOnlyLiquidLegionsV2 = + ReachOnlyLiquidLegionsSketchAggregationV2Kt.computationDetails { + this.role = token.computationDetails.reachOnlyLiquidLegionsV2.role + this.parameters = token.computationDetails.reachOnlyLiquidLegionsV2.parameters + token.computationDetails.reachOnlyLiquidLegionsV2.participantList.forEach { + this.participant += it } - .build() + this.localElgamalKey = + token.computationDetails.reachOnlyLiquidLegionsV2.localElgamalKey + this.combinedPublicKey = combinedPublicKey + this.partiallyCombinedPublicKey = partiallyCombinedPublicKey + } } - .build() + } ) .token } @@ -417,9 +423,9 @@ class ReachOnlyLiquidLegionsV2Mill( /** Sends confirmation to the kingdom and transits the local computation to the next stage. */ private suspend fun passConfirmationPhase(token: ComputationToken): ComputationToken { systemComputationParticipantsClient.confirmComputationParticipant( - ConfirmComputationParticipantRequest.newBuilder() - .apply { name = ComputationParticipantKey(token.globalComputationId, duchyId).toName() } - .build() + confirmComputationParticipantRequest { + name = ComputationParticipantKey(token.globalComputationId, duchyId).toName() + } ) val latestToken = updatePublicElgamalKey(token) return dataClients.transitionComputationToStage( @@ -452,9 +458,7 @@ class ReachOnlyLiquidLegionsV2Mill( } } - private suspend fun completeReachOnlySetupPhaseAtAggregator( - token: ComputationToken - ): ComputationToken { + private suspend fun completeSetupPhaseAtAggregator(token: ComputationToken): ComputationToken { val rollv2Details = token.computationDetails.reachOnlyLiquidLegionsV2 require(AGGREGATOR == rollv2Details.role) { "invalid role for this function." } val (bytes, nextToken) = @@ -463,10 +467,7 @@ class ReachOnlyLiquidLegionsV2Mill( dataClients .readAllRequisitionBlobs(token, duchyId) .concat(readAndCombineAllInputBlobsSetupPhaseAtAggregator(token, workerStubs.size)) - .toCompleteReachOnlySetupPhaseAtAggregatorRequest( - rollv2Details, - token.requisitionsCount - ) + .toCompleteSetupPhaseAtAggregatorRequest(rollv2Details, token.requisitionsCount) val cryptoResult: CompleteReachOnlySetupPhaseResponse = cryptoWorker.completeReachOnlySetupPhaseAtAggregator(request) logStageDurationMetric( @@ -496,9 +497,7 @@ class ReachOnlyLiquidLegionsV2Mill( ) } - private suspend fun completeReachOnlySetupPhaseAtNonAggregator( - token: ComputationToken - ): ComputationToken { + private suspend fun completeSetupPhaseAtNonAggregator(token: ComputationToken): ComputationToken { val rollv2Details = token.computationDetails.reachOnlyLiquidLegionsV2 require(NON_AGGREGATOR == rollv2Details.role) { "invalid role for this function." } val (bytes, nextToken) = @@ -536,7 +535,7 @@ class ReachOnlyLiquidLegionsV2Mill( ) } - private suspend fun completeReachOnlyExecutionPhaseAtAggregator( + private suspend fun completeExecutionPhaseAtAggregator( token: ComputationToken ): ComputationToken { val rollv2Details = token.computationDetails.reachOnlyLiquidLegionsV2 @@ -602,7 +601,7 @@ class ReachOnlyLiquidLegionsV2Mill( return completeComputation(nextToken, CompletedReason.SUCCEEDED) } - private suspend fun completeReachOnlyExecutionPhaseAtNonAggregator( + private suspend fun completeExecutionPhaseAtNonAggregator( token: ComputationToken ): ComputationToken { val rollv2Details = token.computationDetails.reachOnlyLiquidLegionsV2 @@ -616,17 +615,15 @@ class ReachOnlyLiquidLegionsV2Mill( existingOutputOr(token) { val cryptoResult: CompleteReachOnlyExecutionPhaseResponse = cryptoWorker.completeReachOnlyExecutionPhase( - CompleteReachOnlyExecutionPhaseRequest.newBuilder() - .apply { - combinedRegisterVector = - inputBlob.substring(0, inputBlob.size() - kBytesPerCipherText) - localElGamalKeyPair = rollv2Details.localElgamalKey - curveId = rollv2Details.parameters.ellipticCurveId.toLong() - serializedExcessiveNoiseCiphertext = - inputBlob.substring(inputBlob.size() - kBytesPerCipherText, inputBlob.size()) - parallelism = this@ReachOnlyLiquidLegionsV2Mill.parallelism - } - .build() + completeReachOnlyExecutionPhaseRequest { + combinedRegisterVector = + inputBlob.substring(0, inputBlob.size() - kBytesPerCipherText) + localElGamalKeyPair = rollv2Details.localElgamalKey + curveId = rollv2Details.parameters.ellipticCurveId.toLong() + serializedExcessiveNoiseCiphertext = + inputBlob.substring(inputBlob.size() - kBytesPerCipherText, inputBlob.size()) + parallelism = this@ReachOnlyLiquidLegionsV2Mill.parallelism + } ) logStageDurationMetric( token, @@ -724,12 +721,12 @@ class ReachOnlyLiquidLegionsV2Mill( } } - private fun ByteString.toCompleteReachOnlySetupPhaseAtAggregatorRequest( + private fun ByteString.toCompleteSetupPhaseAtAggregatorRequest( rollv2Details: ReachOnlyLiquidLegionsSketchAggregationV2.ComputationDetails, totalRequisitionsCount: Int ): CompleteReachOnlySetupPhaseRequest { val noiseConfig = rollv2Details.parameters.noise - val combinedInputBlobs = this@toCompleteReachOnlySetupPhaseAtAggregatorRequest + val combinedInputBlobs = this@toCompleteSetupPhaseAtAggregatorRequest return completeReachOnlySetupPhaseRequest { combinedRegisterVector = combinedInputBlobs.substring( diff --git a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/LiquidLegionsV2MillTest.kt b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/LiquidLegionsV2MillTest.kt index a2e0b097296..e5e5590b648 100644 --- a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/LiquidLegionsV2MillTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/LiquidLegionsV2MillTest.kt @@ -113,7 +113,6 @@ import org.wfanet.measurement.internal.duchy.protocol.CompleteInitializationPhas import org.wfanet.measurement.internal.duchy.protocol.CompleteSetupPhaseRequest import org.wfanet.measurement.internal.duchy.protocol.CompleteSetupPhaseResponse import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggregationV2.ComputationDetails.ComputationParticipant -import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggregationV2.ComputationDetails.Parameters import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggregationV2.Stage.COMPLETE import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggregationV2.Stage.CONFIRMATION_PHASE import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggregationV2.Stage.EXECUTION_PHASE_ONE @@ -128,6 +127,7 @@ import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggrega import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggregationV2.Stage.WAIT_SETUP_PHASE_INPUTS import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggregationV2.Stage.WAIT_TO_START import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggregationV2Kt +import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggregationV2Kt.ComputationDetailsKt.parameters import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsV2NoiseConfig import org.wfanet.measurement.internal.duchy.protocol.completeExecutionPhaseOneAtAggregatorRequest import org.wfanet.measurement.internal.duchy.protocol.completeExecutionPhaseOneAtAggregatorResponse @@ -252,18 +252,15 @@ private val TEST_NOISE_CONFIG = } .build() -private val LLV2_PARAMETERS = - Parameters.newBuilder() - .apply { - maximumFrequency = MAX_FREQUENCY - sketchParametersBuilder.apply { - decayRate = DECAY_RATE - size = SKETCH_SIZE - } - noise = TEST_NOISE_CONFIG - ellipticCurveId = CURVE_ID.toInt() - } - .build() +private val LLV2_PARAMETERS = parameters { + maximumFrequency = MAX_FREQUENCY + sketchParameters = liquidLegionsSketchParameters { + decayRate = DECAY_RATE + size = SKETCH_SIZE + } + noise = TEST_NOISE_CONFIG + ellipticCurveId = CURVE_ID.toInt() +} // In the test, use the same set of cert and encryption key for all parties. private const val CONSENT_SIGNALING_CERT_NAME = "Just a name" diff --git a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2MillTest.kt b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2MillTest.kt index acecf238a7e..ba1d895881d 100644 --- a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2MillTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2MillTest.kt @@ -49,10 +49,11 @@ import org.mockito.kotlin.any import org.mockito.kotlin.mock import org.mockito.kotlin.whenever import org.wfanet.anysketch.crypto.CombineElGamalPublicKeysRequest -import org.wfanet.anysketch.crypto.CombineElGamalPublicKeysResponse -import org.wfanet.measurement.api.v2alpha.ElGamalPublicKey as V2AlphaElGamalPublicKey +import org.wfanet.anysketch.crypto.combineElGamalPublicKeysResponse +import org.wfanet.anysketch.crypto.elGamalPublicKey as AnySketchElGamalPublicKey import org.wfanet.measurement.api.v2alpha.MeasurementSpecKt.reach import org.wfanet.measurement.api.v2alpha.MeasurementSpecKt.vidSamplingInterval +import org.wfanet.measurement.api.v2alpha.elGamalPublicKey as V2AlphaElGamalPublicKey import org.wfanet.measurement.api.v2alpha.measurementSpec import org.wfanet.measurement.common.crypto.SigningKeyHandle import org.wfanet.measurement.common.crypto.readCertificate @@ -86,26 +87,21 @@ import org.wfanet.measurement.internal.duchy.ComputationDetails.CompletedReason import org.wfanet.measurement.internal.duchy.ComputationDetailsKt.kingdomComputationDetails import org.wfanet.measurement.internal.duchy.ComputationStatsGrpcKt.ComputationStatsCoroutineImplBase import org.wfanet.measurement.internal.duchy.ComputationStatsGrpcKt.ComputationStatsCoroutineStub -import org.wfanet.measurement.internal.duchy.ComputationToken import org.wfanet.measurement.internal.duchy.ComputationsGrpcKt.ComputationsCoroutineStub -import org.wfanet.measurement.internal.duchy.ElGamalKeyPair -import org.wfanet.measurement.internal.duchy.ElGamalPublicKey import org.wfanet.measurement.internal.duchy.computationDetails import org.wfanet.measurement.internal.duchy.computationStageBlobMetadata +import org.wfanet.measurement.internal.duchy.computationStageDetails import org.wfanet.measurement.internal.duchy.computationToken import org.wfanet.measurement.internal.duchy.config.LiquidLegionsV2SetupConfig.RoleInComputation import org.wfanet.measurement.internal.duchy.copy +import org.wfanet.measurement.internal.duchy.differentialPrivacyParams +import org.wfanet.measurement.internal.duchy.elGamalKeyPair +import org.wfanet.measurement.internal.duchy.elGamalPublicKey import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyExecutionPhaseAtAggregatorRequest -import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyExecutionPhaseAtAggregatorResponse import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyExecutionPhaseRequest -import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyExecutionPhaseResponse import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyInitializationPhaseRequest -import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyInitializationPhaseResponse import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlySetupPhaseRequest -import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlySetupPhaseResponse -import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsV2NoiseConfig -import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2.ComputationDetails.ComputationParticipant -import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2.ComputationDetails.Parameters +import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsV2NoiseConfigKt import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2.Stage.COMPLETE import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2.Stage.CONFIRMATION_PHASE import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2.Stage.EXECUTION_PHASE @@ -116,17 +112,28 @@ import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSket import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2.Stage.WAIT_SETUP_PHASE_INPUTS import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2.Stage.WAIT_TO_START import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2Kt +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2Kt.ComputationDetailsKt.computationParticipant +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2Kt.ComputationDetailsKt.parameters +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2Kt.stageDetails +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2Kt.waitSetupPhaseInputsDetails import org.wfanet.measurement.internal.duchy.protocol.completeReachOnlyExecutionPhaseAtAggregatorRequest +import org.wfanet.measurement.internal.duchy.protocol.completeReachOnlyExecutionPhaseAtAggregatorResponse import org.wfanet.measurement.internal.duchy.protocol.completeReachOnlyExecutionPhaseRequest +import org.wfanet.measurement.internal.duchy.protocol.completeReachOnlyExecutionPhaseResponse +import org.wfanet.measurement.internal.duchy.protocol.completeReachOnlyInitializationPhaseRequest +import org.wfanet.measurement.internal.duchy.protocol.completeReachOnlyInitializationPhaseResponse import org.wfanet.measurement.internal.duchy.protocol.completeReachOnlySetupPhaseRequest +import org.wfanet.measurement.internal.duchy.protocol.completeReachOnlySetupPhaseResponse import org.wfanet.measurement.internal.duchy.protocol.copy import org.wfanet.measurement.internal.duchy.protocol.globalReachDpNoiseBaseline import org.wfanet.measurement.internal.duchy.protocol.liquidLegionsSketchParameters +import org.wfanet.measurement.internal.duchy.protocol.liquidLegionsV2NoiseConfig import org.wfanet.measurement.internal.duchy.protocol.reachNoiseDifferentialPrivacyParams import org.wfanet.measurement.internal.duchy.protocol.registerNoiseGenerationParameters import org.wfanet.measurement.storage.Store.Blob import org.wfanet.measurement.storage.filesystem.FileSystemStorageClient import org.wfanet.measurement.system.v1alpha.AdvanceComputationRequest +import org.wfanet.measurement.system.v1alpha.AdvanceComputationRequestKt import org.wfanet.measurement.system.v1alpha.AdvanceComputationResponse import org.wfanet.measurement.system.v1alpha.Computation import org.wfanet.measurement.system.v1alpha.ComputationControlGrpcKt.ComputationControlCoroutineImplBase @@ -136,19 +143,25 @@ import org.wfanet.measurement.system.v1alpha.ComputationLogEntriesGrpcKt.Computa import org.wfanet.measurement.system.v1alpha.ComputationLogEntriesGrpcKt.ComputationLogEntriesCoroutineStub as SystemComputationLogEntriesCoroutineStub import org.wfanet.measurement.system.v1alpha.ComputationLogEntry import org.wfanet.measurement.system.v1alpha.ComputationParticipantKey +import org.wfanet.measurement.system.v1alpha.ComputationParticipantKt +import org.wfanet.measurement.system.v1alpha.ComputationParticipantKt.RequisitionParamsKt +import org.wfanet.measurement.system.v1alpha.ComputationParticipantKt.requisitionParams import org.wfanet.measurement.system.v1alpha.ComputationParticipantsGrpcKt.ComputationParticipantsCoroutineImplBase as SystemComputationParticipantsCoroutineImplBase import org.wfanet.measurement.system.v1alpha.ComputationParticipantsGrpcKt.ComputationParticipantsCoroutineStub as SystemComputationParticipantsCoroutineStub import org.wfanet.measurement.system.v1alpha.ComputationsGrpcKt.ComputationsCoroutineImplBase as SystemComputationsCoroutineImplBase import org.wfanet.measurement.system.v1alpha.ComputationsGrpcKt.ComputationsCoroutineStub as SystemComputationsCoroutineStub -import org.wfanet.measurement.system.v1alpha.ConfirmComputationParticipantRequest -import org.wfanet.measurement.system.v1alpha.FailComputationParticipantRequest import org.wfanet.measurement.system.v1alpha.ReachOnlyLiquidLegionsV2 import org.wfanet.measurement.system.v1alpha.ReachOnlyLiquidLegionsV2.Description.EXECUTION_PHASE_INPUT import org.wfanet.measurement.system.v1alpha.ReachOnlyLiquidLegionsV2.Description.SETUP_PHASE_INPUT import org.wfanet.measurement.system.v1alpha.Requisition import org.wfanet.measurement.system.v1alpha.SetComputationResultRequest -import org.wfanet.measurement.system.v1alpha.SetParticipantRequisitionParamsRequest +import org.wfanet.measurement.system.v1alpha.advanceComputationRequest +import org.wfanet.measurement.system.v1alpha.confirmComputationParticipantRequest +import org.wfanet.measurement.system.v1alpha.failComputationParticipantRequest +import org.wfanet.measurement.system.v1alpha.reachOnlyLiquidLegionsV2 import org.wfanet.measurement.system.v1alpha.setComputationResultRequest +import org.wfanet.measurement.system.v1alpha.setParticipantRequisitionParamsRequest +import org.wfanet.measurement.system.v1alpha.stageAttempt private const val PUBLIC_API_VERSION = "v2alpha" @@ -170,76 +183,56 @@ private const val NOISE_CIPHERTEXT = "abcdefghijklmnopqrstuvwxyz0123456abcdefghijklmnopqrstuvwxyz0123456" private val SERIALIZED_NOISE_CIPHERTEXT = ByteString.copyFromUtf8(NOISE_CIPHERTEXT) -private val DUCHY_ONE_KEY_PAIR = - ElGamalKeyPair.newBuilder() - .apply { - publicKeyBuilder.apply { - generator = ByteString.copyFromUtf8("generator_1") - element = ByteString.copyFromUtf8("element_1") +private val DUCHY_ONE_KEY_PAIR = elGamalKeyPair { + publicKey = elGamalPublicKey { + generator = ByteString.copyFromUtf8("generator_1") + element = ByteString.copyFromUtf8("element_1") + } + secretKey = ByteString.copyFromUtf8("secret_key_1") +} +private val DUCHY_TWO_PUBLIC_KEY = elGamalPublicKey { + generator = ByteString.copyFromUtf8("generator_2") + element = ByteString.copyFromUtf8("element_2") +} +private val DUCHY_THREE_PUBLIC_KEY = elGamalPublicKey { + generator = ByteString.copyFromUtf8("generator_3") + element = ByteString.copyFromUtf8("element_3") +} +private val COMBINED_PUBLIC_KEY = elGamalPublicKey { + generator = ByteString.copyFromUtf8("generator_1_generator_2_generator_3") + element = ByteString.copyFromUtf8("element_1_element_2_element_3") +} +private val PARTIALLY_COMBINED_PUBLIC_KEY = elGamalPublicKey { + generator = ByteString.copyFromUtf8("generator_2_generator_3") + element = ByteString.copyFromUtf8("element_2_element_3") +} + +private val TEST_NOISE_CONFIG = liquidLegionsV2NoiseConfig { + reachNoiseConfig = + LiquidLegionsV2NoiseConfigKt.reachNoiseConfig { + blindHistogramNoise = differentialPrivacyParams { + epsilon = 1.0 + delta = 2.0 } - secretKey = ByteString.copyFromUtf8("secret_key_1") - } - .build() -private val DUCHY_TWO_PUBLIC_KEY = - ElGamalPublicKey.newBuilder() - .apply { - generator = ByteString.copyFromUtf8("generator_2") - element = ByteString.copyFromUtf8("element_2") - } - .build() -private val DUCHY_THREE_PUBLIC_KEY = - ElGamalPublicKey.newBuilder() - .apply { - generator = ByteString.copyFromUtf8("generator_3") - element = ByteString.copyFromUtf8("element_3") - } - .build() -private val COMBINED_PUBLIC_KEY = - ElGamalPublicKey.newBuilder() - .apply { - generator = ByteString.copyFromUtf8("generator_1_generator_2_generator_3") - element = ByteString.copyFromUtf8("element_1_element_2_element_3") - } - .build() -private val PARTIALLY_COMBINED_PUBLIC_KEY = - ElGamalPublicKey.newBuilder() - .apply { - generator = ByteString.copyFromUtf8("generator_2_generator_3") - element = ByteString.copyFromUtf8("element_2_element_3") - } - .build() - -private val TEST_NOISE_CONFIG = - LiquidLegionsV2NoiseConfig.newBuilder() - .apply { - reachNoiseConfigBuilder.apply { - blindHistogramNoiseBuilder.apply { - epsilon = 1.0 - delta = 2.0 - } - noiseForPublisherNoiseBuilder.apply { - epsilon = 3.0 - delta = 4.0 - } - globalReachDpNoiseBuilder.apply { - epsilon = 5.0 - delta = 6.0 - } + noiseForPublisherNoise = differentialPrivacyParams { + epsilon = 3.0 + delta = 4.0 } - } - .build() - -private val ROLLV2_PARAMETERS = - Parameters.newBuilder() - .apply { - sketchParametersBuilder.apply { - decayRate = DECAY_RATE - size = SKETCH_SIZE + globalReachDpNoise = differentialPrivacyParams { + epsilon = 5.0 + delta = 6.0 } - noise = TEST_NOISE_CONFIG - ellipticCurveId = CURVE_ID.toInt() } - .build() +} + +private val ROLLV2_PARAMETERS = parameters { + sketchParameters = liquidLegionsSketchParameters { + decayRate = DECAY_RATE + size = SKETCH_SIZE + } + noise = TEST_NOISE_CONFIG + ellipticCurveId = CURVE_ID.toInt() +} // In the test, use the same set of cert and encryption key for all parties. private const val CONSENT_SIGNALING_CERT_NAME = "Just a name" @@ -254,13 +247,10 @@ private val ENCRYPTION_PUBLIC_KEY: org.wfanet.measurement.api.v2alpha.Encryption ENCRYPTION_PRIVATE_KEY.publicKey.toEncryptionPublicKey() /** A public Key used for consent signaling check. */ -private val CONSENT_SIGNALING_EL_GAMAL_PUBLIC_KEY = - V2AlphaElGamalPublicKey.newBuilder() - .apply { - generator = ByteString.copyFromUtf8("generator-foo") - element = ByteString.copyFromUtf8("element-foo") - } - .build() +private val CONSENT_SIGNALING_EL_GAMAL_PUBLIC_KEY = V2AlphaElGamalPublicKey { + generator = ByteString.copyFromUtf8("generator-foo") + element = ByteString.copyFromUtf8("element-foo") +} /** A pre-computed signature of the CONSENT_SIGNALING_EL_GAMAL_PUBLIC_KEY. */ private val CONSENT_SIGNALING_EL_GAMAL_PUBLIC_KEY_SINGATURE = ByteString.copyFrom( @@ -271,36 +261,27 @@ private val CONSENT_SIGNALING_EL_GAMAL_PUBLIC_KEY_SINGATURE = ) ) -private val COMPUTATION_PARTICIPANT_1 = - ComputationParticipant.newBuilder() - .apply { - duchyId = DUCHY_ONE_NAME - publicKey = DUCHY_ONE_KEY_PAIR.publicKey - elGamalPublicKey = CONSENT_SIGNALING_EL_GAMAL_PUBLIC_KEY.toByteString() - elGamalPublicKeySignature = CONSENT_SIGNALING_EL_GAMAL_PUBLIC_KEY_SINGATURE - duchyCertificateDer = CONSENT_SIGNALING_CERT_DER - } - .build() -private val COMPUTATION_PARTICIPANT_2 = - ComputationParticipant.newBuilder() - .apply { - duchyId = DUCHY_TWO_NAME - publicKey = DUCHY_TWO_PUBLIC_KEY - elGamalPublicKey = CONSENT_SIGNALING_EL_GAMAL_PUBLIC_KEY.toByteString() - elGamalPublicKeySignature = CONSENT_SIGNALING_EL_GAMAL_PUBLIC_KEY_SINGATURE - duchyCertificateDer = CONSENT_SIGNALING_CERT_DER - } - .build() -private val COMPUTATION_PARTICIPANT_3 = - ComputationParticipant.newBuilder() - .apply { - duchyId = DUCHY_THREE_NAME - publicKey = DUCHY_THREE_PUBLIC_KEY - elGamalPublicKey = CONSENT_SIGNALING_EL_GAMAL_PUBLIC_KEY.toByteString() - elGamalPublicKeySignature = CONSENT_SIGNALING_EL_GAMAL_PUBLIC_KEY_SINGATURE - duchyCertificateDer = CONSENT_SIGNALING_CERT_DER - } - .build() +private val COMPUTATION_PARTICIPANT_1 = computationParticipant { + duchyId = DUCHY_ONE_NAME + publicKey = DUCHY_ONE_KEY_PAIR.publicKey + elGamalPublicKey = CONSENT_SIGNALING_EL_GAMAL_PUBLIC_KEY.toByteString() + elGamalPublicKeySignature = CONSENT_SIGNALING_EL_GAMAL_PUBLIC_KEY_SINGATURE + duchyCertificateDer = CONSENT_SIGNALING_CERT_DER +} +private val COMPUTATION_PARTICIPANT_2 = computationParticipant { + duchyId = DUCHY_TWO_NAME + publicKey = DUCHY_TWO_PUBLIC_KEY + elGamalPublicKey = CONSENT_SIGNALING_EL_GAMAL_PUBLIC_KEY.toByteString() + elGamalPublicKeySignature = CONSENT_SIGNALING_EL_GAMAL_PUBLIC_KEY_SINGATURE + duchyCertificateDer = CONSENT_SIGNALING_CERT_DER +} +private val COMPUTATION_PARTICIPANT_3 = computationParticipant { + duchyId = DUCHY_THREE_NAME + publicKey = DUCHY_THREE_PUBLIC_KEY + elGamalPublicKey = CONSENT_SIGNALING_EL_GAMAL_PUBLIC_KEY.toByteString() + elGamalPublicKeySignature = CONSENT_SIGNALING_EL_GAMAL_PUBLIC_KEY_SINGATURE + duchyCertificateDer = CONSENT_SIGNALING_CERT_DER +} private val TEST_REQUISITION_1 = TestRequisition("111") { SERIALIZED_MEASUREMENT_SPEC } private val TEST_REQUISITION_2 = TestRequisition("222") { SERIALIZED_MEASUREMENT_SPEC } @@ -394,24 +375,22 @@ class ReachOnlyLiquidLegionsV2MillTest { on { combineElGamalPublicKeys(any()) } .thenAnswer { val cryptoRequest: CombineElGamalPublicKeysRequest = it.getArgument(0) - CombineElGamalPublicKeysResponse.newBuilder() - .apply { - elGamalKeysBuilder.apply { - generator = - ByteString.copyFromUtf8( - cryptoRequest.elGamalKeysList - .sortedBy { key -> key.generator.toStringUtf8() } - .joinToString(separator = "_") { key -> key.generator.toStringUtf8() } - ) - element = - ByteString.copyFromUtf8( - cryptoRequest.elGamalKeysList - .sortedBy { key -> key.element.toStringUtf8() } - .joinToString(separator = "_") { key -> key.element.toStringUtf8() } - ) - } + combineElGamalPublicKeysResponse { + elGamalKeys = AnySketchElGamalPublicKey { + generator = + ByteString.copyFromUtf8( + cryptoRequest.elGamalKeysList + .sortedBy { key -> key.generator.toStringUtf8() } + .joinToString(separator = "_") { key -> key.generator.toStringUtf8() } + ) + element = + ByteString.copyFromUtf8( + cryptoRequest.elGamalKeysList + .sortedBy { key -> key.element.toStringUtf8() } + .joinToString(separator = "_") { key -> key.element.toStringUtf8() } + ) } - .build() + } } } private val fakeComputationDb = FakeComputationsDatabase() @@ -485,20 +464,21 @@ class ReachOnlyLiquidLegionsV2MillTest { description: ReachOnlyLiquidLegionsV2.Description, vararg chunkContents: String ): List { - val header = - AdvanceComputationRequest.newBuilder() - .apply { - headerBuilder.apply { - name = ComputationKey(globalComputationId).toName() - reachOnlyLiquidLegionsV2Builder.description = description + val header = advanceComputationRequest { + header = + AdvanceComputationRequestKt.header { + name = ComputationKey(globalComputationId).toName() + this.reachOnlyLiquidLegionsV2 = reachOnlyLiquidLegionsV2 { + this.description = description } } - .build() + } val body = chunkContents.asList().map { - AdvanceComputationRequest.newBuilder() - .apply { bodyChunkBuilder.apply { partialData = ByteString.copyFromUtf8(it) } } - .build() + advanceComputationRequest { + bodyChunk = + AdvanceComputationRequestKt.bodyChunk { partialData = ByteString.copyFromUtf8(it) } + } } return listOf(header) + body } @@ -574,17 +554,16 @@ class ReachOnlyLiquidLegionsV2MillTest { ) .build() - val initialComputationDetails = - NON_AGGREGATOR_COMPUTATION_DETAILS.toBuilder() - .apply { - reachOnlyLiquidLegionsV2Builder.apply { - parametersBuilder.ellipticCurveId = CURVE_ID.toInt() - clearPartiallyCombinedPublicKey() - clearCombinedPublicKey() - clearLocalElgamalKey() - } + val initialComputationDetails = computationDetails { + kingdomComputation = NON_AGGREGATOR_COMPUTATION_DETAILS.kingdomComputation + reachOnlyLiquidLegionsV2 = + ReachOnlyLiquidLegionsSketchAggregationV2Kt.computationDetails { + role = RoleInComputation.NON_AGGREGATOR + parameters = ROLLV2_PARAMETERS + participant += + listOf(COMPUTATION_PARTICIPANT_1, COMPUTATION_PARTICIPANT_2, COMPUTATION_PARTICIPANT_3) } - .build() + } fakeComputationDb.addComputation( partialToken.localComputationId, @@ -594,17 +573,15 @@ class ReachOnlyLiquidLegionsV2MillTest { ) whenever(mockCryptoWorker.completeReachOnlyInitializationPhase(any())).thenAnswer { - CompleteReachOnlyInitializationPhaseResponse.newBuilder() - .apply { - elGamalKeyPairBuilder.apply { - publicKeyBuilder.apply { - generator = ByteString.copyFromUtf8("generator-foo") - element = ByteString.copyFromUtf8("element-foo") - } - secretKey = ByteString.copyFromUtf8("secretKey-foo") + completeReachOnlyInitializationPhaseResponse { + this.elGamalKeyPair = elGamalKeyPair { + publicKey = elGamalPublicKey { + generator = ByteString.copyFromUtf8("generator-foo") + element = ByteString.copyFromUtf8("element-foo") } + secretKey = ByteString.copyFromUtf8("secretKey-foo") } - .build() + } } // This will result in TRANSIENT gRPC failure. @@ -616,90 +593,107 @@ class ReachOnlyLiquidLegionsV2MillTest { assertThat(fakeComputationDb[LOCAL_ID]) .isEqualTo( - ComputationToken.newBuilder() - .apply { - globalComputationId = GLOBAL_ID - localComputationId = LOCAL_ID - attempt = 1 - computationStage = INITIALIZATION_PHASE.toProtocolStage() - version = 3 // claimTask + updateComputationDetails + enqueueComputation - computationDetails = - initialComputationDetails - .toBuilder() - .apply { - reachOnlyLiquidLegionsV2Builder.localElgamalKeyBuilder.apply { - publicKeyBuilder.apply { - generator = ByteString.copyFromUtf8("generator-foo") - element = ByteString.copyFromUtf8("element-foo") - } - secretKey = ByteString.copyFromUtf8("secretKey-foo") + computationToken { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 1 + computationStage = INITIALIZATION_PHASE.toProtocolStage() + version = 3 // claimTask + updateComputationDetails + enqueueComputation + this.computationDetails = computationDetails { + kingdomComputation = initialComputationDetails.kingdomComputation + reachOnlyLiquidLegionsV2 = + ReachOnlyLiquidLegionsSketchAggregationV2Kt.computationDetails { + role = RoleInComputation.NON_AGGREGATOR + parameters = ROLLV2_PARAMETERS + participant += + listOf( + COMPUTATION_PARTICIPANT_1, + COMPUTATION_PARTICIPANT_2, + COMPUTATION_PARTICIPANT_3 + ) + this.localElgamalKey = elGamalKeyPair { + publicKey = elGamalPublicKey { + generator = ByteString.copyFromUtf8("generator-foo") + element = ByteString.copyFromUtf8("element-foo") } + secretKey = ByteString.copyFromUtf8("secretKey-foo") } - .build() - addAllRequisitions(REQUISITIONS) + } } - .build() + requisitions.addAll(REQUISITIONS) + } ) // Second attempt fails, which doesn't change the computation stage. nonAggregatorMill.pollAndProcessNextComputation() assertThat(fakeComputationDb[LOCAL_ID]) .isEqualTo( - ComputationToken.newBuilder() - .apply { - globalComputationId = GLOBAL_ID - localComputationId = LOCAL_ID - attempt = 2 - computationStage = INITIALIZATION_PHASE.toProtocolStage() - version = 5 // claimTask + updateComputationDetails + enqueueComputation - computationDetails = - initialComputationDetails - .toBuilder() - .apply { - reachOnlyLiquidLegionsV2Builder.localElgamalKeyBuilder.apply { - publicKeyBuilder.apply { - generator = ByteString.copyFromUtf8("generator-foo") - element = ByteString.copyFromUtf8("element-foo") - } - secretKey = ByteString.copyFromUtf8("secretKey-foo") + computationToken { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 2 + computationStage = INITIALIZATION_PHASE.toProtocolStage() + version = 5 // claimTask + updateComputationDetails + enqueueComputation + this.computationDetails = computationDetails { + kingdomComputation = initialComputationDetails.kingdomComputation + reachOnlyLiquidLegionsV2 = + ReachOnlyLiquidLegionsSketchAggregationV2Kt.computationDetails { + role = RoleInComputation.NON_AGGREGATOR + parameters = ROLLV2_PARAMETERS + participant += + listOf( + COMPUTATION_PARTICIPANT_1, + COMPUTATION_PARTICIPANT_2, + COMPUTATION_PARTICIPANT_3 + ) + this.localElgamalKey = elGamalKeyPair { + publicKey = elGamalPublicKey { + generator = ByteString.copyFromUtf8("generator-foo") + element = ByteString.copyFromUtf8("element-foo") } + secretKey = ByteString.copyFromUtf8("secretKey-foo") } - .build() - addAllRequisitions(REQUISITIONS) + } } - .build() + requisitions.addAll(REQUISITIONS) + } ) - // Third attempt fails, which will fail the computation. nonAggregatorMill.pollAndProcessNextComputation() assertThat(fakeComputationDb[LOCAL_ID]) .isEqualTo( - ComputationToken.newBuilder() - .apply { - globalComputationId = GLOBAL_ID - localComputationId = LOCAL_ID - attempt = 3 - computationStage = COMPLETE.toProtocolStage() - version = 8 // claimTask + updateComputationDetails + enqueueComputation + claimTask + - // EndComputation - computationDetails = - initialComputationDetails - .toBuilder() - .apply { - endingState = CompletedReason.FAILED - reachOnlyLiquidLegionsV2Builder.localElgamalKeyBuilder.apply { - publicKeyBuilder.apply { - generator = ByteString.copyFromUtf8("generator-foo") - element = ByteString.copyFromUtf8("element-foo") - } - secretKey = ByteString.copyFromUtf8("secretKey-foo") + computationToken { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 3 + computationStage = COMPLETE.toProtocolStage() + version = 8 // claimTask + updateComputationDetails + enqueueComputation + claimTask + + // EndComputation + this.computationDetails = computationDetails { + kingdomComputation = initialComputationDetails.kingdomComputation + endingState = CompletedReason.FAILED + reachOnlyLiquidLegionsV2 = + ReachOnlyLiquidLegionsSketchAggregationV2Kt.computationDetails { + role = RoleInComputation.NON_AGGREGATOR + parameters = ROLLV2_PARAMETERS + participant += + listOf( + COMPUTATION_PARTICIPANT_1, + COMPUTATION_PARTICIPANT_2, + COMPUTATION_PARTICIPANT_3 + ) + this.localElgamalKey = elGamalKeyPair { + publicKey = elGamalPublicKey { + generator = ByteString.copyFromUtf8("generator-foo") + element = ByteString.copyFromUtf8("element-foo") } + secretKey = ByteString.copyFromUtf8("secretKey-foo") } - .build() - addAllRequisitions(REQUISITIONS) + } } - .build() + requisitions.addAll(REQUISITIONS) + } ) } @@ -713,17 +707,16 @@ class ReachOnlyLiquidLegionsV2MillTest { ) .build() - val initialComputationDetails = - NON_AGGREGATOR_COMPUTATION_DETAILS.toBuilder() - .apply { - reachOnlyLiquidLegionsV2Builder.apply { - parametersBuilder.ellipticCurveId = CURVE_ID.toInt() - clearPartiallyCombinedPublicKey() - clearCombinedPublicKey() - clearLocalElgamalKey() - } + val initialComputationDetails = computationDetails { + kingdomComputation = NON_AGGREGATOR_COMPUTATION_DETAILS.kingdomComputation + reachOnlyLiquidLegionsV2 = + ReachOnlyLiquidLegionsSketchAggregationV2Kt.computationDetails { + role = RoleInComputation.NON_AGGREGATOR + parameters = ROLLV2_PARAMETERS + participant += + listOf(COMPUTATION_PARTICIPANT_1, COMPUTATION_PARTICIPANT_2, COMPUTATION_PARTICIPANT_3) } - .build() + } fakeComputationDb.addComputation( partialToken.localComputationId, @@ -735,17 +728,15 @@ class ReachOnlyLiquidLegionsV2MillTest { var cryptoRequest = CompleteReachOnlyInitializationPhaseRequest.getDefaultInstance() whenever(mockCryptoWorker.completeReachOnlyInitializationPhase(any())).thenAnswer { cryptoRequest = it.getArgument(0) - CompleteReachOnlyInitializationPhaseResponse.newBuilder() - .apply { - elGamalKeyPairBuilder.apply { - publicKeyBuilder.apply { - generator = ByteString.copyFromUtf8("generator-foo") - element = ByteString.copyFromUtf8("element-foo") - } - secretKey = ByteString.copyFromUtf8("secretKey-foo") + completeReachOnlyInitializationPhaseResponse { + this.elGamalKeyPair = elGamalKeyPair { + publicKey = elGamalPublicKey { + generator = ByteString.copyFromUtf8("generator-foo") + element = ByteString.copyFromUtf8("element-foo") } + secretKey = ByteString.copyFromUtf8("secretKey-foo") } - .build() + } } // Stage 1. Process the above computation @@ -754,29 +745,35 @@ class ReachOnlyLiquidLegionsV2MillTest { // Stage 2. Check the status of the computation assertThat(fakeComputationDb[LOCAL_ID]) .isEqualTo( - ComputationToken.newBuilder() - .apply { - globalComputationId = GLOBAL_ID - localComputationId = LOCAL_ID - attempt = 1 - computationStage = WAIT_REQUISITIONS_AND_KEY_SET.toProtocolStage() - version = 3 // claimTask + updateComputationDetails + transitionStage - computationDetails = - initialComputationDetails - .toBuilder() - .apply { - reachOnlyLiquidLegionsV2Builder.localElgamalKeyBuilder.apply { - publicKeyBuilder.apply { - generator = ByteString.copyFromUtf8("generator-foo") - element = ByteString.copyFromUtf8("element-foo") - } - secretKey = ByteString.copyFromUtf8("secretKey-foo") + computationToken { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 1 + computationStage = WAIT_REQUISITIONS_AND_KEY_SET.toProtocolStage() + version = 3 // claimTask + updateComputationDetails + enqueueComputation + this.computationDetails = computationDetails { + kingdomComputation = initialComputationDetails.kingdomComputation + reachOnlyLiquidLegionsV2 = + ReachOnlyLiquidLegionsSketchAggregationV2Kt.computationDetails { + role = RoleInComputation.NON_AGGREGATOR + parameters = ROLLV2_PARAMETERS + participant += + listOf( + COMPUTATION_PARTICIPANT_1, + COMPUTATION_PARTICIPANT_2, + COMPUTATION_PARTICIPANT_3 + ) + this.localElgamalKey = elGamalKeyPair { + publicKey = elGamalPublicKey { + generator = ByteString.copyFromUtf8("generator-foo") + element = ByteString.copyFromUtf8("element-foo") } + secretKey = ByteString.copyFromUtf8("secretKey-foo") } - .build() - addAllRequisitions(REQUISITIONS) + } } - .build() + requisitions.addAll(REQUISITIONS) + } ) verifyProtoArgument( @@ -785,25 +782,20 @@ class ReachOnlyLiquidLegionsV2MillTest { ) .comparingExpectedFieldsOnly() .isEqualTo( - SetParticipantRequisitionParamsRequest.newBuilder() - .apply { - name = ComputationParticipantKey(GLOBAL_ID, DUCHY_ONE_NAME).toName() - requisitionParamsBuilder.apply { - duchyCertificate = CONSENT_SIGNALING_CERT_NAME - reachOnlyLiquidLegionsV2Builder.apply { + setParticipantRequisitionParamsRequest { + name = ComputationParticipantKey(GLOBAL_ID, DUCHY_ONE_NAME).toName() + this.requisitionParams = requisitionParams { + duchyCertificate = CONSENT_SIGNALING_CERT_NAME + reachOnlyLiquidLegionsV2 = + RequisitionParamsKt.liquidLegionsV2 { elGamalPublicKey = CONSENT_SIGNALING_EL_GAMAL_PUBLIC_KEY.toByteString() } - } } - .build() + } ) assertThat(cryptoRequest) - .isEqualTo( - CompleteReachOnlyInitializationPhaseRequest.newBuilder() - .apply { curveId = CURVE_ID } - .build() - ) + .isEqualTo(completeReachOnlyInitializationPhaseRequest { curveId = CURVE_ID }) } @Test @@ -813,12 +805,19 @@ class ReachOnlyLiquidLegionsV2MillTest { // requisition2 is fulfilled at Duchy One, but doesn't have path set. val requisition2 = REQUISITION_2.copy { details = details.copy { externalFulfillingDuchyId = DUCHY_ONE_NAME } } - val computationDetailsWithoutPublicKey = - AGGREGATOR_COMPUTATION_DETAILS.toBuilder() - .apply { - reachOnlyLiquidLegionsV2Builder.clearCombinedPublicKey().clearPartiallyCombinedPublicKey() + val computationDetailsWithoutPublicKey = computationDetails { + kingdomComputation = AGGREGATOR_COMPUTATION_DETAILS.kingdomComputation + reachOnlyLiquidLegionsV2 = + ReachOnlyLiquidLegionsSketchAggregationV2Kt.computationDetails { + role = RoleInComputation.AGGREGATOR + parameters = ROLLV2_PARAMETERS + participant += + listOf(COMPUTATION_PARTICIPANT_2, COMPUTATION_PARTICIPANT_3, COMPUTATION_PARTICIPANT_1) + // partiallyCombinedPublicKey and combinedPublicKey are the same at the aggregator. + partiallyCombinedPublicKey = COMBINED_PUBLIC_KEY + localElgamalKey = DUCHY_ONE_KEY_PAIR } - .build() + } fakeComputationDb.addComputation( globalId = GLOBAL_ID, stage = CONFIRMATION_PHASE.toProtocolStage(), @@ -835,21 +834,19 @@ class ReachOnlyLiquidLegionsV2MillTest { // Stage 2. Check the status of the computation assertThat(fakeComputationDb[LOCAL_ID]!!) .isEqualTo( - ComputationToken.newBuilder() - .apply { - globalComputationId = GLOBAL_ID - localComputationId = LOCAL_ID - attempt = 1 - computationStage = COMPLETE.toProtocolStage() - version = 2 // claimTask + transitionStage - computationDetails = - computationDetailsWithoutPublicKey - .toBuilder() - .apply { endingState = CompletedReason.FAILED } - .build() - addAllRequisitions(listOf(requisition1, requisition2)) + computationToken { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 1 + computationStage = COMPLETE.toProtocolStage() + version = 2 // claimTask + transitionStage + this.computationDetails = computationDetails { + kingdomComputation = computationDetailsWithoutPublicKey.kingdomComputation + reachOnlyLiquidLegionsV2 = computationDetailsWithoutPublicKey.reachOnlyLiquidLegionsV2 + endingState = CompletedReason.FAILED } - .build() + requisitions.addAll(listOf(requisition1, requisition2)) + } ) verifyProtoArgument( @@ -858,36 +855,41 @@ class ReachOnlyLiquidLegionsV2MillTest { ) .comparingExpectedFieldsOnly() .isEqualTo( - FailComputationParticipantRequest.newBuilder() - .apply { - name = ComputationParticipantKey(GLOBAL_ID, DUCHY_ONE_NAME).toName() - failureBuilder.apply { + failComputationParticipantRequest { + name = ComputationParticipantKey(GLOBAL_ID, DUCHY_ONE_NAME).toName() + failure = + ComputationParticipantKt.failure { participantChildReferenceId = MILL_ID errorMessage = "PERMANENT error: java.lang.Exception: @Mill a nice mill, Computation 1234 " + "failed due to:\n" + "Cannot verify participation of all DataProviders.\n" + "Missing expected data for requisition 222." - stageAttemptBuilder.apply { + this.stageAttempt = stageAttempt { stage = CONFIRMATION_PHASE.number stageName = CONFIRMATION_PHASE.name attemptNumber = 1 } } - } - .build() + } ) } @Test fun `confirmation phase, passed at non-aggregator`() = runBlocking { // Stage 0. preparing the storage and set up mock - val computationDetailsWithoutPublicKey = - NON_AGGREGATOR_COMPUTATION_DETAILS.toBuilder() - .apply { - reachOnlyLiquidLegionsV2Builder.clearCombinedPublicKey().clearPartiallyCombinedPublicKey() + val computationDetailsWithoutPublicKey = computationDetails { + kingdomComputation = NON_AGGREGATOR_COMPUTATION_DETAILS.kingdomComputation + reachOnlyLiquidLegionsV2 = + ReachOnlyLiquidLegionsSketchAggregationV2Kt.computationDetails { + role = RoleInComputation.NON_AGGREGATOR + parameters = ROLLV2_PARAMETERS + participant += + listOf(COMPUTATION_PARTICIPANT_1, COMPUTATION_PARTICIPANT_2, COMPUTATION_PARTICIPANT_3) + partiallyCombinedPublicKey = PARTIALLY_COMBINED_PUBLIC_KEY + localElgamalKey = DUCHY_ONE_KEY_PAIR } - .build() + } fakeComputationDb.addComputation( globalId = GLOBAL_ID, stage = CONFIRMATION_PHASE.toProtocolStage(), @@ -904,25 +906,18 @@ class ReachOnlyLiquidLegionsV2MillTest { // Stage 2. Check the status of the computation assertThat(fakeComputationDb[LOCAL_ID]!!) .isEqualTo( - ComputationToken.newBuilder() - .apply { - globalComputationId = GLOBAL_ID - localComputationId = LOCAL_ID - attempt = 1 - computationStage = WAIT_TO_START.toProtocolStage() - version = 3 // claimTask + updateComputationDetail + transitionStage - computationDetails = - NON_AGGREGATOR_COMPUTATION_DETAILS.toBuilder() - .apply { - reachOnlyLiquidLegionsV2Builder.apply { - combinedPublicKey = COMBINED_PUBLIC_KEY - partiallyCombinedPublicKey = PARTIALLY_COMBINED_PUBLIC_KEY - } - } - .build() - addAllRequisitions(REQUISITIONS) + computationToken { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 1 + computationStage = WAIT_TO_START.toProtocolStage() + version = 3 // claimTask + updateComputationDetail + transitionStage + computationDetails = computationDetails { + kingdomComputation = NON_AGGREGATOR_COMPUTATION_DETAILS.kingdomComputation + reachOnlyLiquidLegionsV2 = NON_AGGREGATOR_COMPUTATION_DETAILS.reachOnlyLiquidLegionsV2 } - .build() + requisitions.addAll(REQUISITIONS) + } ) verifyProtoArgument( @@ -930,21 +925,26 @@ class ReachOnlyLiquidLegionsV2MillTest { SystemComputationParticipantsCoroutineImplBase::confirmComputationParticipant ) .isEqualTo( - ConfirmComputationParticipantRequest.newBuilder() - .apply { name = ComputationParticipantKey(GLOBAL_ID, DUCHY_ONE_NAME).toName() } - .build() + confirmComputationParticipantRequest { + name = ComputationParticipantKey(GLOBAL_ID, DUCHY_ONE_NAME).toName() + } ) } @Test fun `confirmation phase, passed at aggregator`() = runBlocking { // Stage 0. preparing the storage and set up mock - val computationDetailsWithoutPublicKey = - AGGREGATOR_COMPUTATION_DETAILS.toBuilder() - .apply { - reachOnlyLiquidLegionsV2Builder.clearCombinedPublicKey().clearPartiallyCombinedPublicKey() + val computationDetailsWithoutPublicKey = computationDetails { + kingdomComputation = AGGREGATOR_COMPUTATION_DETAILS.kingdomComputation + reachOnlyLiquidLegionsV2 = + ReachOnlyLiquidLegionsSketchAggregationV2Kt.computationDetails { + role = RoleInComputation.AGGREGATOR + parameters = ROLLV2_PARAMETERS + participant += + listOf(COMPUTATION_PARTICIPANT_2, COMPUTATION_PARTICIPANT_3, COMPUTATION_PARTICIPANT_1) + localElgamalKey = DUCHY_ONE_KEY_PAIR } - .build() + } fakeComputationDb.addComputation( globalId = GLOBAL_ID, stage = CONFIRMATION_PHASE.toProtocolStage(), @@ -961,32 +961,28 @@ class ReachOnlyLiquidLegionsV2MillTest { // Stage 2. Check the status of the computation assertThat(fakeComputationDb[LOCAL_ID]!!) .isEqualTo( - ComputationToken.newBuilder() - .apply { - globalComputationId = GLOBAL_ID - localComputationId = LOCAL_ID - attempt = 1 - computationStage = WAIT_SETUP_PHASE_INPUTS.toProtocolStage() - version = 3 // claimTask + updateComputationDetails + transitionStage - addAllBlobs(listOf(newEmptyOutputBlobMetadata(0), newEmptyOutputBlobMetadata(1))) - stageSpecificDetailsBuilder.apply { - reachOnlyLiquidLegionsV2Builder.waitSetupPhaseInputsDetailsBuilder.apply { - putExternalDuchyLocalBlobId("DUCHY_TWO", 0L) - putExternalDuchyLocalBlobId("DUCHY_THREE", 1L) - } - } - computationDetails = - AGGREGATOR_COMPUTATION_DETAILS.toBuilder() - .apply { - reachOnlyLiquidLegionsV2Builder.apply { - combinedPublicKey = COMBINED_PUBLIC_KEY - partiallyCombinedPublicKey = COMBINED_PUBLIC_KEY - } + computationToken { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 1 + computationStage = WAIT_SETUP_PHASE_INPUTS.toProtocolStage() + version = 3 // claimTask + updateComputationDetails + transitionStage + blobs.addAll(listOf(newEmptyOutputBlobMetadata(0), newEmptyOutputBlobMetadata(1))) + stageSpecificDetails = computationStageDetails { + reachOnlyLiquidLegionsV2 = stageDetails { + waitSetupPhaseInputsDetails = + ReachOnlyLiquidLegionsSketchAggregationV2Kt.waitSetupPhaseInputsDetails { + externalDuchyLocalBlobId.put("DUCHY_TWO", 0L) + externalDuchyLocalBlobId.put("DUCHY_THREE", 1L) } - .build() - addAllRequisitions(REQUISITIONS) + } } - .build() + computationDetails = computationDetails { + kingdomComputation = AGGREGATOR_COMPUTATION_DETAILS.kingdomComputation + reachOnlyLiquidLegionsV2 = AGGREGATOR_COMPUTATION_DETAILS.reachOnlyLiquidLegionsV2 + } + requisitions.addAll(REQUISITIONS) + } ) verifyProtoArgument( @@ -994,25 +990,40 @@ class ReachOnlyLiquidLegionsV2MillTest { SystemComputationParticipantsCoroutineImplBase::confirmComputationParticipant ) .isEqualTo( - ConfirmComputationParticipantRequest.newBuilder() - .apply { name = ComputationParticipantKey(GLOBAL_ID, DUCHY_ONE_NAME).toName() } - .build() + confirmComputationParticipantRequest { + name = ComputationParticipantKey(GLOBAL_ID, DUCHY_ONE_NAME).toName() + } ) } @Test fun `confirmation phase, failed due to invalid nonce and ElGamal key signature`() = runBlocking { + val COMPUTATION_PARTICIPANT_2_WITH_INVALID_SIGNATURE = computationParticipant { + duchyId = DUCHY_TWO_NAME + publicKey = DUCHY_TWO_PUBLIC_KEY + elGamalPublicKey = CONSENT_SIGNALING_EL_GAMAL_PUBLIC_KEY.toByteString() + elGamalPublicKeySignature = ByteString.copyFromUtf8("An invalid signature") + duchyCertificateDer = CONSENT_SIGNALING_CERT_DER + } // Stage 0. preparing the storage and set up mock - val computationDetailsWithoutInvalidDuchySignature = - AGGREGATOR_COMPUTATION_DETAILS.toBuilder() - .apply { - reachOnlyLiquidLegionsV2Builder.apply { - participantBuilderList[0].apply { - elGamalPublicKeySignature = ByteString.copyFromUtf8("An invalid signature") - } - } + val computationDetailsWithoutInvalidDuchySignature = computationDetails { + kingdomComputation = AGGREGATOR_COMPUTATION_DETAILS.kingdomComputation + reachOnlyLiquidLegionsV2 = + ReachOnlyLiquidLegionsSketchAggregationV2Kt.computationDetails { + role = RoleInComputation.AGGREGATOR + parameters = ROLLV2_PARAMETERS + participant += + listOf( + COMPUTATION_PARTICIPANT_2_WITH_INVALID_SIGNATURE, + COMPUTATION_PARTICIPANT_3, + COMPUTATION_PARTICIPANT_1 + ) + combinedPublicKey = COMBINED_PUBLIC_KEY + // partiallyCombinedPublicKey and combinedPublicKey are the same at the aggregator. + partiallyCombinedPublicKey = COMBINED_PUBLIC_KEY + localElgamalKey = DUCHY_ONE_KEY_PAIR } - .build() + } val requisitionWithInvalidNonce = REQUISITION_1.copy { details = details.copy { nonce = 404L } } fakeComputationDb.addComputation( globalId = GLOBAL_ID, @@ -1030,21 +1041,20 @@ class ReachOnlyLiquidLegionsV2MillTest { // Stage 2. Check the status of the computation assertThat(fakeComputationDb[LOCAL_ID]!!) .isEqualTo( - ComputationToken.newBuilder() - .apply { - globalComputationId = GLOBAL_ID - localComputationId = LOCAL_ID - attempt = 1 - computationStage = COMPLETE.toProtocolStage() - version = 2 // claimTask + transitionStage - computationDetails = - computationDetailsWithoutInvalidDuchySignature - .toBuilder() - .apply { endingState = CompletedReason.FAILED } - .build() - addAllRequisitions(listOf(requisitionWithInvalidNonce)) + computationToken { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 1 + computationStage = COMPLETE.toProtocolStage() + version = 2 // claimTask + transitionStage + computationDetails = computationDetails { + kingdomComputation = computationDetailsWithoutInvalidDuchySignature.kingdomComputation + reachOnlyLiquidLegionsV2 = + computationDetailsWithoutInvalidDuchySignature.reachOnlyLiquidLegionsV2 + endingState = CompletedReason.FAILED } - .build() + requisitions.addAll(listOf(requisitionWithInvalidNonce)) + } ) verifyProtoArgument( @@ -1053,24 +1063,23 @@ class ReachOnlyLiquidLegionsV2MillTest { ) .comparingExpectedFieldsOnly() .isEqualTo( - FailComputationParticipantRequest.newBuilder() - .apply { - name = ComputationParticipantKey(GLOBAL_ID, DUCHY_ONE_NAME).toName() - failureBuilder.apply { + failComputationParticipantRequest { + name = ComputationParticipantKey(GLOBAL_ID, DUCHY_ONE_NAME).toName() + failure = + ComputationParticipantKt.failure { participantChildReferenceId = MILL_ID errorMessage = "PERMANENT error: java.lang.Exception: @Mill a nice mill, Computation 1234 " + "failed due to:\n" + "Cannot verify participation of all DataProviders.\n" + "Invalid ElGamal public key signature for Duchy $DUCHY_TWO_NAME" - stageAttemptBuilder.apply { + this.stageAttempt = stageAttempt { stage = CONFIRMATION_PHASE.number stageName = CONFIRMATION_PHASE.name attemptNumber = 1 } } - } - .build() + } ) } @@ -1102,26 +1111,28 @@ class ReachOnlyLiquidLegionsV2MillTest { // Stage 2. Check the status of the computation assertThat(fakeComputationDb[LOCAL_ID]) .isEqualTo( - ComputationToken.newBuilder() - .apply { - globalComputationId = GLOBAL_ID - localComputationId = LOCAL_ID - attempt = 1 - computationStage = WAIT_EXECUTION_PHASE_INPUTS.toProtocolStage() - addBlobsBuilder().apply { - dependencyType = ComputationBlobDependency.INPUT - blobId = 0L - path = cachedBlobContext.blobKey - } - addBlobsBuilder().apply { - dependencyType = ComputationBlobDependency.OUTPUT - blobId = 1L - } - version = 2 // claimTask + transitionStage - computationDetails = NON_AGGREGATOR_COMPUTATION_DETAILS - addAllRequisitions(listOf(REQUISITION_1, REQUISITION_2, REQUISITION_3)) - } - .build() + computationToken { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 1 + computationStage = WAIT_EXECUTION_PHASE_INPUTS.toProtocolStage() + blobs.addAll( + listOf( + computationStageBlobMetadata { + dependencyType = ComputationBlobDependency.INPUT + blobId = 0L + path = cachedBlobContext.blobKey + }, + computationStageBlobMetadata { + dependencyType = ComputationBlobDependency.OUTPUT + blobId = 1L + } + ) + ) + version = 2 // claimTask + transitionStage + computationDetails = NON_AGGREGATOR_COMPUTATION_DETAILS + requisitions.addAll(listOf(REQUISITION_1, REQUISITION_2, REQUISITION_3)) + } ) assertThat(computationControlRequests) @@ -1156,12 +1167,10 @@ class ReachOnlyLiquidLegionsV2MillTest { whenever(mockCryptoWorker.completeReachOnlySetupPhase(any())).thenAnswer { cryptoRequest = it.getArgument(0) val postFix = ByteString.copyFromUtf8("-completeReachOnlySetupPhase") - CompleteReachOnlySetupPhaseResponse.newBuilder() - .apply { - combinedRegisterVector = cryptoRequest.combinedRegisterVector.concat(postFix) - serializedExcessiveNoiseCiphertext = ByteString.copyFromUtf8("-encryptedNoise") - } - .build() + completeReachOnlySetupPhaseResponse { + combinedRegisterVector = cryptoRequest.combinedRegisterVector.concat(postFix) + serializedExcessiveNoiseCiphertext = ByteString.copyFromUtf8("-encryptedNoise") + } } // Stage 1. Process the above computation @@ -1171,26 +1180,28 @@ class ReachOnlyLiquidLegionsV2MillTest { val blobKey = calculatedBlobContext.blobKey assertThat(fakeComputationDb[LOCAL_ID]) .isEqualTo( - ComputationToken.newBuilder() - .apply { - globalComputationId = GLOBAL_ID - localComputationId = LOCAL_ID - attempt = 1 - computationStage = WAIT_EXECUTION_PHASE_INPUTS.toProtocolStage() - addBlobsBuilder().apply { - dependencyType = ComputationBlobDependency.INPUT - blobId = 0L - path = blobKey - } - addBlobsBuilder().apply { - dependencyType = ComputationBlobDependency.OUTPUT - blobId = 1L - } - version = 3 // claimTask + writeOutputBlob + transitionStage - computationDetails = NON_AGGREGATOR_COMPUTATION_DETAILS - addAllRequisitions(listOf(REQUISITION_1, REQUISITION_2, REQUISITION_3)) - } - .build() + computationToken { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 1 + computationStage = WAIT_EXECUTION_PHASE_INPUTS.toProtocolStage() + blobs.addAll( + listOf( + computationStageBlobMetadata { + dependencyType = ComputationBlobDependency.INPUT + blobId = 0L + path = blobKey + }, + computationStageBlobMetadata { + dependencyType = ComputationBlobDependency.OUTPUT + blobId = 1L + } + ) + ) + version = 3 // claimTask + writeOutputBlob + transitionStage + computationDetails = NON_AGGREGATOR_COMPUTATION_DETAILS + requisitions.addAll(listOf(REQUISITION_1, REQUISITION_2, REQUISITION_3)) + } ) assertThat(computationStore.get(blobKey)?.readToString()) @@ -1258,26 +1269,28 @@ class ReachOnlyLiquidLegionsV2MillTest { // Stage 2. Check the status of the computation assertThat(fakeComputationDb[LOCAL_ID]) .isEqualTo( - ComputationToken.newBuilder() - .apply { - globalComputationId = GLOBAL_ID - localComputationId = LOCAL_ID - attempt = 1 - computationStage = WAIT_EXECUTION_PHASE_INPUTS.toProtocolStage() - addBlobsBuilder().apply { - dependencyType = ComputationBlobDependency.INPUT - blobId = 0L - path = cachedBlobContext.blobKey - } - addBlobsBuilder().apply { - dependencyType = ComputationBlobDependency.OUTPUT - blobId = 1L - } - version = 2 // claimTask + transitionStage - computationDetails = AGGREGATOR_COMPUTATION_DETAILS - addAllRequisitions(listOf(REQUISITION_1, REQUISITION_2, REQUISITION_3)) - } - .build() + computationToken { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 1 + computationStage = WAIT_EXECUTION_PHASE_INPUTS.toProtocolStage() + blobs.addAll( + listOf( + computationStageBlobMetadata { + dependencyType = ComputationBlobDependency.INPUT + blobId = 0L + path = cachedBlobContext.blobKey + }, + computationStageBlobMetadata { + dependencyType = ComputationBlobDependency.OUTPUT + blobId = 1L + } + ) + ) + version = 2 // claimTask + transitionStage + computationDetails = AGGREGATOR_COMPUTATION_DETAILS + requisitions.addAll(listOf(REQUISITION_1, REQUISITION_2, REQUISITION_3)) + } ) assertThat(computationControlRequests) @@ -1320,12 +1333,10 @@ class ReachOnlyLiquidLegionsV2MillTest { whenever(mockCryptoWorker.completeReachOnlySetupPhaseAtAggregator(any())).thenAnswer { cryptoRequest = it.getArgument(0) val postFix = ByteString.copyFromUtf8("-completeReachOnlySetupPhase") - CompleteReachOnlySetupPhaseResponse.newBuilder() - .apply { - combinedRegisterVector = cryptoRequest.combinedRegisterVector.concat(postFix) - serializedExcessiveNoiseCiphertext = ByteString.copyFromUtf8("-encryptedNoise") - } - .build() + completeReachOnlySetupPhaseResponse { + combinedRegisterVector = cryptoRequest.combinedRegisterVector.concat(postFix) + serializedExcessiveNoiseCiphertext = ByteString.copyFromUtf8("-encryptedNoise") + } } // Stage 1. Process the above computation @@ -1335,26 +1346,28 @@ class ReachOnlyLiquidLegionsV2MillTest { val blobKey = ComputationBlobContext(GLOBAL_ID, SETUP_PHASE.toProtocolStage(), 3L).blobKey assertThat(fakeComputationDb[LOCAL_ID]) .isEqualTo( - ComputationToken.newBuilder() - .apply { - globalComputationId = GLOBAL_ID - localComputationId = LOCAL_ID - attempt = 1 - computationStage = WAIT_EXECUTION_PHASE_INPUTS.toProtocolStage() - addBlobsBuilder().apply { - dependencyType = ComputationBlobDependency.INPUT - blobId = 0 - path = blobKey - } - addBlobsBuilder().apply { - dependencyType = ComputationBlobDependency.OUTPUT - blobId = 1 - } - version = 3 // claimTask + writeOutputBlob + transitionStage - computationDetails = AGGREGATOR_COMPUTATION_DETAILS - addAllRequisitions(listOf(REQUISITION_1, REQUISITION_2, REQUISITION_3)) - } - .build() + computationToken { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 1 + computationStage = WAIT_EXECUTION_PHASE_INPUTS.toProtocolStage() + blobs.addAll( + listOf( + computationStageBlobMetadata { + dependencyType = ComputationBlobDependency.INPUT + blobId = 0 + path = blobKey + }, + computationStageBlobMetadata { + dependencyType = ComputationBlobDependency.OUTPUT + blobId = 1 + } + ) + ) + version = 3 // claimTask + writeOutputBlob + transitionStage + computationDetails = AGGREGATOR_COMPUTATION_DETAILS + requisitions.addAll(listOf(REQUISITION_1, REQUISITION_2, REQUISITION_3)) + } ) assertThat(computationStore.get(blobKey)?.readToString()) @@ -1436,20 +1449,19 @@ class ReachOnlyLiquidLegionsV2MillTest { // Stage 2. Check the status of the computation assertThat(fakeComputationDb[LOCAL_ID]!!) .isEqualTo( - ComputationToken.newBuilder() - .apply { - globalComputationId = GLOBAL_ID - localComputationId = LOCAL_ID - attempt = 1 - computationStage = COMPLETE.toProtocolStage() - version = 2 // claimTask + transitionStage - computationDetails = - AGGREGATOR_COMPUTATION_DETAILS.toBuilder() - .apply { endingState = CompletedReason.FAILED } - .build() - addAllRequisitions(listOf(REQUISITION_1, REQUISITION_2, REQUISITION_3)) + computationToken { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 1 + computationStage = COMPLETE.toProtocolStage() + version = 2 // claimTask + transitionStage + computationDetails = computationDetails { + kingdomComputation = AGGREGATOR_COMPUTATION_DETAILS.kingdomComputation + reachOnlyLiquidLegionsV2 = AGGREGATOR_COMPUTATION_DETAILS.reachOnlyLiquidLegionsV2 + endingState = CompletedReason.FAILED } - .build() + requisitions.addAll(listOf(REQUISITION_1, REQUISITION_2, REQUISITION_3)) + } ) verifyProtoArgument( @@ -1458,22 +1470,21 @@ class ReachOnlyLiquidLegionsV2MillTest { ) .comparingExpectedFieldsOnly() .isEqualTo( - FailComputationParticipantRequest.newBuilder() - .apply { - name = ComputationParticipantKey(GLOBAL_ID, DUCHY_ONE_NAME).toName() - failureBuilder.apply { + failComputationParticipantRequest { + name = ComputationParticipantKey(GLOBAL_ID, DUCHY_ONE_NAME).toName() + failure = + ComputationParticipantKt.failure { participantChildReferenceId = MILL_ID errorMessage = "PERMANENT error: java.lang.IllegalArgumentException: Invalid input blob size. Input" + " blob duchy_2_sketch_ has size 15 which is less than (66)." - stageAttemptBuilder.apply { + this.stageAttempt = stageAttempt { stage = SETUP_PHASE.number stageName = SETUP_PHASE.name attemptNumber = 1 } } - } - .build() + } ) } @@ -1514,10 +1525,11 @@ class ReachOnlyLiquidLegionsV2MillTest { attempt = 1 computationStage = COMPLETE.toProtocolStage() version = 2 // claimTask + transitionStage - computationDetails = - NON_AGGREGATOR_COMPUTATION_DETAILS.toBuilder() - .apply { endingState = CompletedReason.SUCCEEDED } - .build() + computationDetails = computationDetails { + kingdomComputation = NON_AGGREGATOR_COMPUTATION_DETAILS.kingdomComputation + reachOnlyLiquidLegionsV2 = NON_AGGREGATOR_COMPUTATION_DETAILS.reachOnlyLiquidLegionsV2 + endingState = CompletedReason.SUCCEEDED + } requisitions += REQUISITIONS } ) @@ -1558,12 +1570,10 @@ class ReachOnlyLiquidLegionsV2MillTest { whenever(mockCryptoWorker.completeReachOnlyExecutionPhase(any())).thenAnswer { cryptoRequest = it.getArgument(0) val postFix = ByteString.copyFromUtf8("-completeReachOnlyExecutionPhase") - CompleteReachOnlyExecutionPhaseResponse.newBuilder() - .apply { - combinedRegisterVector = cryptoRequest.combinedRegisterVector.concat(postFix) - serializedExcessiveNoiseCiphertext = ByteString.copyFromUtf8("-partiallyDecryptedNoise") - } - .build() + completeReachOnlyExecutionPhaseResponse { + combinedRegisterVector = cryptoRequest.combinedRegisterVector.concat(postFix) + serializedExcessiveNoiseCiphertext = ByteString.copyFromUtf8("-partiallyDecryptedNoise") + } } // Stage 1. Process the above computation @@ -1573,20 +1583,19 @@ class ReachOnlyLiquidLegionsV2MillTest { val blobKey = calculatedBlobContext.blobKey assertThat(fakeComputationDb[LOCAL_ID]) .isEqualTo( - ComputationToken.newBuilder() - .apply { - globalComputationId = GLOBAL_ID - localComputationId = LOCAL_ID - attempt = 1 - computationStage = COMPLETE.toProtocolStage() - version = 3 // claimTask + writeOutputBlob + transitionStage - computationDetails = - NON_AGGREGATOR_COMPUTATION_DETAILS.toBuilder() - .apply { endingState = CompletedReason.SUCCEEDED } - .build() - addAllRequisitions(REQUISITIONS) + computationToken { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 1 + computationStage = COMPLETE.toProtocolStage() + version = 3 // claimTask + writeOutputBlob + transitionStage + computationDetails = computationDetails { + kingdomComputation = NON_AGGREGATOR_COMPUTATION_DETAILS.kingdomComputation + reachOnlyLiquidLegionsV2 = NON_AGGREGATOR_COMPUTATION_DETAILS.reachOnlyLiquidLegionsV2 + endingState = CompletedReason.SUCCEEDED } - .build() + requisitions.addAll(REQUISITIONS) + } ) assertThat(computationStore.get(blobKey)?.readToString()) .isEqualTo("data-completeReachOnlyExecutionPhase-partiallyDecryptedNoise") @@ -1634,20 +1643,19 @@ class ReachOnlyLiquidLegionsV2MillTest { // Stage 2. Check the status of the computation assertThat(fakeComputationDb[LOCAL_ID]!!) .isEqualTo( - ComputationToken.newBuilder() - .apply { - globalComputationId = GLOBAL_ID - localComputationId = LOCAL_ID - attempt = 1 - computationStage = COMPLETE.toProtocolStage() - version = 2 // claimTask + transitionStage - computationDetails = - NON_AGGREGATOR_COMPUTATION_DETAILS.toBuilder() - .apply { endingState = CompletedReason.FAILED } - .build() - addAllRequisitions(listOf(REQUISITION_1, REQUISITION_2, REQUISITION_3)) + computationToken { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 1 + computationStage = COMPLETE.toProtocolStage() + version = 2 // claimTask + transitionStage + computationDetails = computationDetails { + kingdomComputation = NON_AGGREGATOR_COMPUTATION_DETAILS.kingdomComputation + reachOnlyLiquidLegionsV2 = NON_AGGREGATOR_COMPUTATION_DETAILS.reachOnlyLiquidLegionsV2 + endingState = CompletedReason.FAILED } - .build() + requisitions.addAll(listOf(REQUISITION_1, REQUISITION_2, REQUISITION_3)) + } ) verifyProtoArgument( @@ -1656,21 +1664,20 @@ class ReachOnlyLiquidLegionsV2MillTest { ) .comparingExpectedFieldsOnly() .isEqualTo( - FailComputationParticipantRequest.newBuilder() - .apply { - name = ComputationParticipantKey(GLOBAL_ID, DUCHY_ONE_NAME).toName() - failureBuilder.apply { + failComputationParticipantRequest { + name = ComputationParticipantKey(GLOBAL_ID, DUCHY_ONE_NAME).toName() + failure = + ComputationParticipantKt.failure { participantChildReferenceId = MILL_ID errorMessage = "PERMANENT error: Invalid input blob size. Input blob data has size 4 which is less than (66)." - stageAttemptBuilder.apply { + this.stageAttempt = stageAttempt { stage = EXECUTION_PHASE.number stageName = EXECUTION_PHASE.name attemptNumber = 1 } } - } - .build() + } ) } @@ -1711,10 +1718,11 @@ class ReachOnlyLiquidLegionsV2MillTest { attempt = 1 computationStage = COMPLETE.toProtocolStage() version = 2 // claimTask + transitionStage - computationDetails = - AGGREGATOR_COMPUTATION_DETAILS.toBuilder() - .apply { endingState = CompletedReason.SUCCEEDED } - .build() + computationDetails = computationDetails { + kingdomComputation = AGGREGATOR_COMPUTATION_DETAILS.kingdomComputation + reachOnlyLiquidLegionsV2 = AGGREGATOR_COMPUTATION_DETAILS.reachOnlyLiquidLegionsV2 + endingState = CompletedReason.SUCCEEDED + } requisitions += REQUISITIONS } ) @@ -1755,9 +1763,7 @@ class ReachOnlyLiquidLegionsV2MillTest { var cryptoRequest = CompleteReachOnlyExecutionPhaseAtAggregatorRequest.getDefaultInstance() whenever(mockCryptoWorker.completeReachOnlyExecutionPhaseAtAggregator(any())).thenAnswer { cryptoRequest = it.getArgument(0) - CompleteReachOnlyExecutionPhaseAtAggregatorResponse.newBuilder() - .apply { reach = testReach } - .build() + completeReachOnlyExecutionPhaseAtAggregatorResponse { reach = testReach } } var systemComputationResult = SetComputationResultRequest.getDefaultInstance() whenever(mockSystemComputations.setComputationResult(any())).thenAnswer { @@ -1772,21 +1778,20 @@ class ReachOnlyLiquidLegionsV2MillTest { val blobKey = calculatedBlobContext.blobKey assertThat(fakeComputationDb[LOCAL_ID]) .isEqualTo( - ComputationToken.newBuilder() - .apply { - globalComputationId = GLOBAL_ID - localComputationId = LOCAL_ID - attempt = 1 - computationStage = COMPLETE.toProtocolStage() - version = 3 // claimTask + writeOutputBlob + transitionStage - computationDetails = - computationDetailsWithVidSamplingWidth - .toBuilder() - .apply { endingState = CompletedReason.SUCCEEDED } - .build() - addAllRequisitions(REQUISITIONS) + computationToken { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 1 + computationStage = COMPLETE.toProtocolStage() + version = 3 // claimTask + writeOutputBlob + transitionStage + computationDetails = computationDetails { + kingdomComputation = computationDetailsWithVidSamplingWidth.kingdomComputation + reachOnlyLiquidLegionsV2 = + computationDetailsWithVidSamplingWidth.reachOnlyLiquidLegionsV2 + endingState = CompletedReason.SUCCEEDED } - .build() + requisitions.addAll(REQUISITIONS) + } ) assertThat(computationStore.get(blobKey)?.readToString()).isNotEmpty() @@ -1866,20 +1871,19 @@ class ReachOnlyLiquidLegionsV2MillTest { // Stage 2. Check the status of the computation assertThat(fakeComputationDb[LOCAL_ID]!!) .isEqualTo( - ComputationToken.newBuilder() - .apply { - globalComputationId = GLOBAL_ID - localComputationId = LOCAL_ID - attempt = 1 - computationStage = COMPLETE.toProtocolStage() - version = 2 // claimTask + transitionStage - computationDetails = - AGGREGATOR_COMPUTATION_DETAILS.toBuilder() - .apply { endingState = CompletedReason.FAILED } - .build() - addAllRequisitions(listOf(REQUISITION_1, REQUISITION_2, REQUISITION_3)) + computationToken { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 1 + computationStage = COMPLETE.toProtocolStage() + version = 2 // claimTask + transitionStage + computationDetails = computationDetails { + kingdomComputation = AGGREGATOR_COMPUTATION_DETAILS.kingdomComputation + reachOnlyLiquidLegionsV2 = AGGREGATOR_COMPUTATION_DETAILS.reachOnlyLiquidLegionsV2 + endingState = CompletedReason.FAILED } - .build() + requisitions.addAll(listOf(REQUISITION_1, REQUISITION_2, REQUISITION_3)) + } ) verifyProtoArgument( @@ -1888,21 +1892,20 @@ class ReachOnlyLiquidLegionsV2MillTest { ) .comparingExpectedFieldsOnly() .isEqualTo( - FailComputationParticipantRequest.newBuilder() - .apply { - name = ComputationParticipantKey(GLOBAL_ID, DUCHY_ONE_NAME).toName() - failureBuilder.apply { + failComputationParticipantRequest { + name = ComputationParticipantKey(GLOBAL_ID, DUCHY_ONE_NAME).toName() + failure = + ComputationParticipantKt.failure { participantChildReferenceId = MILL_ID errorMessage = "PERMANENT error: Invalid input blob size. Input blob data has size 4 which is less than (66)." - stageAttemptBuilder.apply { + this.stageAttempt = stageAttempt { stage = EXECUTION_PHASE.number stageName = EXECUTION_PHASE.name attemptNumber = 1 } } - } - .build() + } ) } } diff --git a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/ReachOnlyLiquidLegionsV2EncryptionUtilityTest.kt b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/ReachOnlyLiquidLegionsV2EncryptionUtilityTest.kt index 8fa395c88f5..5f0efed089c 100644 --- a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/ReachOnlyLiquidLegionsV2EncryptionUtilityTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/ReachOnlyLiquidLegionsV2EncryptionUtilityTest.kt @@ -21,24 +21,25 @@ import kotlin.test.assertFailsWith import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.JUnit4 -import org.wfanet.anysketch.Sketch -import org.wfanet.anysketch.crypto.CombineElGamalPublicKeysRequest +import org.wfanet.anysketch.SketchKt import org.wfanet.anysketch.crypto.CombineElGamalPublicKeysResponse -import org.wfanet.anysketch.crypto.EncryptSketchRequest import org.wfanet.anysketch.crypto.EncryptSketchRequest.DestroyedRegisterStrategy.FLAGGED_KEY import org.wfanet.anysketch.crypto.EncryptSketchResponse import org.wfanet.anysketch.crypto.SketchEncrypterAdapter +import org.wfanet.anysketch.crypto.combineElGamalPublicKeysRequest +import org.wfanet.anysketch.crypto.encryptSketchRequest +import org.wfanet.anysketch.sketch import org.wfanet.estimation.Estimators import org.wfanet.measurement.common.loadLibrary import org.wfanet.measurement.duchy.daemon.utils.toAnySketchElGamalPublicKey import org.wfanet.measurement.duchy.daemon.utils.toCmmsElGamalPublicKey import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyExecutionPhaseAtAggregatorResponse import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyExecutionPhaseResponse -import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyInitializationPhaseRequest import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyInitializationPhaseResponse import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlySetupPhaseResponse import org.wfanet.measurement.internal.duchy.protocol.completeReachOnlyExecutionPhaseAtAggregatorRequest import org.wfanet.measurement.internal.duchy.protocol.completeReachOnlyExecutionPhaseRequest +import org.wfanet.measurement.internal.duchy.protocol.completeReachOnlyInitializationPhaseRequest import org.wfanet.measurement.internal.duchy.protocol.completeReachOnlySetupPhaseRequest import org.wfanet.measurement.internal.duchy.protocol.liquidLegionsSketchParameters import org.wfanet.measurement.internal.duchy.protocol.reachonlyliquidlegionsv2.ReachOnlyLiquidLegionsV2EncryptionUtility @@ -46,14 +47,6 @@ import org.wfanet.measurement.internal.duchy.protocol.reachonlyliquidlegionsv2.R @RunWith(JUnit4::class) class ReachOnlyLiquidLegionsV2EncryptionUtilityTest { - private fun createEmptyReachOnlyLiquidLegionsSketch(): Sketch.Builder { - return Sketch.newBuilder() - } - - private fun Sketch.Builder.addRegister(index: Long) { - addRegistersBuilder().also { it.index = index } - } - // Helper function to go through the entire Liquid Legions V2 protocol using the input data. // The final relative_frequency_distribution map are returned. private fun goThroughEntireMpcProtocol( @@ -163,26 +156,20 @@ class ReachOnlyLiquidLegionsV2EncryptionUtilityTest { @Test fun endToEnd_basicBehavior() { - val rawSketch = - createEmptyReachOnlyLiquidLegionsSketch() - .apply { - addRegister(index = 1L) - addRegister(index = 2L) - addRegister(index = 2L) - addRegister(index = 4L) - addRegister(index = 5L) - } - .build() - val request = - EncryptSketchRequest.newBuilder() - .apply { - sketch = rawSketch - curveId = CURVE_ID - maximumValue = MAX_COUNTER_VALUE - elGamalKeys = CLIENT_EL_GAMAL_KEYS.toAnySketchElGamalPublicKey() - destroyedRegisterStrategy = FLAGGED_KEY - } - .build() + val rawSketch = sketch { + registers.add(SketchKt.register { index = 1L }) + registers.add(SketchKt.register { index = 2L }) + registers.add(SketchKt.register { index = 2L }) + registers.add(SketchKt.register { index = 3L }) + registers.add(SketchKt.register { index = 4L }) + } + val request = encryptSketchRequest { + sketch = rawSketch + curveId = CURVE_ID + maximumValue = MAX_COUNTER_VALUE + elGamalKeys = CLIENT_EL_GAMAL_KEYS.toAnySketchElGamalPublicKey() + destroyedRegisterStrategy = FLAGGED_KEY + } val response = EncryptSketchResponse.parseFrom(SketchEncrypterAdapter.EncryptSketch(request.toByteArray())) val encryptedSketch = response.encryptedSketch @@ -252,8 +239,9 @@ class ReachOnlyLiquidLegionsV2EncryptionUtilityTest { private const val PARALLELISM = 3 private const val MAX_COUNTER_VALUE = 10 - private val COMPLETE_INITIALIZATION_REQUEST = - CompleteReachOnlyInitializationPhaseRequest.newBuilder().apply { curveId = CURVE_ID }.build() + private val COMPLETE_INITIALIZATION_REQUEST = completeReachOnlyInitializationPhaseRequest { + curveId = CURVE_ID + } private val DUCHY_1_EL_GAMAL_KEYS = CompleteReachOnlyInitializationPhaseResponse.parseFrom( ReachOnlyLiquidLegionsV2EncryptionUtility.completeReachOnlyInitializationPhase( @@ -279,14 +267,12 @@ class ReachOnlyLiquidLegionsV2EncryptionUtilityTest { private val CLIENT_EL_GAMAL_KEYS = CombineElGamalPublicKeysResponse.parseFrom( SketchEncrypterAdapter.CombineElGamalPublicKeys( - CombineElGamalPublicKeysRequest.newBuilder() - .apply { + combineElGamalPublicKeysRequest { curveId = CURVE_ID - addElGamalKeys(DUCHY_1_EL_GAMAL_KEYS.publicKey.toAnySketchElGamalPublicKey()) - addElGamalKeys(DUCHY_2_EL_GAMAL_KEYS.publicKey.toAnySketchElGamalPublicKey()) - addElGamalKeys(DUCHY_3_EL_GAMAL_KEYS.publicKey.toAnySketchElGamalPublicKey()) + elGamalKeys.add(DUCHY_1_EL_GAMAL_KEYS.publicKey.toAnySketchElGamalPublicKey()) + elGamalKeys.add(DUCHY_2_EL_GAMAL_KEYS.publicKey.toAnySketchElGamalPublicKey()) + elGamalKeys.add(DUCHY_3_EL_GAMAL_KEYS.publicKey.toAnySketchElGamalPublicKey()) } - .build() .toByteArray() ) ) @@ -295,13 +281,11 @@ class ReachOnlyLiquidLegionsV2EncryptionUtilityTest { private val DUCHY_2_3_COMBINED_EL_GAMAL_KEYS = CombineElGamalPublicKeysResponse.parseFrom( SketchEncrypterAdapter.CombineElGamalPublicKeys( - CombineElGamalPublicKeysRequest.newBuilder() - .apply { + combineElGamalPublicKeysRequest { curveId = CURVE_ID - addElGamalKeys(DUCHY_2_EL_GAMAL_KEYS.publicKey.toAnySketchElGamalPublicKey()) - addElGamalKeys(DUCHY_3_EL_GAMAL_KEYS.publicKey.toAnySketchElGamalPublicKey()) + elGamalKeys.add(DUCHY_2_EL_GAMAL_KEYS.publicKey.toAnySketchElGamalPublicKey()) + elGamalKeys.add(DUCHY_3_EL_GAMAL_KEYS.publicKey.toAnySketchElGamalPublicKey()) } - .build() .toByteArray() ) ) From 3ec6f0a0fd2383aae74f8927ed46b9fefe83a85b Mon Sep 17 00:00:00 2001 From: Phi Hung Le Date: Mon, 14 Aug 2023 18:20:35 +0000 Subject: [PATCH 14/15] Minor formatting the code. --- ...nlyLiquidLegionsV2EncryptionUtilityTest.kt | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/ReachOnlyLiquidLegionsV2EncryptionUtilityTest.kt b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/ReachOnlyLiquidLegionsV2EncryptionUtilityTest.kt index 5f0efed089c..82a81134d04 100644 --- a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/ReachOnlyLiquidLegionsV2EncryptionUtilityTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/ReachOnlyLiquidLegionsV2EncryptionUtilityTest.kt @@ -157,11 +157,11 @@ class ReachOnlyLiquidLegionsV2EncryptionUtilityTest { @Test fun endToEnd_basicBehavior() { val rawSketch = sketch { - registers.add(SketchKt.register { index = 1L }) - registers.add(SketchKt.register { index = 2L }) - registers.add(SketchKt.register { index = 2L }) - registers.add(SketchKt.register { index = 3L }) - registers.add(SketchKt.register { index = 4L }) + registers += SketchKt.register { index = 1L } + registers += SketchKt.register { index = 2L }) + registers += SketchKt.register { index = 2L }) + registers += SketchKt.register { index = 3L }) + registers += SketchKt.register { index = 4L }) } val request = encryptSketchRequest { sketch = rawSketch @@ -269,9 +269,9 @@ class ReachOnlyLiquidLegionsV2EncryptionUtilityTest { SketchEncrypterAdapter.CombineElGamalPublicKeys( combineElGamalPublicKeysRequest { curveId = CURVE_ID - elGamalKeys.add(DUCHY_1_EL_GAMAL_KEYS.publicKey.toAnySketchElGamalPublicKey()) - elGamalKeys.add(DUCHY_2_EL_GAMAL_KEYS.publicKey.toAnySketchElGamalPublicKey()) - elGamalKeys.add(DUCHY_3_EL_GAMAL_KEYS.publicKey.toAnySketchElGamalPublicKey()) + elGamalKeys += DUCHY_1_EL_GAMAL_KEYS.publicKey.toAnySketchElGamalPublicKey() + elGamalKeys += DUCHY_2_EL_GAMAL_KEYS.publicKey.toAnySketchElGamalPublicKey() + elGamalKeys += DUCHY_3_EL_GAMAL_KEYS.publicKey.toAnySketchElGamalPublicKey() } .toByteArray() ) @@ -283,8 +283,8 @@ class ReachOnlyLiquidLegionsV2EncryptionUtilityTest { SketchEncrypterAdapter.CombineElGamalPublicKeys( combineElGamalPublicKeysRequest { curveId = CURVE_ID - elGamalKeys.add(DUCHY_2_EL_GAMAL_KEYS.publicKey.toAnySketchElGamalPublicKey()) - elGamalKeys.add(DUCHY_3_EL_GAMAL_KEYS.publicKey.toAnySketchElGamalPublicKey()) + elGamalKeys += DUCHY_2_EL_GAMAL_KEYS.publicKey.toAnySketchElGamalPublicKey() + elGamalKeys += DUCHY_3_EL_GAMAL_KEYS.publicKey.toAnySketchElGamalPublicKey() } .toByteArray() ) From 93e40d5b7ddb4d1e70fcc098c37d0a762941a206 Mon Sep 17 00:00:00 2001 From: Phi Hung Le Date: Mon, 14 Aug 2023 18:24:59 +0000 Subject: [PATCH 15/15] Fix some typos. --- .../ReachOnlyLiquidLegionsV2EncryptionUtilityTest.kt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/ReachOnlyLiquidLegionsV2EncryptionUtilityTest.kt b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/ReachOnlyLiquidLegionsV2EncryptionUtilityTest.kt index 82a81134d04..8e122c56429 100644 --- a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/ReachOnlyLiquidLegionsV2EncryptionUtilityTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/ReachOnlyLiquidLegionsV2EncryptionUtilityTest.kt @@ -158,10 +158,10 @@ class ReachOnlyLiquidLegionsV2EncryptionUtilityTest { fun endToEnd_basicBehavior() { val rawSketch = sketch { registers += SketchKt.register { index = 1L } - registers += SketchKt.register { index = 2L }) - registers += SketchKt.register { index = 2L }) - registers += SketchKt.register { index = 3L }) - registers += SketchKt.register { index = 4L }) + registers += SketchKt.register { index = 2L } + registers += SketchKt.register { index = 2L } + registers += SketchKt.register { index = 3L } + registers += SketchKt.register { index = 4L } } val request = encryptSketchRequest { sketch = rawSketch