From 113a778daea2252df8c0df94d8960e002d75a534 Mon Sep 17 00:00:00 2001 From: Phi Date: Thu, 17 Aug 2023 00:47:23 +0000 Subject: [PATCH] Adding the class ReachOnlyLiquidLegionsV2Mill on the Duchy mill. (#1155) --- .../protocol/liquid_legions_v2/BUILD.bazel | 17 + ...d_legions_v2_encryption_utility_wrapper.cc | 62 + ...id_legions_v2_encryption_utility_wrapper.h | 45 + .../daemon/herald/LiquidLegionsV2Starter.kt | 2 +- .../daemon/mill/liquidlegionsv2/BUILD.bazel | 38 + .../liquidlegionsv2/LiquidLegionsV2Mill.kt | 4 +- .../ReachOnlyLiquidLegionsV2Mill.kt | 790 +++++++ .../mill/liquidlegionsv2/crypto/BUILD.bazel | 21 +- .../JniReachOnlyLiquidLegionsV2Encryption.kt | 107 + .../ReachOnlyLiquidLegionsV2Encryption.kt | 58 + .../testing/FakeComputationsDatabase.kt | 2 + ...liquid_legions_sketch_aggregation_v2.proto | 3 +- .../reachonlyliquidlegionsv2/BUILD.bazel | 16 + .../reachonlyliquidlegionsv2/README.md | 22 + ..._liquid_legions_v2_encryption_utility.swig | 67 + .../duchy/daemon/herald/HeraldTest.kt | 4 +- .../daemon/mill/liquidlegionsv2/BUILD.bazel | 30 + .../LiquidLegionsV2MillTest.kt | 23 +- .../ReachOnlyLiquidLegionsV2MillTest.kt | 1930 +++++++++++++++++ .../mill/liquidlegionsv2/crypto/BUILD.bazel | 32 + ...iReachOnlyLiquidLegionsV2EncryptionTest.kt | 43 + ...nlyLiquidLegionsV2EncryptionUtilityTest.kt | 295 +++ 22 files changed, 3590 insertions(+), 21 deletions(-) create mode 100644 src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility_wrapper.cc create mode 100644 src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility_wrapper.h create mode 100644 src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2Mill.kt create mode 100644 src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/JniReachOnlyLiquidLegionsV2Encryption.kt create mode 100644 src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/ReachOnlyLiquidLegionsV2Encryption.kt create mode 100644 src/main/swig/protocol/reachonlyliquidlegionsv2/BUILD.bazel create mode 100644 src/main/swig/protocol/reachonlyliquidlegionsv2/README.md create mode 100644 src/main/swig/protocol/reachonlyliquidlegionsv2/reach_only_liquid_legions_v2_encryption_utility.swig create mode 100644 src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2MillTest.kt create mode 100644 src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/JniReachOnlyLiquidLegionsV2EncryptionTest.kt create mode 100644 src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/ReachOnlyLiquidLegionsV2EncryptionUtilityTest.kt diff --git a/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/BUILD.bazel b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/BUILD.bazel index 79de5b8dfca..0dda2d3882e 100644 --- a/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/BUILD.bazel +++ b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/BUILD.bazel @@ -81,6 +81,23 @@ cc_library( ], ) +cc_library( + name = "reach_only_liquid_legions_v2_encryption_utility_wrapper", + srcs = [ + "reach_only_liquid_legions_v2_encryption_utility_wrapper.cc", + ], + hdrs = [ + "reach_only_liquid_legions_v2_encryption_utility_wrapper.h", + ], + strip_include_prefix = _INCLUDE_PREFIX, + deps = [ + ":reach_only_liquid_legions_v2_encryption_utility", + "//src/main/proto/wfa/measurement/internal/duchy/protocol:reach_only_liquid_legions_v2_encryption_methods_cc_proto", + "@wfa_common_cpp//src/main/cc/common_cpp/jni:jni_wrap", + "@wfa_common_cpp//src/main/cc/common_cpp/macros", + ], +) + cc_library( name = "noise_parameters_computation", srcs = [ diff --git a/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility_wrapper.cc b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility_wrapper.cc new file mode 100644 index 00000000000..8a1d995ba58 --- /dev/null +++ b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility_wrapper.cc @@ -0,0 +1,62 @@ +// Copyright 2023 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility_wrapper.h" + +#include + +#include "absl/status/statusor.h" +#include "common_cpp/jni/jni_wrap.h" +#include "common_cpp/macros/macros.h" +#include "wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility.h" +#include "wfa/measurement/internal/duchy/protocol/reach_only_liquid_legions_v2_encryption_methods.pb.h" + +namespace wfa::measurement::internal::duchy::protocol::liquid_legions_v2 { + +absl::StatusOr CompleteReachOnlyInitializationPhase( + const std::string& serialized_request) { + return JniWrap( + serialized_request, CompleteReachOnlyInitializationPhase); +} + +absl::StatusOr CompleteReachOnlySetupPhase( + const std::string& serialized_request) { + return JniWrap( + serialized_request, CompleteReachOnlySetupPhase); +} + +absl::StatusOr CompleteReachOnlySetupPhaseAtAggregator( + const std::string& serialized_request) { + return JniWrap( + serialized_request, CompleteReachOnlySetupPhaseAtAggregator); +} + +absl::StatusOr CompleteReachOnlyExecutionPhase( + const std::string& serialized_request) { + return JniWrap( + serialized_request, CompleteReachOnlyExecutionPhase); +} + +absl::StatusOr CompleteReachOnlyExecutionPhaseAtAggregator( + const std::string& serialized_request) { + return JniWrap( + serialized_request, CompleteReachOnlyExecutionPhaseAtAggregator); +} + +} // namespace wfa::measurement::internal::duchy::protocol::liquid_legions_v2 diff --git a/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility_wrapper.h b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility_wrapper.h new file mode 100644 index 00000000000..0b940ce936f --- /dev/null +++ b/src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility_wrapper.h @@ -0,0 +1,45 @@ +// Copyright 2023 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SRC_MAIN_CC_WFA_MEASUREMENT_INTERNAL_DUCHY_PROTOCOL_LIQUID_LEGIONS_V2_REACH_ONLY_LIQUID_LEGIONS_V2_ENCRYPTION_UTILITY_WRAPPER_H_ +#define SRC_MAIN_CC_WFA_MEASUREMENT_INTERNAL_DUCHY_PROTOCOL_LIQUID_LEGIONS_V2_REACH_ONLY_LIQUID_LEGIONS_V2_ENCRYPTION_UTILITY_WRAPPER_H_ + +#include + +#include "absl/status/statusor.h" + +// Wrapper methods used to generate the swig/JNI Java classes. +// The only functionality of these methods are converting between proto messages +// and their corresponding serialized strings, and then calling into the +// reach_only_liquid_legions_v2_encryption_utility methods. +namespace wfa::measurement::internal::duchy::protocol::liquid_legions_v2 { + +absl::StatusOr CompleteReachOnlyInitializationPhase( + const std::string& serialized_request); + +absl::StatusOr CompleteReachOnlySetupPhase( + const std::string& serialized_request); + +absl::StatusOr CompleteReachOnlySetupPhaseAtAggregator( + const std::string& serialized_request); + +absl::StatusOr CompleteReachOnlyExecutionPhase( + const std::string& serialized_request); + +absl::StatusOr CompleteReachOnlyExecutionPhaseAtAggregator( + const std::string& serialized_request); + +} // namespace wfa::measurement::internal::duchy::protocol::liquid_legions_v2 + +#endif // SRC_MAIN_CC_WFA_MEASUREMENT_INTERNAL_DUCHY_PROTOCOL_LIQUID_LEGIONS_V2_REACH_ONLY_LIQUID_LEGIONS_V2_ENCRYPTION_UTILITY_WRAPPER_H_ diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/herald/LiquidLegionsV2Starter.kt b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/herald/LiquidLegionsV2Starter.kt index 00ee53c174b..aef0566817d 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/herald/LiquidLegionsV2Starter.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/herald/LiquidLegionsV2Starter.kt @@ -298,7 +298,7 @@ object LiquidLegionsV2Starter { return parameters { maximumFrequency = llv2Config.maximumFrequency - liquidLegionsSketch = liquidLegionsSketchParameters { + sketchParameters = liquidLegionsSketchParameters { decayRate = llv2Config.sketchParams.decayRate size = llv2Config.sketchParams.maxSize } diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/BUILD.bazel index 98a1b9bbf23..442c31b904d 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/BUILD.bazel @@ -51,3 +51,41 @@ kt_jvm_library( "@wfa_consent_signaling_client//src/main/kotlin/org/wfanet/measurement/consent/client/duchy", ], ) + +kt_jvm_library( + name = "reach_only_liquid_legions_v2_mill", + testonly = True, #TODO: delete when InMemoryKeyStore and FakeHybridCipher are not used. + srcs = [ + "ReachOnlyLiquidLegionsV2Mill.kt", + ], + runtime_deps = ["@wfa_common_jvm//imports/java/io/grpc/netty"], + deps = [ + "//imports/java/io/opentelemetry/api", + "//src/main/kotlin/org/wfanet/measurement/api:public_api_version", + "//src/main/kotlin/org/wfanet/measurement/common/identity", + "//src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill:mill_base", + "//src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto:reachonlyliquidlegionsv2encryption", + "//src/main/kotlin/org/wfanet/measurement/duchy/daemon/utils:computation_conversions", + "//src/main/kotlin/org/wfanet/measurement/duchy/daemon/utils:duchy_order", + "//src/main/kotlin/org/wfanet/measurement/duchy/db/computation", + "//src/main/kotlin/org/wfanet/measurement/duchy/service/system/v1alpha:advance_computation_request_headers", + "//src/main/kotlin/org/wfanet/measurement/system/v1alpha:resource_key", + "//src/main/proto/wfa/any_sketch/crypto:sketch_encryption_methods_kt_jvm_proto", + "//src/main/proto/wfa/measurement/internal/duchy:crypto_kt_jvm_proto", + "//src/main/proto/wfa/measurement/internal/duchy:differential_privacy_kt_jvm_proto", + "//src/main/proto/wfa/measurement/internal/duchy/config:protocols_setup_config_kt_jvm_proto", + "//src/main/proto/wfa/measurement/internal/duchy/protocol:liquid_legions_sketch_parameter_kt_jvm_proto", + "//src/main/proto/wfa/measurement/internal/duchy/protocol:liquid_legions_v2_noise_config_kt_jvm_proto", + "//src/main/proto/wfa/measurement/internal/duchy/protocol:reach_only_liquid_legions_v2_encryption_methods_kt_jvm_proto", + "//src/main/proto/wfa/measurement/internal/duchy/protocol:reach_only_liquid_legions_v2_kt_jvm_proto", + "//src/main/proto/wfa/measurement/system/v1alpha:computation_control_service_kt_jvm_grpc_proto", + "//src/main/proto/wfa/measurement/system/v1alpha:computation_participants_service_kt_jvm_grpc_proto", + "//src/main/swig/protocol/reachonlyliquidlegionsv2:reach_only_liquid_legions_v2_encryption_utility", + "@wfa_common_jvm//imports/java/com/google/protobuf", + "@wfa_common_jvm//imports/java/io/grpc:api", + "@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/throttler", + "@wfa_consent_signaling_client//src/main/kotlin/org/wfanet/measurement/consent/client/duchy", + ], +) diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/LiquidLegionsV2Mill.kt b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/LiquidLegionsV2Mill.kt index fc51824e5a7..488301d7e4c 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/LiquidLegionsV2Mill.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/LiquidLegionsV2Mill.kt @@ -673,8 +673,8 @@ class LiquidLegionsV2Mill( flagCountTuples = readAndCombineAllInputBlobs(token, 1) maximumFrequency = maximumRequestedFrequency liquidLegionsParameters = liquidLegionsSketchParameters { - decayRate = llv2Parameters.liquidLegionsSketch.decayRate - size = llv2Parameters.liquidLegionsSketch.size + decayRate = llv2Parameters.sketchParameters.decayRate + size = llv2Parameters.sketchParameters.size } vidSamplingIntervalWidth = measurementSpec.vidSamplingInterval.width if (llv2Parameters.noise.hasReachNoiseConfig()) { diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2Mill.kt b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2Mill.kt new file mode 100644 index 00000000000..c47ec1c8433 --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2Mill.kt @@ -0,0 +1,790 @@ +// Copyright 2023 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package org.wfanet.measurement.duchy.daemon.mill.liquidlegionsv2 + +import com.google.protobuf.ByteString +import io.opentelemetry.api.OpenTelemetry +import io.opentelemetry.api.metrics.LongHistogram +import io.opentelemetry.api.metrics.Meter +import java.security.SignatureException +import java.security.cert.CertPathValidatorException +import java.security.cert.X509Certificate +import java.time.Clock +import java.time.Duration +import java.util.logging.Logger +import org.wfanet.anysketch.crypto.combineElGamalPublicKeysRequest +import org.wfanet.measurement.api.Version +import org.wfanet.measurement.api.v2alpha.MeasurementSpec +import org.wfanet.measurement.common.crypto.SigningKeyHandle +import org.wfanet.measurement.common.crypto.readCertificate +import org.wfanet.measurement.common.identity.DuchyInfo +import org.wfanet.measurement.common.throttler.MinimumIntervalThrottler +import org.wfanet.measurement.consent.client.duchy.encryptResult +import org.wfanet.measurement.consent.client.duchy.signElgamalPublicKey +import org.wfanet.measurement.consent.client.duchy.signResult +import org.wfanet.measurement.consent.client.duchy.verifyDataProviderParticipation +import org.wfanet.measurement.consent.client.duchy.verifyElGamalPublicKey +import org.wfanet.measurement.duchy.daemon.mill.CRYPTO_LIB_CPU_DURATION +import org.wfanet.measurement.duchy.daemon.mill.Certificate +import org.wfanet.measurement.duchy.daemon.mill.MillBase +import org.wfanet.measurement.duchy.daemon.mill.PermanentComputationError +import org.wfanet.measurement.duchy.daemon.mill.liquidlegionsv2.crypto.ReachOnlyLiquidLegionsV2Encryption +import org.wfanet.measurement.duchy.daemon.utils.ComputationResult +import org.wfanet.measurement.duchy.daemon.utils.ReachResult +import org.wfanet.measurement.duchy.daemon.utils.toAnySketchElGamalPublicKey +import org.wfanet.measurement.duchy.daemon.utils.toCmmsElGamalPublicKey +import org.wfanet.measurement.duchy.daemon.utils.toV2AlphaElGamalPublicKey +import org.wfanet.measurement.duchy.daemon.utils.toV2AlphaEncryptionPublicKey +import org.wfanet.measurement.duchy.db.computation.BlobRef +import org.wfanet.measurement.duchy.db.computation.ComputationDataClients +import org.wfanet.measurement.duchy.service.internal.computations.outputPathList +import org.wfanet.measurement.duchy.service.system.v1alpha.advanceComputationHeader +import org.wfanet.measurement.duchy.toProtocolStage +import org.wfanet.measurement.internal.duchy.ComputationDetails.CompletedReason +import org.wfanet.measurement.internal.duchy.ComputationDetails.KingdomComputationDetails +import org.wfanet.measurement.internal.duchy.ComputationStatsGrpcKt.ComputationStatsCoroutineStub +import org.wfanet.measurement.internal.duchy.ComputationToken +import org.wfanet.measurement.internal.duchy.ComputationTypeEnum.ComputationType +import org.wfanet.measurement.internal.duchy.ElGamalPublicKey +import org.wfanet.measurement.internal.duchy.RequisitionMetadata +import org.wfanet.measurement.internal.duchy.computationDetails +import org.wfanet.measurement.internal.duchy.config.LiquidLegionsV2SetupConfig.RoleInComputation.AGGREGATOR +import org.wfanet.measurement.internal.duchy.config.LiquidLegionsV2SetupConfig.RoleInComputation.NON_AGGREGATOR +import org.wfanet.measurement.internal.duchy.config.LiquidLegionsV2SetupConfig.RoleInComputation.UNKNOWN +import org.wfanet.measurement.internal.duchy.config.LiquidLegionsV2SetupConfig.RoleInComputation.UNRECOGNIZED +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyExecutionPhaseAtAggregatorResponse +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyExecutionPhaseResponse +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlySetupPhaseRequest +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlySetupPhaseResponse +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2 +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2.ComputationDetails.ComputationParticipant +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2.Stage +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2Kt +import org.wfanet.measurement.internal.duchy.protocol.completeReachOnlyExecutionPhaseAtAggregatorRequest +import org.wfanet.measurement.internal.duchy.protocol.completeReachOnlyExecutionPhaseRequest +import org.wfanet.measurement.internal.duchy.protocol.completeReachOnlyInitializationPhaseRequest +import org.wfanet.measurement.internal.duchy.protocol.completeReachOnlySetupPhaseRequest +import org.wfanet.measurement.internal.duchy.protocol.globalReachDpNoiseBaseline +import org.wfanet.measurement.internal.duchy.protocol.liquidLegionsSketchParameters +import org.wfanet.measurement.internal.duchy.protocol.reachNoiseDifferentialPrivacyParams +import org.wfanet.measurement.internal.duchy.protocol.registerNoiseGenerationParameters +import org.wfanet.measurement.internal.duchy.updateComputationDetailsRequest +import org.wfanet.measurement.system.v1alpha.ComputationControlGrpcKt.ComputationControlCoroutineStub +import org.wfanet.measurement.system.v1alpha.ComputationLogEntriesGrpcKt.ComputationLogEntriesCoroutineStub +import org.wfanet.measurement.system.v1alpha.ComputationParticipantKey +import org.wfanet.measurement.system.v1alpha.ComputationParticipantKt +import org.wfanet.measurement.system.v1alpha.ComputationParticipantKt.RequisitionParamsKt +import org.wfanet.measurement.system.v1alpha.ComputationParticipantsGrpcKt.ComputationParticipantsCoroutineStub +import org.wfanet.measurement.system.v1alpha.ComputationsGrpcKt +import org.wfanet.measurement.system.v1alpha.ReachOnlyLiquidLegionsV2 +import org.wfanet.measurement.system.v1alpha.confirmComputationParticipantRequest +import org.wfanet.measurement.system.v1alpha.reachOnlyLiquidLegionsV2 +import org.wfanet.measurement.system.v1alpha.setParticipantRequisitionParamsRequest + +/** + * Mill works on computations using the ReachOnlyLiquidLegionSketchAggregationProtocol. + * + * @param millId The identifier of this mill, used to claim a work. + * @param duchyId The identifier of this duchy who owns this mill. + * @param signingKey handle to a signing private key for consent signaling. + * @param consentSignalCert The [Certificate] used for consent signaling. + * @param trustedCertificates [Map] of SKID to trusted certificate + * @param dataClients clients that have access to local computation storage, i.e., spanner table and + * blob store. + * @param systemComputationParticipantsClient client of the kingdom's system + * ComputationParticipantsService. + * @param systemComputationsClient client of the kingdom's system computationsService. + * @param systemComputationLogEntriesClient client of the kingdom's system + * computationLogEntriesService. + * @param computationStatsClient client of the duchy's internal ComputationStatsService. + * @param throttler A throttler used to rate limit the frequency of the mill polling from the + * computation table. + * @param requestChunkSizeBytes The size of data chunk when sending result to other duchies. + * @param clock A clock + * @param maximumAttempts The maximum number of attempts on a computation at the same stage. + * @param workerStubs A map from other duchies' Ids to their corresponding + * computationControlClients, used for passing computation to other duchies. + * @param cryptoWorker The cryptoWorker that performs the actual computation. + * @param parallelism The maximum number of threads used for crypto actions. + */ +class ReachOnlyLiquidLegionsV2Mill( + millId: String, + duchyId: String, + signingKey: SigningKeyHandle, + consentSignalCert: Certificate, + private val trustedCertificates: Map, + dataClients: ComputationDataClients, + systemComputationParticipantsClient: ComputationParticipantsCoroutineStub, + systemComputationsClient: ComputationsGrpcKt.ComputationsCoroutineStub, + systemComputationLogEntriesClient: ComputationLogEntriesCoroutineStub, + computationStatsClient: ComputationStatsCoroutineStub, + throttler: MinimumIntervalThrottler, + private val workerStubs: Map, + private val cryptoWorker: ReachOnlyLiquidLegionsV2Encryption, + workLockDuration: Duration, + openTelemetry: OpenTelemetry, + requestChunkSizeBytes: Int = 1024 * 32, + maximumAttempts: Int = 10, + clock: Clock = Clock.systemUTC(), + private val parallelism: Int = 1, +) : + MillBase( + millId, + duchyId, + signingKey, + consentSignalCert, + dataClients, + systemComputationParticipantsClient, + systemComputationsClient, + systemComputationLogEntriesClient, + computationStatsClient, + throttler, + ComputationType.REACH_ONLY_LIQUID_LEGIONS_SKETCH_AGGREGATION_V2, + workLockDuration, + requestChunkSizeBytes, + maximumAttempts, + clock, + openTelemetry + ) { + private val meter: Meter = openTelemetry.getMeter(ReachOnlyLiquidLegionsV2Mill::class.java.name) + + private val initializationPhaseCryptoCpuTimeDurationHistogram: LongHistogram = + meter.histogramBuilder("initialization_phase_crypto_cpu_time_duration_millis").ofLongs().build() + + private val setupPhaseCryptoCpuTimeDurationHistogram: LongHistogram = + meter.histogramBuilder("setup_phase_crypto_cpu_time_duration_millis").ofLongs().build() + + private val executionPhaseCryptoCpuTimeDurationHistogram: LongHistogram = + meter.histogramBuilder("execution_phase_crypto_cpu_time_duration_millis").ofLongs().build() + + override val endingStage = + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.COMPLETE.toProtocolStage() + + private val actions = + mapOf( + Pair(Stage.INITIALIZATION_PHASE, AGGREGATOR) to ::initializationPhase, + Pair(Stage.INITIALIZATION_PHASE, NON_AGGREGATOR) to ::initializationPhase, + Pair(Stage.CONFIRMATION_PHASE, AGGREGATOR) to ::confirmationPhase, + Pair(Stage.CONFIRMATION_PHASE, NON_AGGREGATOR) to ::confirmationPhase, + Pair(Stage.SETUP_PHASE, AGGREGATOR) to ::completeSetupPhaseAtAggregator, + Pair(Stage.SETUP_PHASE, NON_AGGREGATOR) to ::completeSetupPhaseAtNonAggregator, + Pair(Stage.EXECUTION_PHASE, AGGREGATOR) to ::completeExecutionPhaseAtAggregator, + Pair(Stage.EXECUTION_PHASE, NON_AGGREGATOR) to ::completeExecutionPhaseAtNonAggregator, + ) + + private val kBytesPerCipherText = 66 + + override suspend fun processComputationImpl(token: ComputationToken) { + require(token.computationDetails.hasReachOnlyLiquidLegionsV2()) { + "Only Reach Only Liquid Legions V2 computation is supported in this mill." + } + val stage = token.computationStage.reachOnlyLiquidLegionsSketchAggregationV2 + val role = token.computationDetails.reachOnlyLiquidLegionsV2.role + val action = actions[Pair(stage, role)] ?: error("Unexpected stage or role: ($stage, $role)") + val updatedToken = action(token) + + val globalId = token.globalComputationId + val updatedStage = updatedToken.computationStage.reachOnlyLiquidLegionsSketchAggregationV2 + logger.info("$globalId@$millId: Stage transitioned from $stage to $updatedStage") + } + + /** Sends requisition params to the kingdom. */ + private suspend fun sendRequisitionParamsToKingdom(token: ComputationToken) { + val rollv2ComputationDetails = token.computationDetails.reachOnlyLiquidLegionsV2 + require(rollv2ComputationDetails.hasLocalElgamalKey()) { "Missing local elgamal key." } + val signedElgamalPublicKey = + when (Version.fromString(token.computationDetails.kingdomComputation.publicApiVersion)) { + Version.V2_ALPHA -> + signElgamalPublicKey( + rollv2ComputationDetails.localElgamalKey.publicKey.toV2AlphaElGamalPublicKey(), + signingKey + ) + Version.VERSION_UNSPECIFIED -> error("Public api version is invalid or unspecified.") + } + + val request = setParticipantRequisitionParamsRequest { + name = ComputationParticipantKey(token.globalComputationId, duchyId).toName() + requisitionParams = + ComputationParticipantKt.requisitionParams { + duchyCertificate = consentSignalCert.name + reachOnlyLiquidLegionsV2 = + ComputationParticipantKt.RequisitionParamsKt.liquidLegionsV2 { + elGamalPublicKey = signedElgamalPublicKey.data + elGamalPublicKeySignature = signedElgamalPublicKey.signature + } + } + } + systemComputationParticipantsClient.setParticipantRequisitionParams(request) + } + + /** Processes computation in the initialization phase */ + private suspend fun initializationPhase(token: ComputationToken): ComputationToken { + val rollv2ComputationDetails = token.computationDetails.reachOnlyLiquidLegionsV2 + val ellipticCurveId = rollv2ComputationDetails.parameters.ellipticCurveId + require(ellipticCurveId > 0) { "invalid ellipticCurveId $ellipticCurveId" } + + val nextToken = + if (rollv2ComputationDetails.hasLocalElgamalKey()) { + // Reuses the key if it is already generated for this computation. + token + } else { + // Generates a new set of ElGamalKeyPair. + val request = completeReachOnlyInitializationPhaseRequest { + curveId = ellipticCurveId.toLong() + } + val cryptoResult = cryptoWorker.completeReachOnlyInitializationPhase(request) + logStageDurationMetric( + token, + CRYPTO_LIB_CPU_DURATION, + cryptoResult.elapsedCpuTimeMillis, + initializationPhaseCryptoCpuTimeDurationHistogram + ) + + // Updates the newly generated localElgamalKey to the ComputationDetails. + dataClients.computationsClient + .updateComputationDetails( + updateComputationDetailsRequest { + this.token = token + this.details = computationDetails { + this.blobsStoragePrefix = token.computationDetails.blobsStoragePrefix + this.endingState = token.computationDetails.endingState + this.kingdomComputation = token.computationDetails.kingdomComputation + this.reachOnlyLiquidLegionsV2 = + ReachOnlyLiquidLegionsSketchAggregationV2Kt.computationDetails { + this.role = token.computationDetails.reachOnlyLiquidLegionsV2.role + this.parameters = token.computationDetails.reachOnlyLiquidLegionsV2.parameters + participant.addAll( + token.computationDetails.reachOnlyLiquidLegionsV2.participantList + ) + this.localElgamalKey = cryptoResult.elGamalKeyPair + } + } + } + ) + .token + } + + sendRequisitionParamsToKingdom(nextToken) + + return dataClients.transitionComputationToStage( + nextToken, + stage = Stage.WAIT_REQUISITIONS_AND_KEY_SET.toProtocolStage() + ) + } + + /** + * Verifies that all EDPs have participated. + * + * @return a list of error messages if anything is wrong, otherwise an empty list. + */ + private fun verifyEdpParticipation( + details: KingdomComputationDetails, + requisitions: Iterable, + ): List { + when (Version.fromString(details.publicApiVersion)) { + Version.V2_ALPHA -> {} + Version.VERSION_UNSPECIFIED -> error("Public api version is invalid or unspecified.") + } + + val errorList = mutableListOf() + val measurementSpec = MeasurementSpec.parseFrom(details.measurementSpec) + if (!verifyDataProviderParticipation(measurementSpec, requisitions.map { it.details.nonce })) { + errorList.add("Cannot verify participation of all DataProviders.") + } + for (requisition in requisitions) { + if (requisition.details.externalFulfillingDuchyId == duchyId && requisition.path.isBlank()) { + errorList.add( + "Missing expected data for requisition ${requisition.externalKey.externalRequisitionId}." + ) + } + } + return errorList + } + + /** + * Verifies the ElGamal public key of [duchy]. + * + * @return the error message if verification fails, or else `null` + */ + private fun verifyDuchySignature( + duchy: ComputationParticipant, + publicApiVersion: Version + ): String? { + val duchyInfo: DuchyInfo.Entry = + requireNotNull(DuchyInfo.getByDuchyId(duchy.duchyId)) { + "DuchyInfo not found for ${duchy.duchyId}" + } + when (publicApiVersion) { + Version.V2_ALPHA -> { + try { + verifyElGamalPublicKey( + duchy.elGamalPublicKey, + duchy.elGamalPublicKeySignature, + readCertificate(duchy.duchyCertificateDer), + trustedCertificates.getValue(duchyInfo.rootCertificateSkid) + ) + } catch (e: CertPathValidatorException) { + return "Certificate path invalid for Duchy ${duchy.duchyId}" + } catch (e: SignatureException) { + return "Invalid ElGamal public key signature for Duchy ${duchy.duchyId}" + } + } + Version.VERSION_UNSPECIFIED -> error("Public api version is invalid or unspecified.") + } + return null + } + + /** Fails a computation both locally and at the kingdom when the confirmation fails. */ + private fun failComputationAtConfirmationPhase( + token: ComputationToken, + errorList: List + ): ComputationToken { + val errorMessage = + "@Mill $millId, Computation ${token.globalComputationId} failed due to:\n" + + errorList.joinToString(separator = "\n") + throw PermanentComputationError(Exception(errorMessage)) + } + + private fun List.toCombinedPublicKey(curveId: Int): ElGamalPublicKey { + val request = combineElGamalPublicKeysRequest { + this.curveId = curveId.toLong() + this.elGamalKeys += map { it.toAnySketchElGamalPublicKey() } + } + return cryptoWorker.combineElGamalPublicKeys(request).elGamalKeys.toCmmsElGamalPublicKey() + } + + /** + * Computes the fully and partially combined Elgamal public keys and caches the result in the + * computationDetails. + */ + private suspend fun updatePublicElgamalKey(token: ComputationToken): ComputationToken { + val rollv2Details = token.computationDetails.reachOnlyLiquidLegionsV2 + val fullParticipantList = rollv2Details.participantList + val combinedPublicKey = + fullParticipantList + .map { it.publicKey } + .toCombinedPublicKey(rollv2Details.parameters.ellipticCurveId) + + @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. + val partiallyCombinedPublicKey = + when (rollv2Details.role) { + // For aggregator, the partial list is the same as the full list. + AGGREGATOR -> combinedPublicKey + NON_AGGREGATOR -> + fullParticipantList + .subList( + fullParticipantList.indexOfFirst { it.duchyId == duchyId } + 1, + fullParticipantList.size + ) + .map { it.publicKey } + .toCombinedPublicKey(rollv2Details.parameters.ellipticCurveId) + UNKNOWN, + UNRECOGNIZED -> error("Invalid role ${rollv2Details.role}") + } + + return dataClients.computationsClient + .updateComputationDetails( + updateComputationDetailsRequest { + this.token = token + details = computationDetails { + this.blobsStoragePrefix = token.computationDetails.blobsStoragePrefix + this.endingState = token.computationDetails.endingState + this.kingdomComputation = token.computationDetails.kingdomComputation + this.reachOnlyLiquidLegionsV2 = + ReachOnlyLiquidLegionsSketchAggregationV2Kt.computationDetails { + this.role = token.computationDetails.reachOnlyLiquidLegionsV2.role + this.parameters = token.computationDetails.reachOnlyLiquidLegionsV2.parameters + token.computationDetails.reachOnlyLiquidLegionsV2.participantList.forEach { + this.participant += it + } + this.localElgamalKey = + token.computationDetails.reachOnlyLiquidLegionsV2.localElgamalKey + this.combinedPublicKey = combinedPublicKey + this.partiallyCombinedPublicKey = partiallyCombinedPublicKey + } + } + } + ) + .token + } + + /** Sends confirmation to the kingdom and transits the local computation to the next stage. */ + private suspend fun passConfirmationPhase(token: ComputationToken): ComputationToken { + systemComputationParticipantsClient.confirmComputationParticipant( + confirmComputationParticipantRequest { + name = ComputationParticipantKey(token.globalComputationId, duchyId).toName() + } + ) + val latestToken = updatePublicElgamalKey(token) + return dataClients.transitionComputationToStage( + latestToken, + stage = + when (checkNotNull(token.computationDetails.reachOnlyLiquidLegionsV2.role)) { + AGGREGATOR -> Stage.WAIT_SETUP_PHASE_INPUTS.toProtocolStage() + NON_AGGREGATOR -> Stage.WAIT_TO_START.toProtocolStage() + else -> + error("Unknown role: ${latestToken.computationDetails.reachOnlyLiquidLegionsV2.role}") + } + ) + } + + /** Processes computation in the confirmation phase */ + private suspend fun confirmationPhase(token: ComputationToken): ComputationToken { + val errorList = mutableListOf() + val kingdomComputation = token.computationDetails.kingdomComputation + errorList.addAll(verifyEdpParticipation(kingdomComputation, token.requisitionsList)) + token.computationDetails.reachOnlyLiquidLegionsV2.participantList.forEach { + verifyDuchySignature(it, Version.fromString(kingdomComputation.publicApiVersion))?.also { + error -> + errorList.add(error) + } + } + return if (errorList.isEmpty()) { + passConfirmationPhase(token) + } else { + failComputationAtConfirmationPhase(token, errorList) + } + } + + private suspend fun completeSetupPhaseAtAggregator(token: ComputationToken): ComputationToken { + val rollv2Details = token.computationDetails.reachOnlyLiquidLegionsV2 + require(AGGREGATOR == rollv2Details.role) { "invalid role for this function." } + val (bytes, nextToken) = + existingOutputOr(token) { + val request = + dataClients + .readAllRequisitionBlobs(token, duchyId) + .concat(readAndCombineAllInputBlobsSetupPhaseAtAggregator(token, workerStubs.size)) + .toCompleteSetupPhaseAtAggregatorRequest(rollv2Details, token.requisitionsCount) + val cryptoResult: CompleteReachOnlySetupPhaseResponse = + cryptoWorker.completeReachOnlySetupPhaseAtAggregator(request) + logStageDurationMetric( + token, + CRYPTO_LIB_CPU_DURATION, + cryptoResult.elapsedCpuTimeMillis, + setupPhaseCryptoCpuTimeDurationHistogram + ) + // The nextToken consists of the CRV and the noise ciphertext. + cryptoResult.combinedRegisterVector.concat(cryptoResult.serializedExcessiveNoiseCiphertext) + } + + sendAdvanceComputationRequest( + header = + advanceComputationHeader( + ReachOnlyLiquidLegionsV2.Description.EXECUTION_PHASE_INPUT, + token.globalComputationId + ), + content = addLoggingHook(token, bytes), + stub = nextDuchyStub(rollv2Details.participantList) + ) + + return dataClients.transitionComputationToStage( + nextToken, + inputsToNextStage = nextToken.outputPathList(), + stage = Stage.WAIT_EXECUTION_PHASE_INPUTS.toProtocolStage() + ) + } + + private suspend fun completeSetupPhaseAtNonAggregator(token: ComputationToken): ComputationToken { + val rollv2Details = token.computationDetails.reachOnlyLiquidLegionsV2 + require(NON_AGGREGATOR == rollv2Details.role) { "invalid role for this function." } + val (bytes, nextToken) = + existingOutputOr(token) { + val request = + dataClients + .readAllRequisitionBlobs(token, duchyId) + .toCompleteReachOnlySetupPhaseRequest(rollv2Details, token.requisitionsCount) + val cryptoResult: CompleteReachOnlySetupPhaseResponse = + cryptoWorker.completeReachOnlySetupPhase(request) + logStageDurationMetric( + token, + CRYPTO_LIB_CPU_DURATION, + cryptoResult.elapsedCpuTimeMillis, + setupPhaseCryptoCpuTimeDurationHistogram + ) + // The nextToken consists of the CRV and the noise ciphertext. + cryptoResult.combinedRegisterVector.concat(cryptoResult.serializedExcessiveNoiseCiphertext) + } + + sendAdvanceComputationRequest( + header = + advanceComputationHeader( + ReachOnlyLiquidLegionsV2.Description.SETUP_PHASE_INPUT, + token.globalComputationId + ), + content = addLoggingHook(token, bytes), + stub = aggregatorDuchyStub(rollv2Details.participantList.last().duchyId) + ) + + return dataClients.transitionComputationToStage( + nextToken, + inputsToNextStage = nextToken.outputPathList(), + stage = Stage.WAIT_EXECUTION_PHASE_INPUTS.toProtocolStage() + ) + } + + private suspend fun completeExecutionPhaseAtAggregator( + token: ComputationToken + ): ComputationToken { + val rollv2Details = token.computationDetails.reachOnlyLiquidLegionsV2 + require(AGGREGATOR == rollv2Details.role) { "invalid role for this function." } + val rollv2Parameters = rollv2Details.parameters + val noiseConfig = rollv2Details.parameters.noise + val measurementSpec = + MeasurementSpec.parseFrom(token.computationDetails.kingdomComputation.measurementSpec) + val inputBlob = readAndCombineAllInputBlobs(token, 1) + require(inputBlob.size() >= kBytesPerCipherText) { + ("Invalid input blob size. Input blob ${inputBlob.toStringUtf8()} has size " + + "${inputBlob.size()} which is less than ($kBytesPerCipherText).") + } + var reach = 0L + val (bytes, nextToken) = + existingOutputOr(token) { + val request = completeReachOnlyExecutionPhaseAtAggregatorRequest { + combinedRegisterVector = inputBlob.substring(0, inputBlob.size() - kBytesPerCipherText) + localElGamalKeyPair = rollv2Details.localElgamalKey + curveId = rollv2Details.parameters.ellipticCurveId.toLong() + serializedExcessiveNoiseCiphertext = + inputBlob.substring(inputBlob.size() - kBytesPerCipherText, inputBlob.size()) + if (rollv2Parameters.noise.hasReachNoiseConfig()) { + reachDpNoiseBaseline = globalReachDpNoiseBaseline { + contributorsCount = workerStubs.size + 1 + globalReachDpNoise = rollv2Parameters.noise.reachNoiseConfig.globalReachDpNoise + } + } + liquidLegionsParameters = liquidLegionsSketchParameters { + decayRate = rollv2Parameters.sketchParameters.decayRate + size = rollv2Parameters.sketchParameters.size + } + vidSamplingIntervalWidth = measurementSpec.vidSamplingInterval.width + if (noiseConfig.hasReachNoiseConfig()) { + noiseParameters = registerNoiseGenerationParameters { + compositeElGamalPublicKey = rollv2Details.combinedPublicKey + curveId = rollv2Details.parameters.ellipticCurveId.toLong() + contributorsCount = workerStubs.size + 1 + totalSketchesCount = token.requisitionsCount + dpParams = reachNoiseDifferentialPrivacyParams { + blindHistogram = noiseConfig.reachNoiseConfig.blindHistogramNoise + noiseForPublisherNoise = noiseConfig.reachNoiseConfig.noiseForPublisherNoise + globalReachDpNoise = noiseConfig.reachNoiseConfig.globalReachDpNoise + } + } + noiseMechanism = rollv2Details.parameters.noise.noiseMechanism + } + parallelism = this@ReachOnlyLiquidLegionsV2Mill.parallelism + } + val cryptoResult: CompleteReachOnlyExecutionPhaseAtAggregatorResponse = + cryptoWorker.completeReachOnlyExecutionPhaseAtAggregator(request) + logStageDurationMetric( + token, + CRYPTO_LIB_CPU_DURATION, + cryptoResult.elapsedCpuTimeMillis, + executionPhaseCryptoCpuTimeDurationHistogram + ) + reach = cryptoResult.reach + cryptoResult.toByteString() + } + + sendResultToKingdom(token, ReachResult(reach)) + return completeComputation(nextToken, CompletedReason.SUCCEEDED) + } + + private suspend fun completeExecutionPhaseAtNonAggregator( + token: ComputationToken + ): ComputationToken { + val rollv2Details = token.computationDetails.reachOnlyLiquidLegionsV2 + require(NON_AGGREGATOR == rollv2Details.role) { "invalid role for this function." } + val inputBlob = readAndCombineAllInputBlobs(token, 1) + require(inputBlob.size() >= kBytesPerCipherText) { + ("Invalid input blob size. Input blob ${inputBlob.toStringUtf8()} has size " + + "${inputBlob.size()} which is less than ($kBytesPerCipherText).") + } + val (bytes, nextToken) = + existingOutputOr(token) { + val cryptoResult: CompleteReachOnlyExecutionPhaseResponse = + cryptoWorker.completeReachOnlyExecutionPhase( + completeReachOnlyExecutionPhaseRequest { + combinedRegisterVector = + inputBlob.substring(0, inputBlob.size() - kBytesPerCipherText) + localElGamalKeyPair = rollv2Details.localElgamalKey + curveId = rollv2Details.parameters.ellipticCurveId.toLong() + serializedExcessiveNoiseCiphertext = + inputBlob.substring(inputBlob.size() - kBytesPerCipherText, inputBlob.size()) + parallelism = this@ReachOnlyLiquidLegionsV2Mill.parallelism + } + ) + logStageDurationMetric( + token, + CRYPTO_LIB_CPU_DURATION, + cryptoResult.elapsedCpuTimeMillis, + executionPhaseCryptoCpuTimeDurationHistogram + ) + cryptoResult.combinedRegisterVector.concat(cryptoResult.serializedExcessiveNoiseCiphertext) + } + + // Passes the computation to the next duchy. + sendAdvanceComputationRequest( + header = + advanceComputationHeader( + ReachOnlyLiquidLegionsV2.Description.EXECUTION_PHASE_INPUT, + token.globalComputationId + ), + content = addLoggingHook(token, bytes), + stub = nextDuchyStub(rollv2Details.participantList) + ) + + return completeComputation(nextToken, CompletedReason.SUCCEEDED) + } + + private suspend fun sendResultToKingdom( + token: ComputationToken, + computationResult: ComputationResult + ) { + val kingdomComputation = token.computationDetails.kingdomComputation + val serializedPublicApiEncryptionPublicKey: ByteString + val encryptedResult = + when (Version.fromString(kingdomComputation.publicApiVersion)) { + Version.V2_ALPHA -> { + val signedResult = signResult(computationResult.toV2AlphaMeasurementResult(), signingKey) + val publicApiEncryptionPublicKey = + kingdomComputation.measurementPublicKey.toV2AlphaEncryptionPublicKey() + serializedPublicApiEncryptionPublicKey = publicApiEncryptionPublicKey.toByteString() + encryptResult(signedResult, publicApiEncryptionPublicKey) + } + Version.VERSION_UNSPECIFIED -> error("Public api version is invalid or unspecified.") + } + sendResultToKingdom( + globalId = token.globalComputationId, + certificate = consentSignalCert, + resultPublicKey = serializedPublicApiEncryptionPublicKey, + encryptedResult = encryptedResult + ) + } + + private fun nextDuchyStub( + duchyList: List + ): ComputationControlCoroutineStub { + val index = duchyList.indexOfFirst { it.duchyId == duchyId } + val nextDuchy = duchyList[(index + 1) % duchyList.size].duchyId + return workerStubs[nextDuchy] + ?: throw PermanentComputationError( + IllegalArgumentException("No ComputationControlService stub for next duchy '$nextDuchy'") + ) + } + + private fun aggregatorDuchyStub(aggregatorId: String): ComputationControlCoroutineStub { + return workerStubs[aggregatorId] + ?: throw PermanentComputationError( + IllegalArgumentException( + "No ComputationControlService stub for the Aggregator duchy '$aggregatorId'" + ) + ) + } + + private fun ByteString.toCompleteReachOnlySetupPhaseRequest( + rollv2Details: ReachOnlyLiquidLegionsSketchAggregationV2.ComputationDetails, + totalRequisitionsCount: Int + ): CompleteReachOnlySetupPhaseRequest { + val noiseConfig = rollv2Details.parameters.noise + return completeReachOnlySetupPhaseRequest { + combinedRegisterVector = this@toCompleteReachOnlySetupPhaseRequest + curveId = rollv2Details.parameters.ellipticCurveId.toLong() + if (noiseConfig.hasReachNoiseConfig()) { + noiseParameters = registerNoiseGenerationParameters { + compositeElGamalPublicKey = rollv2Details.combinedPublicKey + curveId = rollv2Details.parameters.ellipticCurveId.toLong() + contributorsCount = workerStubs.size + 1 + totalSketchesCount = totalRequisitionsCount + dpParams = reachNoiseDifferentialPrivacyParams { + blindHistogram = noiseConfig.reachNoiseConfig.blindHistogramNoise + noiseForPublisherNoise = noiseConfig.reachNoiseConfig.noiseForPublisherNoise + globalReachDpNoise = noiseConfig.reachNoiseConfig.globalReachDpNoise + } + } + noiseMechanism = rollv2Details.parameters.noise.noiseMechanism + } + compositeElGamalPublicKey = rollv2Details.combinedPublicKey + serializedExcessiveNoiseCiphertext = ByteString.EMPTY + parallelism = this@ReachOnlyLiquidLegionsV2Mill.parallelism + } + } + + private fun ByteString.toCompleteSetupPhaseAtAggregatorRequest( + rollv2Details: ReachOnlyLiquidLegionsSketchAggregationV2.ComputationDetails, + totalRequisitionsCount: Int + ): CompleteReachOnlySetupPhaseRequest { + val noiseConfig = rollv2Details.parameters.noise + val combinedInputBlobs = this@toCompleteSetupPhaseAtAggregatorRequest + return completeReachOnlySetupPhaseRequest { + combinedRegisterVector = + combinedInputBlobs.substring( + 0, + combinedInputBlobs.size() - workerStubs.size * kBytesPerCipherText + ) + curveId = rollv2Details.parameters.ellipticCurveId.toLong() + if (noiseConfig.hasReachNoiseConfig()) { + noiseParameters = registerNoiseGenerationParameters { + compositeElGamalPublicKey = rollv2Details.combinedPublicKey + curveId = rollv2Details.parameters.ellipticCurveId.toLong() + contributorsCount = workerStubs.size + 1 + totalSketchesCount = totalRequisitionsCount + dpParams = reachNoiseDifferentialPrivacyParams { + blindHistogram = noiseConfig.reachNoiseConfig.blindHistogramNoise + noiseForPublisherNoise = noiseConfig.reachNoiseConfig.noiseForPublisherNoise + globalReachDpNoise = noiseConfig.reachNoiseConfig.globalReachDpNoise + } + } + noiseMechanism = rollv2Details.parameters.noise.noiseMechanism + } + compositeElGamalPublicKey = rollv2Details.combinedPublicKey + serializedExcessiveNoiseCiphertext = + combinedInputBlobs.substring( + combinedInputBlobs.size() - workerStubs.size * kBytesPerCipherText, + combinedInputBlobs.size() + ) + parallelism = this@ReachOnlyLiquidLegionsV2Mill.parallelism + } + } + + /** Reads all input blobs and combines all the bytes together. */ + protected suspend fun readAndCombineAllInputBlobsSetupPhaseAtAggregator( + token: ComputationToken, + count: Int + ): ByteString { + val blobMap: Map = dataClients.readInputBlobs(token) + if (blobMap.size != count) { + throw PermanentComputationError( + Exception("Unexpected number of input blobs. expected $count, actual ${blobMap.size}.") + ) + } + var combinedRegisterVector = ByteString.EMPTY + var combinedNoiseCiphertext = ByteString.EMPTY + for (str in blobMap.values) { + require(str.size() >= kBytesPerCipherText) { + ("Invalid input blob size. Input blob ${str.toStringUtf8()} has size " + + "${str.size()} which is less than ($kBytesPerCipherText).") + } + combinedRegisterVector = + combinedRegisterVector.concat(str.substring(0, str.size() - kBytesPerCipherText)) + combinedNoiseCiphertext = + combinedNoiseCiphertext.concat(str.substring(str.size() - kBytesPerCipherText, str.size())) + } + return combinedRegisterVector.concat(combinedNoiseCiphertext) + } + + companion object { + private val logger: Logger = Logger.getLogger(this::class.java.name) + } +} diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/BUILD.bazel index 75ea8a3a617..12de420eaac 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/BUILD.bazel @@ -4,7 +4,10 @@ package(default_visibility = ["//visibility:public"]) kt_jvm_library( name = "liquidlegionsv2encryption", - srcs = glob(["*.kt"]), + srcs = [ + "JniLiquidLegionsV2Encryption.kt", + "LiquidLegionsV2Encryption.kt", + ], deps = [ "//src/main/proto/wfa/any_sketch/crypto:sketch_encryption_methods_kt_jvm_proto", "//src/main/proto/wfa/measurement/internal/duchy/protocol:liquid_legions_v2_encryption_methods_kt_jvm_proto", @@ -14,3 +17,19 @@ kt_jvm_library( "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common", ], ) + +kt_jvm_library( + name = "reachonlyliquidlegionsv2encryption", + srcs = [ + "JniReachOnlyLiquidLegionsV2Encryption.kt", + "ReachOnlyLiquidLegionsV2Encryption.kt", + ], + deps = [ + "//src/main/proto/wfa/any_sketch/crypto:sketch_encryption_methods_kt_jvm_proto", + "//src/main/proto/wfa/measurement/internal/duchy/protocol:reach_only_liquid_legions_v2_encryption_methods_kt_jvm_proto", + "//src/main/swig/protocol/reachonlyliquidlegionsv2:reach_only_liquid_legions_v2_encryption_utility", + "@any_sketch_java//src/main/java/org/wfanet/anysketch/crypto:sketch_encrypter_adapter", + "@wfa_common_jvm//imports/java/com/google/protobuf", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common", + ], +) diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/JniReachOnlyLiquidLegionsV2Encryption.kt b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/JniReachOnlyLiquidLegionsV2Encryption.kt new file mode 100644 index 00000000000..58464adc9d6 --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/JniReachOnlyLiquidLegionsV2Encryption.kt @@ -0,0 +1,107 @@ +// Copyright 2023 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package org.wfanet.measurement.duchy.daemon.mill.liquidlegionsv2.crypto + +import java.nio.file.Paths +import org.wfanet.anysketch.crypto.CombineElGamalPublicKeysRequest +import org.wfanet.anysketch.crypto.CombineElGamalPublicKeysResponse +import org.wfanet.anysketch.crypto.SketchEncrypterAdapter +import org.wfanet.measurement.common.loadLibrary +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyExecutionPhaseAtAggregatorRequest +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyExecutionPhaseAtAggregatorResponse +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyExecutionPhaseRequest +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyExecutionPhaseResponse +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyInitializationPhaseRequest +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyInitializationPhaseResponse +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlySetupPhaseRequest +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlySetupPhaseResponse +import org.wfanet.measurement.internal.duchy.protocol.reachonlyliquidlegionsv2.ReachOnlyLiquidLegionsV2EncryptionUtility + +/** + * A [ReachOnlyLiquidLegionsV2Encryption] implementation using the JNI + * [ReachOnlyLiquidLegionsV2EncryptionUtility]. + */ +class JniReachOnlyLiquidLegionsV2Encryption : ReachOnlyLiquidLegionsV2Encryption { + + override fun completeReachOnlyInitializationPhase( + request: CompleteReachOnlyInitializationPhaseRequest + ): CompleteReachOnlyInitializationPhaseResponse { + return CompleteReachOnlyInitializationPhaseResponse.parseFrom( + ReachOnlyLiquidLegionsV2EncryptionUtility.completeReachOnlyInitializationPhase( + request.toByteArray() + ) + ) + } + + override fun completeReachOnlySetupPhase( + request: CompleteReachOnlySetupPhaseRequest + ): CompleteReachOnlySetupPhaseResponse { + return CompleteReachOnlySetupPhaseResponse.parseFrom( + ReachOnlyLiquidLegionsV2EncryptionUtility.completeReachOnlySetupPhase(request.toByteArray()) + ) + } + + override fun completeReachOnlySetupPhaseAtAggregator( + request: CompleteReachOnlySetupPhaseRequest + ): CompleteReachOnlySetupPhaseResponse { + return CompleteReachOnlySetupPhaseResponse.parseFrom( + ReachOnlyLiquidLegionsV2EncryptionUtility.completeReachOnlySetupPhaseAtAggregator( + request.toByteArray() + ) + ) + } + + override fun completeReachOnlyExecutionPhase( + request: CompleteReachOnlyExecutionPhaseRequest + ): CompleteReachOnlyExecutionPhaseResponse { + return CompleteReachOnlyExecutionPhaseResponse.parseFrom( + ReachOnlyLiquidLegionsV2EncryptionUtility.completeReachOnlyExecutionPhase( + request.toByteArray() + ) + ) + } + + override fun completeReachOnlyExecutionPhaseAtAggregator( + request: CompleteReachOnlyExecutionPhaseAtAggregatorRequest + ): CompleteReachOnlyExecutionPhaseAtAggregatorResponse { + return CompleteReachOnlyExecutionPhaseAtAggregatorResponse.parseFrom( + ReachOnlyLiquidLegionsV2EncryptionUtility.completeReachOnlyExecutionPhaseAtAggregator( + request.toByteArray() + ) + ) + } + + override fun combineElGamalPublicKeys( + request: CombineElGamalPublicKeysRequest + ): CombineElGamalPublicKeysResponse { + return CombineElGamalPublicKeysResponse.parseFrom( + SketchEncrypterAdapter.CombineElGamalPublicKeys(request.toByteArray()) + ) + } + + companion object { + init { + loadLibrary( + name = "reach_only_liquid_legions_v2_encryption_utility", + directoryPath = + Paths.get("wfa_measurement_system/src/main/swig/protocol/reachonlyliquidlegionsv2") + ) + loadLibrary( + name = "sketch_encrypter_adapter", + directoryPath = Paths.get("any_sketch_java/src/main/java/org/wfanet/anysketch/crypto") + ) + } + } +} diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/ReachOnlyLiquidLegionsV2Encryption.kt b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/ReachOnlyLiquidLegionsV2Encryption.kt new file mode 100644 index 00000000000..733cb3cba81 --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/ReachOnlyLiquidLegionsV2Encryption.kt @@ -0,0 +1,58 @@ +// Copyright 2023 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package org.wfanet.measurement.duchy.daemon.mill.liquidlegionsv2.crypto + +import org.wfanet.anysketch.crypto.CombineElGamalPublicKeysRequest +import org.wfanet.anysketch.crypto.CombineElGamalPublicKeysResponse +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyExecutionPhaseAtAggregatorRequest +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyExecutionPhaseAtAggregatorResponse +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyExecutionPhaseRequest +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyExecutionPhaseResponse +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyInitializationPhaseRequest +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyInitializationPhaseResponse +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlySetupPhaseRequest +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlySetupPhaseResponse + +/** + * Crypto operations for the Reach Only Liquid Legions v2 protocol. check + * src/main/cc/wfa/measurement/common/crypto/reach_only_liquid_legions_v2_encryption_utility.h for + * more descriptions. + */ +interface ReachOnlyLiquidLegionsV2Encryption { + + fun completeReachOnlyInitializationPhase( + request: CompleteReachOnlyInitializationPhaseRequest + ): CompleteReachOnlyInitializationPhaseResponse + + fun completeReachOnlySetupPhase( + request: CompleteReachOnlySetupPhaseRequest + ): CompleteReachOnlySetupPhaseResponse + + fun completeReachOnlySetupPhaseAtAggregator( + request: CompleteReachOnlySetupPhaseRequest + ): CompleteReachOnlySetupPhaseResponse + + fun completeReachOnlyExecutionPhase( + request: CompleteReachOnlyExecutionPhaseRequest + ): CompleteReachOnlyExecutionPhaseResponse + + fun completeReachOnlyExecutionPhaseAtAggregator( + request: CompleteReachOnlyExecutionPhaseAtAggregatorRequest + ): CompleteReachOnlyExecutionPhaseAtAggregatorResponse + + fun combineElGamalPublicKeys( + request: CombineElGamalPublicKeysRequest + ): CombineElGamalPublicKeysResponse +} diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/testing/FakeComputationsDatabase.kt b/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/testing/FakeComputationsDatabase.kt index 6a835f839ac..c5eeaaeccf2 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/testing/FakeComputationsDatabase.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/testing/FakeComputationsDatabase.kt @@ -351,6 +351,8 @@ private constructor( when (it.computationStage.stageCase) { ComputationStage.StageCase.LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 -> ComputationType.LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 + ComputationStage.StageCase.REACH_ONLY_LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 -> + ComputationType.REACH_ONLY_LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 else -> error("Computation type for $it is unknown") }, stage = it.computationStage, diff --git a/src/main/proto/wfa/measurement/internal/duchy/protocol/liquid_legions_sketch_aggregation_v2.proto b/src/main/proto/wfa/measurement/internal/duchy/protocol/liquid_legions_sketch_aggregation_v2.proto index bc49b00bd17..d99d9c43a38 100644 --- a/src/main/proto/wfa/measurement/internal/duchy/protocol/liquid_legions_sketch_aggregation_v2.proto +++ b/src/main/proto/wfa/measurement/internal/duchy/protocol/liquid_legions_sketch_aggregation_v2.proto @@ -119,8 +119,7 @@ message LiquidLegionsSketchAggregationV2 { // The maximum frequency to reveal in the histogram. int32 maximum_frequency = 1; // Parameters used for liquidLegions sketch creation and estimation. - // TODO(@ple): rename to sketch_parameters. - LiquidLegionsSketchParameters liquid_legions_sketch = 2; + LiquidLegionsSketchParameters sketch_parameters = 2; // Noise parameters selected for the LiquidLegionV2 MPC protocol. LiquidLegionsV2NoiseConfig noise = 3; // ID of the OpenSSL built-in elliptic curve. For example, 415 for the diff --git a/src/main/swig/protocol/reachonlyliquidlegionsv2/BUILD.bazel b/src/main/swig/protocol/reachonlyliquidlegionsv2/BUILD.bazel new file mode 100644 index 00000000000..fd7a57bb8a8 --- /dev/null +++ b/src/main/swig/protocol/reachonlyliquidlegionsv2/BUILD.bazel @@ -0,0 +1,16 @@ +load("@wfa_rules_swig//java:defs.bzl", "java_wrap_cc") + +package(default_visibility = [ + "//src/main/kotlin/org/wfanet/measurement:__subpackages__", + "//src/test/kotlin/org/wfanet/measurement:__subpackages__", +]) + +java_wrap_cc( + name = "reach_only_liquid_legions_v2_encryption_utility", + src = "reach_only_liquid_legions_v2_encryption_utility.swig", + module = "ReachOnlyLiquidLegionsV2EncryptionUtility", + package = "org.wfanet.measurement.internal.duchy.protocol.reachonlyliquidlegionsv2", + deps = [ + "//src/main/cc/wfa/measurement/internal/duchy/protocol/liquid_legions_v2:reach_only_liquid_legions_v2_encryption_utility_wrapper", + ], +) diff --git a/src/main/swig/protocol/reachonlyliquidlegionsv2/README.md b/src/main/swig/protocol/reachonlyliquidlegionsv2/README.md new file mode 100644 index 00000000000..3f3f9cc9b28 --- /dev/null +++ b/src/main/swig/protocol/reachonlyliquidlegionsv2/README.md @@ -0,0 +1,22 @@ +# Reach Only Liquid Legions V2 Encryption Utility Java Library + +The ReachOnlyLiquidLegionsV2EncryptionUtility java class is auto-generated from +the reach_only_liquid_legions_v2_encryption_utility.swig definition. The +implementation of the methods is written in c++ and the source codes are under +src/main/cc/measurement/crypto. We create a swig wrapper on the library and call +into the library via JNI in our java code. + +## Possible errors when using the JNI java library. + +### swig uninstalled + +To keep the library updated, each time when the java library is built, it would +run a swig command (provided in the BUILD.bazel rule) to re-generate all the +swig wrapper files using the latest c++ codes. As a result, the swig software is +required to build the java library. Install swig before building the package. + +For example, in a system using apt-get, run the following command to get swig: + +```shell +sudo apt-get install swig +``` diff --git a/src/main/swig/protocol/reachonlyliquidlegionsv2/reach_only_liquid_legions_v2_encryption_utility.swig b/src/main/swig/protocol/reachonlyliquidlegionsv2/reach_only_liquid_legions_v2_encryption_utility.swig new file mode 100644 index 00000000000..0f793c38ff2 --- /dev/null +++ b/src/main/swig/protocol/reachonlyliquidlegionsv2/reach_only_liquid_legions_v2_encryption_utility.swig @@ -0,0 +1,67 @@ +// Copyright 2023 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +%include "exception.i" +%{ +#include "wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility_wrapper.h" +%} + +// Convert C++ ::absl::StatusOr to a Java byte +// array. +%typemap(jni) ::absl::StatusOr "jbyteArray" +%typemap(jtype) ::absl::StatusOr "byte[]" +%typemap(jstype) ::absl::StatusOr "byte[]" +%typemap(out) ::absl::StatusOr { + if ($1.ok()) { + $result = JCALL1(NewByteArray, jenv, $1.value().length()); + JCALL4(SetByteArrayRegion, jenv, $result, 0, + $1.value().length(), (const jbyte*) $1.value().data()); + } else { + SWIG_exception(SWIG_RuntimeError, $1.status().ToString().c_str()); + } +} +%typemap(javaout) ::absl::StatusOr { + return $jnicall; +} + +// Convert Java byte array to C++ const std::string& for any const std::string& +// input parameter. +%typemap(jni) const std::string& "jbyteArray" +%typemap(jtype) const std::string& "byte[]" +%typemap(jstype) const std::string& "byte[]" +%typemap(javain) const std::string& "$javainput" +%typemap(in) const std::string& { + jsize temp_length = JCALL1(GetArrayLength, jenv, $input); + jbyte* temp_bytes = JCALL2(GetByteArrayElements, jenv, $input, 0); + $1 = new std::string((const char*) temp_bytes, temp_length); + JCALL3(ReleaseByteArrayElements, jenv, $input, temp_bytes, JNI_ABORT); +} +// Convert C++ std::string to a Java byte array. +%clear std::string; +%typemap(jni) std::string "jbyteArray" +%typemap(jtype) std::string "byte[]" +%typemap(jstype) std::string "byte[]" +%typemap(out) std::string { + $result = JCALL1(NewByteArray, jenv, $1.length()); + JCALL4(SetByteArrayRegion, jenv, $result, 0, + $1.length(), (const jbyte*) $1.data()); +} +%typemap(javaout) std::string { + return $jnicall; +} + +// Follow Java style for function names. +%rename("%(lowercamelcase)s",%$isfunction) ""; + +%include "wfa/measurement/internal/duchy/protocol/liquid_legions_v2/reach_only_liquid_legions_v2_encryption_utility_wrapper.h" diff --git a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/herald/HeraldTest.kt b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/herald/HeraldTest.kt index 055b6c572e6..9ff4d00a892 100644 --- a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/herald/HeraldTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/herald/HeraldTest.kt @@ -465,7 +465,7 @@ class HeraldTest { parameters = LiquidLegionsSketchAggregationV2Kt.ComputationDetailsKt.parameters { maximumFrequency = 10 - liquidLegionsSketch = liquidLegionsSketchParameters { + sketchParameters = liquidLegionsSketchParameters { decayRate = 12.0 size = 100_000L } @@ -567,7 +567,7 @@ class HeraldTest { parameters = LiquidLegionsSketchAggregationV2Kt.ComputationDetailsKt.parameters { maximumFrequency = 10 - liquidLegionsSketch = liquidLegionsSketchParameters { + sketchParameters = liquidLegionsSketchParameters { decayRate = 12.0 size = 100_000L } diff --git a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/BUILD.bazel b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/BUILD.bazel index 1ba87e7a446..9a59539dc81 100644 --- a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/BUILD.bazel +++ b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/BUILD.bazel @@ -29,3 +29,33 @@ kt_jvm_test( "@wfa_consent_signaling_client//src/main/kotlin/org/wfanet/measurement/consent/client/measurementconsumer", ], ) + +kt_jvm_test( + name = "ReachOnlyLiquidLegionsV2MillTest", + srcs = ["ReachOnlyLiquidLegionsV2MillTest.kt"], + test_class = "org.wfanet.measurement.duchy.daemon.mill.liquidlegionsv2.crypto.ReachOnlyLiquidLegionsV2MillTest", + deps = [ + "//imports/java/io/opentelemetry/api", + "//src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2:reach_only_liquid_legions_v2_mill", + "//src/main/kotlin/org/wfanet/measurement/duchy/daemon/testing", + "//src/main/kotlin/org/wfanet/measurement/duchy/db/computation", + "//src/main/kotlin/org/wfanet/measurement/duchy/db/computation/testing", + "//src/main/kotlin/org/wfanet/measurement/duchy/service/internal/computations", + "//src/main/kotlin/org/wfanet/measurement/duchy/storage:computation_store", + "//src/main/kotlin/org/wfanet/measurement/duchy/storage:requisition_store", + "//src/main/kotlin/org/wfanet/measurement/system/v1alpha:resource_key", + "//src/main/proto/wfa/measurement/api/v2alpha:crypto_kt_jvm_proto", + "//src/main/proto/wfa/measurement/internal/duchy/protocol:reach_only_liquid_legions_v2_encryption_methods_kt_jvm_proto", + "//src/main/swig/protocol/reachonlyliquidlegionsv2:reach_only_liquid_legions_v2_encryption_utility", + "@wfa_common_jvm//imports/kotlin/com/google/protobuf/kotlin", + "@wfa_common_jvm//imports/kotlin/kotlin/test", + "@wfa_common_jvm//imports/kotlin/org/mockito/kotlin", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/crypto:hashing", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/crypto/testing", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/crypto/tink", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc/testing", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/storage/filesystem:client", + "@wfa_consent_signaling_client//src/main/kotlin/org/wfanet/measurement/consent/client/common:key_handles", + "@wfa_consent_signaling_client//src/main/kotlin/org/wfanet/measurement/consent/client/measurementconsumer", + ], +) diff --git a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/LiquidLegionsV2MillTest.kt b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/LiquidLegionsV2MillTest.kt index 5d2deba3548..e5e5590b648 100644 --- a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/LiquidLegionsV2MillTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/LiquidLegionsV2MillTest.kt @@ -113,7 +113,6 @@ import org.wfanet.measurement.internal.duchy.protocol.CompleteInitializationPhas import org.wfanet.measurement.internal.duchy.protocol.CompleteSetupPhaseRequest import org.wfanet.measurement.internal.duchy.protocol.CompleteSetupPhaseResponse import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggregationV2.ComputationDetails.ComputationParticipant -import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggregationV2.ComputationDetails.Parameters import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggregationV2.Stage.COMPLETE import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggregationV2.Stage.CONFIRMATION_PHASE import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggregationV2.Stage.EXECUTION_PHASE_ONE @@ -128,6 +127,7 @@ import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggrega import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggregationV2.Stage.WAIT_SETUP_PHASE_INPUTS import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggregationV2.Stage.WAIT_TO_START import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggregationV2Kt +import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggregationV2Kt.ComputationDetailsKt.parameters import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsV2NoiseConfig import org.wfanet.measurement.internal.duchy.protocol.completeExecutionPhaseOneAtAggregatorRequest import org.wfanet.measurement.internal.duchy.protocol.completeExecutionPhaseOneAtAggregatorResponse @@ -252,18 +252,15 @@ private val TEST_NOISE_CONFIG = } .build() -private val LLV2_PARAMETERS = - Parameters.newBuilder() - .apply { - maximumFrequency = MAX_FREQUENCY - liquidLegionsSketchBuilder.apply { - decayRate = DECAY_RATE - size = SKETCH_SIZE - } - noise = TEST_NOISE_CONFIG - ellipticCurveId = CURVE_ID.toInt() - } - .build() +private val LLV2_PARAMETERS = parameters { + maximumFrequency = MAX_FREQUENCY + sketchParameters = liquidLegionsSketchParameters { + decayRate = DECAY_RATE + size = SKETCH_SIZE + } + noise = TEST_NOISE_CONFIG + ellipticCurveId = CURVE_ID.toInt() +} // In the test, use the same set of cert and encryption key for all parties. private const val CONSENT_SIGNALING_CERT_NAME = "Just a name" diff --git a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2MillTest.kt b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2MillTest.kt new file mode 100644 index 00000000000..ba1d895881d --- /dev/null +++ b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2MillTest.kt @@ -0,0 +1,1930 @@ +// Copyright 2023 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package org.wfanet.measurement.duchy.daemon.mill.liquidlegionsv2.crypto + +import com.google.common.truth.Truth.assertThat +import com.google.common.truth.extensions.proto.ProtoTruth.assertThat +import com.google.protobuf.ByteString +import com.google.protobuf.kotlin.toByteString +import io.grpc.Status +import io.opentelemetry.api.GlobalOpenTelemetry +import java.security.cert.X509Certificate +import java.time.Clock +import java.time.Duration +import java.util.Base64 +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.toList +import kotlinx.coroutines.runBlocking +import org.junit.Before +import org.junit.Rule +import org.junit.Test +import org.junit.rules.TemporaryFolder +import org.junit.runner.RunWith +import org.junit.runners.JUnit4 +import org.mockito.kotlin.UseConstructor +import org.mockito.kotlin.any +import org.mockito.kotlin.mock +import org.mockito.kotlin.whenever +import org.wfanet.anysketch.crypto.CombineElGamalPublicKeysRequest +import org.wfanet.anysketch.crypto.combineElGamalPublicKeysResponse +import org.wfanet.anysketch.crypto.elGamalPublicKey as AnySketchElGamalPublicKey +import org.wfanet.measurement.api.v2alpha.MeasurementSpecKt.reach +import org.wfanet.measurement.api.v2alpha.MeasurementSpecKt.vidSamplingInterval +import org.wfanet.measurement.api.v2alpha.elGamalPublicKey as V2AlphaElGamalPublicKey +import org.wfanet.measurement.api.v2alpha.measurementSpec +import org.wfanet.measurement.common.crypto.SigningKeyHandle +import org.wfanet.measurement.common.crypto.readCertificate +import org.wfanet.measurement.common.crypto.readPrivateKey +import org.wfanet.measurement.common.crypto.testing.TestData +import org.wfanet.measurement.common.crypto.tink.TinkPrivateKeyHandle +import org.wfanet.measurement.common.flatten +import org.wfanet.measurement.common.grpc.testing.GrpcTestServerRule +import org.wfanet.measurement.common.grpc.testing.mockService +import org.wfanet.measurement.common.identity.DuchyInfo +import org.wfanet.measurement.common.testing.chainRulesSequentially +import org.wfanet.measurement.common.testing.verifyProtoArgument +import org.wfanet.measurement.common.throttler.MinimumIntervalThrottler +import org.wfanet.measurement.consent.client.common.toEncryptionPublicKey +import org.wfanet.measurement.duchy.daemon.mill.Certificate +import org.wfanet.measurement.duchy.daemon.mill.liquidlegionsv2.ReachOnlyLiquidLegionsV2Mill +import org.wfanet.measurement.duchy.daemon.testing.TestRequisition +import org.wfanet.measurement.duchy.daemon.utils.toDuchyEncryptionPublicKey +import org.wfanet.measurement.duchy.db.computation.ComputationDataClients +import org.wfanet.measurement.duchy.db.computation.testing.FakeComputationsDatabase +import org.wfanet.measurement.duchy.service.internal.computations.ComputationsService +import org.wfanet.measurement.duchy.service.internal.computations.newEmptyOutputBlobMetadata +import org.wfanet.measurement.duchy.service.internal.computations.newInputBlobMetadata +import org.wfanet.measurement.duchy.storage.ComputationBlobContext +import org.wfanet.measurement.duchy.storage.ComputationStore +import org.wfanet.measurement.duchy.storage.RequisitionBlobContext +import org.wfanet.measurement.duchy.storage.RequisitionStore +import org.wfanet.measurement.duchy.toProtocolStage +import org.wfanet.measurement.internal.duchy.ComputationBlobDependency +import org.wfanet.measurement.internal.duchy.ComputationDetails.CompletedReason +import org.wfanet.measurement.internal.duchy.ComputationDetailsKt.kingdomComputationDetails +import org.wfanet.measurement.internal.duchy.ComputationStatsGrpcKt.ComputationStatsCoroutineImplBase +import org.wfanet.measurement.internal.duchy.ComputationStatsGrpcKt.ComputationStatsCoroutineStub +import org.wfanet.measurement.internal.duchy.ComputationsGrpcKt.ComputationsCoroutineStub +import org.wfanet.measurement.internal.duchy.computationDetails +import org.wfanet.measurement.internal.duchy.computationStageBlobMetadata +import org.wfanet.measurement.internal.duchy.computationStageDetails +import org.wfanet.measurement.internal.duchy.computationToken +import org.wfanet.measurement.internal.duchy.config.LiquidLegionsV2SetupConfig.RoleInComputation +import org.wfanet.measurement.internal.duchy.copy +import org.wfanet.measurement.internal.duchy.differentialPrivacyParams +import org.wfanet.measurement.internal.duchy.elGamalKeyPair +import org.wfanet.measurement.internal.duchy.elGamalPublicKey +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyExecutionPhaseAtAggregatorRequest +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyExecutionPhaseRequest +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyInitializationPhaseRequest +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlySetupPhaseRequest +import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsV2NoiseConfigKt +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2.Stage.COMPLETE +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2.Stage.CONFIRMATION_PHASE +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2.Stage.EXECUTION_PHASE +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2.Stage.INITIALIZATION_PHASE +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2.Stage.SETUP_PHASE +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2.Stage.WAIT_EXECUTION_PHASE_INPUTS +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2.Stage.WAIT_REQUISITIONS_AND_KEY_SET +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2.Stage.WAIT_SETUP_PHASE_INPUTS +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2.Stage.WAIT_TO_START +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2Kt +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2Kt.ComputationDetailsKt.computationParticipant +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2Kt.ComputationDetailsKt.parameters +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2Kt.stageDetails +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2Kt.waitSetupPhaseInputsDetails +import org.wfanet.measurement.internal.duchy.protocol.completeReachOnlyExecutionPhaseAtAggregatorRequest +import org.wfanet.measurement.internal.duchy.protocol.completeReachOnlyExecutionPhaseAtAggregatorResponse +import org.wfanet.measurement.internal.duchy.protocol.completeReachOnlyExecutionPhaseRequest +import org.wfanet.measurement.internal.duchy.protocol.completeReachOnlyExecutionPhaseResponse +import org.wfanet.measurement.internal.duchy.protocol.completeReachOnlyInitializationPhaseRequest +import org.wfanet.measurement.internal.duchy.protocol.completeReachOnlyInitializationPhaseResponse +import org.wfanet.measurement.internal.duchy.protocol.completeReachOnlySetupPhaseRequest +import org.wfanet.measurement.internal.duchy.protocol.completeReachOnlySetupPhaseResponse +import org.wfanet.measurement.internal.duchy.protocol.copy +import org.wfanet.measurement.internal.duchy.protocol.globalReachDpNoiseBaseline +import org.wfanet.measurement.internal.duchy.protocol.liquidLegionsSketchParameters +import org.wfanet.measurement.internal.duchy.protocol.liquidLegionsV2NoiseConfig +import org.wfanet.measurement.internal.duchy.protocol.reachNoiseDifferentialPrivacyParams +import org.wfanet.measurement.internal.duchy.protocol.registerNoiseGenerationParameters +import org.wfanet.measurement.storage.Store.Blob +import org.wfanet.measurement.storage.filesystem.FileSystemStorageClient +import org.wfanet.measurement.system.v1alpha.AdvanceComputationRequest +import org.wfanet.measurement.system.v1alpha.AdvanceComputationRequestKt +import org.wfanet.measurement.system.v1alpha.AdvanceComputationResponse +import org.wfanet.measurement.system.v1alpha.Computation +import org.wfanet.measurement.system.v1alpha.ComputationControlGrpcKt.ComputationControlCoroutineImplBase +import org.wfanet.measurement.system.v1alpha.ComputationControlGrpcKt.ComputationControlCoroutineStub +import org.wfanet.measurement.system.v1alpha.ComputationKey +import org.wfanet.measurement.system.v1alpha.ComputationLogEntriesGrpcKt.ComputationLogEntriesCoroutineImplBase as SystemComputationLogEntriesCoroutineImplBase +import org.wfanet.measurement.system.v1alpha.ComputationLogEntriesGrpcKt.ComputationLogEntriesCoroutineStub as SystemComputationLogEntriesCoroutineStub +import org.wfanet.measurement.system.v1alpha.ComputationLogEntry +import org.wfanet.measurement.system.v1alpha.ComputationParticipantKey +import org.wfanet.measurement.system.v1alpha.ComputationParticipantKt +import org.wfanet.measurement.system.v1alpha.ComputationParticipantKt.RequisitionParamsKt +import org.wfanet.measurement.system.v1alpha.ComputationParticipantKt.requisitionParams +import org.wfanet.measurement.system.v1alpha.ComputationParticipantsGrpcKt.ComputationParticipantsCoroutineImplBase as SystemComputationParticipantsCoroutineImplBase +import org.wfanet.measurement.system.v1alpha.ComputationParticipantsGrpcKt.ComputationParticipantsCoroutineStub as SystemComputationParticipantsCoroutineStub +import org.wfanet.measurement.system.v1alpha.ComputationsGrpcKt.ComputationsCoroutineImplBase as SystemComputationsCoroutineImplBase +import org.wfanet.measurement.system.v1alpha.ComputationsGrpcKt.ComputationsCoroutineStub as SystemComputationsCoroutineStub +import org.wfanet.measurement.system.v1alpha.ReachOnlyLiquidLegionsV2 +import org.wfanet.measurement.system.v1alpha.ReachOnlyLiquidLegionsV2.Description.EXECUTION_PHASE_INPUT +import org.wfanet.measurement.system.v1alpha.ReachOnlyLiquidLegionsV2.Description.SETUP_PHASE_INPUT +import org.wfanet.measurement.system.v1alpha.Requisition +import org.wfanet.measurement.system.v1alpha.SetComputationResultRequest +import org.wfanet.measurement.system.v1alpha.advanceComputationRequest +import org.wfanet.measurement.system.v1alpha.confirmComputationParticipantRequest +import org.wfanet.measurement.system.v1alpha.failComputationParticipantRequest +import org.wfanet.measurement.system.v1alpha.reachOnlyLiquidLegionsV2 +import org.wfanet.measurement.system.v1alpha.setComputationResultRequest +import org.wfanet.measurement.system.v1alpha.setParticipantRequisitionParamsRequest +import org.wfanet.measurement.system.v1alpha.stageAttempt + +private const val PUBLIC_API_VERSION = "v2alpha" + +private const val WORKER_COUNT = 3 +private const val MILL_ID = "a nice mill" +private const val DUCHY_ONE_NAME = "DUCHY_ONE" +private const val DUCHY_TWO_NAME = "DUCHY_TWO" +private const val DUCHY_THREE_NAME = "DUCHY_THREE" +private const val DECAY_RATE = 12.0 +private const val SKETCH_SIZE = 100_000L +private const val CURVE_ID = 415L // NID_X9_62_prime256v1 +private const val PARALLELISM = 2 + +private const val LOCAL_ID = 1234L +private const val GLOBAL_ID = LOCAL_ID.toString() + +private const val CIPHERTEXT_SIZE = 66 +private const val NOISE_CIPHERTEXT = + "abcdefghijklmnopqrstuvwxyz0123456abcdefghijklmnopqrstuvwxyz0123456" +private val SERIALIZED_NOISE_CIPHERTEXT = ByteString.copyFromUtf8(NOISE_CIPHERTEXT) + +private val DUCHY_ONE_KEY_PAIR = elGamalKeyPair { + publicKey = elGamalPublicKey { + generator = ByteString.copyFromUtf8("generator_1") + element = ByteString.copyFromUtf8("element_1") + } + secretKey = ByteString.copyFromUtf8("secret_key_1") +} +private val DUCHY_TWO_PUBLIC_KEY = elGamalPublicKey { + generator = ByteString.copyFromUtf8("generator_2") + element = ByteString.copyFromUtf8("element_2") +} +private val DUCHY_THREE_PUBLIC_KEY = elGamalPublicKey { + generator = ByteString.copyFromUtf8("generator_3") + element = ByteString.copyFromUtf8("element_3") +} +private val COMBINED_PUBLIC_KEY = elGamalPublicKey { + generator = ByteString.copyFromUtf8("generator_1_generator_2_generator_3") + element = ByteString.copyFromUtf8("element_1_element_2_element_3") +} +private val PARTIALLY_COMBINED_PUBLIC_KEY = elGamalPublicKey { + generator = ByteString.copyFromUtf8("generator_2_generator_3") + element = ByteString.copyFromUtf8("element_2_element_3") +} + +private val TEST_NOISE_CONFIG = liquidLegionsV2NoiseConfig { + reachNoiseConfig = + LiquidLegionsV2NoiseConfigKt.reachNoiseConfig { + blindHistogramNoise = differentialPrivacyParams { + epsilon = 1.0 + delta = 2.0 + } + noiseForPublisherNoise = differentialPrivacyParams { + epsilon = 3.0 + delta = 4.0 + } + globalReachDpNoise = differentialPrivacyParams { + epsilon = 5.0 + delta = 6.0 + } + } +} + +private val ROLLV2_PARAMETERS = parameters { + sketchParameters = liquidLegionsSketchParameters { + decayRate = DECAY_RATE + size = SKETCH_SIZE + } + noise = TEST_NOISE_CONFIG + ellipticCurveId = CURVE_ID.toInt() +} + +// In the test, use the same set of cert and encryption key for all parties. +private const val CONSENT_SIGNALING_CERT_NAME = "Just a name" +private val CONSENT_SIGNALING_CERT_DER = + TestData.FIXED_SERVER_CERT_DER_FILE.readBytes().toByteString() +private val CONSENT_SIGNALING_PRIVATE_KEY_DER = + TestData.FIXED_SERVER_KEY_DER_FILE.readBytes().toByteString() +private val CONSENT_SIGNALING_ROOT_CERT: X509Certificate = + readCertificate(TestData.FIXED_CA_CERT_PEM_FILE) +private val ENCRYPTION_PRIVATE_KEY = TinkPrivateKeyHandle.generateEcies() +private val ENCRYPTION_PUBLIC_KEY: org.wfanet.measurement.api.v2alpha.EncryptionPublicKey = + ENCRYPTION_PRIVATE_KEY.publicKey.toEncryptionPublicKey() + +/** A public Key used for consent signaling check. */ +private val CONSENT_SIGNALING_EL_GAMAL_PUBLIC_KEY = V2AlphaElGamalPublicKey { + generator = ByteString.copyFromUtf8("generator-foo") + element = ByteString.copyFromUtf8("element-foo") +} +/** A pre-computed signature of the CONSENT_SIGNALING_EL_GAMAL_PUBLIC_KEY. */ +private val CONSENT_SIGNALING_EL_GAMAL_PUBLIC_KEY_SINGATURE = + ByteString.copyFrom( + Base64.getDecoder() + .decode( + "MEUCIB2HWi/udHE9YlCH6n9lnGY12I9F1ra1SRWoJrIOXGiAAiEAm90wrJRqABFC5sjej+" + + "zjSBOMHcmFOEpfW9tXaCla6Qs=" + ) + ) + +private val COMPUTATION_PARTICIPANT_1 = computationParticipant { + duchyId = DUCHY_ONE_NAME + publicKey = DUCHY_ONE_KEY_PAIR.publicKey + elGamalPublicKey = CONSENT_SIGNALING_EL_GAMAL_PUBLIC_KEY.toByteString() + elGamalPublicKeySignature = CONSENT_SIGNALING_EL_GAMAL_PUBLIC_KEY_SINGATURE + duchyCertificateDer = CONSENT_SIGNALING_CERT_DER +} +private val COMPUTATION_PARTICIPANT_2 = computationParticipant { + duchyId = DUCHY_TWO_NAME + publicKey = DUCHY_TWO_PUBLIC_KEY + elGamalPublicKey = CONSENT_SIGNALING_EL_GAMAL_PUBLIC_KEY.toByteString() + elGamalPublicKeySignature = CONSENT_SIGNALING_EL_GAMAL_PUBLIC_KEY_SINGATURE + duchyCertificateDer = CONSENT_SIGNALING_CERT_DER +} +private val COMPUTATION_PARTICIPANT_3 = computationParticipant { + duchyId = DUCHY_THREE_NAME + publicKey = DUCHY_THREE_PUBLIC_KEY + elGamalPublicKey = CONSENT_SIGNALING_EL_GAMAL_PUBLIC_KEY.toByteString() + elGamalPublicKeySignature = CONSENT_SIGNALING_EL_GAMAL_PUBLIC_KEY_SINGATURE + duchyCertificateDer = CONSENT_SIGNALING_CERT_DER +} + +private val TEST_REQUISITION_1 = TestRequisition("111") { SERIALIZED_MEASUREMENT_SPEC } +private val TEST_REQUISITION_2 = TestRequisition("222") { SERIALIZED_MEASUREMENT_SPEC } +private val TEST_REQUISITION_3 = TestRequisition("333") { SERIALIZED_MEASUREMENT_SPEC } + +private val MEASUREMENT_SPEC = measurementSpec { + nonceHashes += TEST_REQUISITION_1.nonceHash + nonceHashes += TEST_REQUISITION_2.nonceHash + nonceHashes += TEST_REQUISITION_3.nonceHash + reach = reach {} +} +private val SERIALIZED_MEASUREMENT_SPEC: ByteString = MEASUREMENT_SPEC.toByteString() + +private val MEASUREMENT_SPEC_WITH_VID_SAMPLING_WIDTH = measurementSpec { + nonceHashes += TEST_REQUISITION_1.nonceHash + nonceHashes += TEST_REQUISITION_2.nonceHash + nonceHashes += TEST_REQUISITION_3.nonceHash + reach = reach {} + vidSamplingInterval = vidSamplingInterval { width = 0.5f } +} + +private val SERIALIZED_MEASUREMENT_SPEC_WITH_VID_SAMPLING_WIDTH = + MEASUREMENT_SPEC_WITH_VID_SAMPLING_WIDTH.toByteString() + +private val REQUISITION_1 = + TEST_REQUISITION_1.toRequisitionMetadata(Requisition.State.FULFILLED, DUCHY_ONE_NAME).copy { + path = RequisitionBlobContext(GLOBAL_ID, externalKey.externalRequisitionId).blobKey + } +private val REQUISITION_2 = + TEST_REQUISITION_2.toRequisitionMetadata(Requisition.State.FULFILLED, DUCHY_TWO_NAME) +private val REQUISITION_3 = + TEST_REQUISITION_3.toRequisitionMetadata(Requisition.State.FULFILLED, DUCHY_THREE_NAME) +private val REQUISITIONS = listOf(REQUISITION_1, REQUISITION_2, REQUISITION_3) + +private val AGGREGATOR_COMPUTATION_DETAILS = computationDetails { + kingdomComputation = kingdomComputationDetails { + publicApiVersion = PUBLIC_API_VERSION + measurementPublicKey = ENCRYPTION_PUBLIC_KEY.toDuchyEncryptionPublicKey() + measurementSpec = SERIALIZED_MEASUREMENT_SPEC + } + reachOnlyLiquidLegionsV2 = + ReachOnlyLiquidLegionsSketchAggregationV2Kt.computationDetails { + role = RoleInComputation.AGGREGATOR + parameters = ROLLV2_PARAMETERS + participant += + listOf(COMPUTATION_PARTICIPANT_2, COMPUTATION_PARTICIPANT_3, COMPUTATION_PARTICIPANT_1) + combinedPublicKey = COMBINED_PUBLIC_KEY + // partiallyCombinedPublicKey and combinedPublicKey are the same at the aggregator. + partiallyCombinedPublicKey = COMBINED_PUBLIC_KEY + localElgamalKey = DUCHY_ONE_KEY_PAIR + } +} + +private val NON_AGGREGATOR_COMPUTATION_DETAILS = computationDetails { + kingdomComputation = kingdomComputationDetails { + publicApiVersion = PUBLIC_API_VERSION + measurementPublicKey = ENCRYPTION_PUBLIC_KEY.toDuchyEncryptionPublicKey() + measurementSpec = SERIALIZED_MEASUREMENT_SPEC + } + reachOnlyLiquidLegionsV2 = + ReachOnlyLiquidLegionsSketchAggregationV2Kt.computationDetails { + role = RoleInComputation.NON_AGGREGATOR + parameters = ROLLV2_PARAMETERS + participant += + listOf(COMPUTATION_PARTICIPANT_1, COMPUTATION_PARTICIPANT_2, COMPUTATION_PARTICIPANT_3) + combinedPublicKey = COMBINED_PUBLIC_KEY + partiallyCombinedPublicKey = PARTIALLY_COMBINED_PUBLIC_KEY + localElgamalKey = DUCHY_ONE_KEY_PAIR + } +} + +@RunWith(JUnit4::class) +class ReachOnlyLiquidLegionsV2MillTest { + private val mockReachOnlyLiquidLegionsComputationControl: ComputationControlCoroutineImplBase = + mockService { + onBlocking { advanceComputation(any()) } + .thenAnswer { + val request: Flow = it.getArgument(0) + computationControlRequests = runBlocking { request.toList() } + AdvanceComputationResponse.getDefaultInstance() + } + } + private val mockSystemComputations: SystemComputationsCoroutineImplBase = mockService() + private val mockComputationParticipants: SystemComputationParticipantsCoroutineImplBase = + mockService() + private val mockComputationLogEntries: SystemComputationLogEntriesCoroutineImplBase = + mockService() + private val mockComputationStats: ComputationStatsCoroutineImplBase = mockService() + private val mockCryptoWorker: ReachOnlyLiquidLegionsV2Encryption = + mock(useConstructor = UseConstructor.parameterless()) { + on { combineElGamalPublicKeys(any()) } + .thenAnswer { + val cryptoRequest: CombineElGamalPublicKeysRequest = it.getArgument(0) + combineElGamalPublicKeysResponse { + elGamalKeys = AnySketchElGamalPublicKey { + generator = + ByteString.copyFromUtf8( + cryptoRequest.elGamalKeysList + .sortedBy { key -> key.generator.toStringUtf8() } + .joinToString(separator = "_") { key -> key.generator.toStringUtf8() } + ) + element = + ByteString.copyFromUtf8( + cryptoRequest.elGamalKeysList + .sortedBy { key -> key.element.toStringUtf8() } + .joinToString(separator = "_") { key -> key.element.toStringUtf8() } + ) + } + } + } + } + private val fakeComputationDb = FakeComputationsDatabase() + + private lateinit var computationDataClients: ComputationDataClients + private lateinit var computationStore: ComputationStore + private lateinit var requisitionStore: RequisitionStore + + private val tempDirectory = TemporaryFolder() + + private val grpcTestServerRule = GrpcTestServerRule { + val storageClient = FileSystemStorageClient(tempDirectory.root) + computationStore = ComputationStore(storageClient) + requisitionStore = RequisitionStore(storageClient) + computationDataClients = + ComputationDataClients.forTesting( + ComputationsCoroutineStub(channel), + computationStore, + requisitionStore + ) + addService(mockReachOnlyLiquidLegionsComputationControl) + addService(mockSystemComputations) + addService(mockComputationLogEntries) + addService(mockComputationParticipants) + addService(mockComputationStats) + addService( + ComputationsService( + fakeComputationDb, + systemComputationLogEntriesStub, + computationStore, + requisitionStore, + DUCHY_THREE_NAME, + Clock.systemUTC() + ) + ) + } + + @get:Rule val ruleChain = chainRulesSequentially(tempDirectory, grpcTestServerRule) + + private val workerStub: ComputationControlCoroutineStub by lazy { + ComputationControlCoroutineStub(grpcTestServerRule.channel) + } + + private val systemComputationStub: SystemComputationsCoroutineStub by lazy { + SystemComputationsCoroutineStub(grpcTestServerRule.channel) + } + + private val systemComputationLogEntriesStub: SystemComputationLogEntriesCoroutineStub by lazy { + SystemComputationLogEntriesCoroutineStub(grpcTestServerRule.channel) + } + + private val systemComputationParticipantsStub: + SystemComputationParticipantsCoroutineStub by lazy { + SystemComputationParticipantsCoroutineStub(grpcTestServerRule.channel) + } + + private val computationStatsStub: ComputationStatsCoroutineStub by lazy { + ComputationStatsCoroutineStub(grpcTestServerRule.channel) + } + + private lateinit var computationControlRequests: List + + // Just use the same workerStub for all other duchies, since it is not relevant to this test. + private val workerStubs = mapOf(DUCHY_TWO_NAME to workerStub, DUCHY_THREE_NAME to workerStub) + + private lateinit var aggregatorMill: ReachOnlyLiquidLegionsV2Mill + private lateinit var nonAggregatorMill: ReachOnlyLiquidLegionsV2Mill + + private fun buildAdvanceComputationRequests( + globalComputationId: String, + description: ReachOnlyLiquidLegionsV2.Description, + vararg chunkContents: String + ): List { + val header = advanceComputationRequest { + header = + AdvanceComputationRequestKt.header { + name = ComputationKey(globalComputationId).toName() + this.reachOnlyLiquidLegionsV2 = reachOnlyLiquidLegionsV2 { + this.description = description + } + } + } + val body = + chunkContents.asList().map { + advanceComputationRequest { + bodyChunk = + AdvanceComputationRequestKt.bodyChunk { partialData = ByteString.copyFromUtf8(it) } + } + } + return listOf(header) + body + } + + @Before + fun initializeMill() = runBlocking { + val throttler = MinimumIntervalThrottler(Clock.systemUTC(), Duration.ofSeconds(60)) + DuchyInfo.setForTest(setOf(DUCHY_ONE_NAME, DUCHY_TWO_NAME, DUCHY_THREE_NAME)) + val csX509Certificate = readCertificate(CONSENT_SIGNALING_CERT_DER) + val csSigningKey = + SigningKeyHandle( + csX509Certificate, + readPrivateKey(CONSENT_SIGNALING_PRIVATE_KEY_DER, csX509Certificate.publicKey.algorithm) + ) + val csCertificate = Certificate(CONSENT_SIGNALING_CERT_NAME, csX509Certificate) + val trustedCertificates = + DuchyInfo.entries.values.associateBy( + { it.rootCertificateSkid }, + { CONSENT_SIGNALING_ROOT_CERT } + ) + + aggregatorMill = + ReachOnlyLiquidLegionsV2Mill( + millId = MILL_ID, + duchyId = DUCHY_ONE_NAME, + signingKey = csSigningKey, + consentSignalCert = csCertificate, + trustedCertificates = trustedCertificates, + dataClients = computationDataClients, + systemComputationParticipantsClient = systemComputationParticipantsStub, + systemComputationsClient = systemComputationStub, + systemComputationLogEntriesClient = systemComputationLogEntriesStub, + computationStatsClient = computationStatsStub, + throttler = throttler, + workerStubs = workerStubs, + cryptoWorker = mockCryptoWorker, + workLockDuration = Duration.ofMinutes(5), + openTelemetry = GlobalOpenTelemetry.get(), + requestChunkSizeBytes = 20, + maximumAttempts = 2, + parallelism = PARALLELISM + ) + nonAggregatorMill = + ReachOnlyLiquidLegionsV2Mill( + millId = MILL_ID, + duchyId = DUCHY_ONE_NAME, + signingKey = csSigningKey, + consentSignalCert = csCertificate, + trustedCertificates = trustedCertificates, + dataClients = computationDataClients, + systemComputationParticipantsClient = systemComputationParticipantsStub, + systemComputationsClient = systemComputationStub, + systemComputationLogEntriesClient = systemComputationLogEntriesStub, + computationStatsClient = computationStatsStub, + throttler = throttler, + workerStubs = workerStubs, + cryptoWorker = mockCryptoWorker, + workLockDuration = Duration.ofMinutes(5), + openTelemetry = GlobalOpenTelemetry.get(), + requestChunkSizeBytes = 20, + maximumAttempts = 2, + parallelism = PARALLELISM + ) + } + + @Test + fun `exceeding max attempt should fail the computation`() = runBlocking { + // Stage 0. preparing the database and set up mock + val partialToken = + FakeComputationsDatabase.newPartialToken( + localId = LOCAL_ID, + stage = INITIALIZATION_PHASE.toProtocolStage() + ) + .build() + + val initialComputationDetails = computationDetails { + kingdomComputation = NON_AGGREGATOR_COMPUTATION_DETAILS.kingdomComputation + reachOnlyLiquidLegionsV2 = + ReachOnlyLiquidLegionsSketchAggregationV2Kt.computationDetails { + role = RoleInComputation.NON_AGGREGATOR + parameters = ROLLV2_PARAMETERS + participant += + listOf(COMPUTATION_PARTICIPANT_1, COMPUTATION_PARTICIPANT_2, COMPUTATION_PARTICIPANT_3) + } + } + + fakeComputationDb.addComputation( + partialToken.localComputationId, + partialToken.computationStage, + computationDetails = initialComputationDetails, + requisitions = REQUISITIONS + ) + + whenever(mockCryptoWorker.completeReachOnlyInitializationPhase(any())).thenAnswer { + completeReachOnlyInitializationPhaseResponse { + this.elGamalKeyPair = elGamalKeyPair { + publicKey = elGamalPublicKey { + generator = ByteString.copyFromUtf8("generator-foo") + element = ByteString.copyFromUtf8("element-foo") + } + secretKey = ByteString.copyFromUtf8("secretKey-foo") + } + } + } + + // This will result in TRANSIENT gRPC failure. + whenever(mockComputationParticipants.setParticipantRequisitionParams(any())) + .thenThrow(Status.UNKNOWN.asRuntimeException()) + + // First attempt fails, which doesn't change the computation stage. + nonAggregatorMill.pollAndProcessNextComputation() + + assertThat(fakeComputationDb[LOCAL_ID]) + .isEqualTo( + computationToken { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 1 + computationStage = INITIALIZATION_PHASE.toProtocolStage() + version = 3 // claimTask + updateComputationDetails + enqueueComputation + this.computationDetails = computationDetails { + kingdomComputation = initialComputationDetails.kingdomComputation + reachOnlyLiquidLegionsV2 = + ReachOnlyLiquidLegionsSketchAggregationV2Kt.computationDetails { + role = RoleInComputation.NON_AGGREGATOR + parameters = ROLLV2_PARAMETERS + participant += + listOf( + COMPUTATION_PARTICIPANT_1, + COMPUTATION_PARTICIPANT_2, + COMPUTATION_PARTICIPANT_3 + ) + this.localElgamalKey = elGamalKeyPair { + publicKey = elGamalPublicKey { + generator = ByteString.copyFromUtf8("generator-foo") + element = ByteString.copyFromUtf8("element-foo") + } + secretKey = ByteString.copyFromUtf8("secretKey-foo") + } + } + } + requisitions.addAll(REQUISITIONS) + } + ) + // Second attempt fails, which doesn't change the computation stage. + nonAggregatorMill.pollAndProcessNextComputation() + + assertThat(fakeComputationDb[LOCAL_ID]) + .isEqualTo( + computationToken { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 2 + computationStage = INITIALIZATION_PHASE.toProtocolStage() + version = 5 // claimTask + updateComputationDetails + enqueueComputation + this.computationDetails = computationDetails { + kingdomComputation = initialComputationDetails.kingdomComputation + reachOnlyLiquidLegionsV2 = + ReachOnlyLiquidLegionsSketchAggregationV2Kt.computationDetails { + role = RoleInComputation.NON_AGGREGATOR + parameters = ROLLV2_PARAMETERS + participant += + listOf( + COMPUTATION_PARTICIPANT_1, + COMPUTATION_PARTICIPANT_2, + COMPUTATION_PARTICIPANT_3 + ) + this.localElgamalKey = elGamalKeyPair { + publicKey = elGamalPublicKey { + generator = ByteString.copyFromUtf8("generator-foo") + element = ByteString.copyFromUtf8("element-foo") + } + secretKey = ByteString.copyFromUtf8("secretKey-foo") + } + } + } + requisitions.addAll(REQUISITIONS) + } + ) + // Third attempt fails, which will fail the computation. + nonAggregatorMill.pollAndProcessNextComputation() + + assertThat(fakeComputationDb[LOCAL_ID]) + .isEqualTo( + computationToken { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 3 + computationStage = COMPLETE.toProtocolStage() + version = 8 // claimTask + updateComputationDetails + enqueueComputation + claimTask + + // EndComputation + this.computationDetails = computationDetails { + kingdomComputation = initialComputationDetails.kingdomComputation + endingState = CompletedReason.FAILED + reachOnlyLiquidLegionsV2 = + ReachOnlyLiquidLegionsSketchAggregationV2Kt.computationDetails { + role = RoleInComputation.NON_AGGREGATOR + parameters = ROLLV2_PARAMETERS + participant += + listOf( + COMPUTATION_PARTICIPANT_1, + COMPUTATION_PARTICIPANT_2, + COMPUTATION_PARTICIPANT_3 + ) + this.localElgamalKey = elGamalKeyPair { + publicKey = elGamalPublicKey { + generator = ByteString.copyFromUtf8("generator-foo") + element = ByteString.copyFromUtf8("element-foo") + } + secretKey = ByteString.copyFromUtf8("secretKey-foo") + } + } + } + requisitions.addAll(REQUISITIONS) + } + ) + } + + @Test + fun `initialization phase`() = runBlocking { + // Stage 0. preparing the database and set up mock + val partialToken = + FakeComputationsDatabase.newPartialToken( + localId = LOCAL_ID, + stage = INITIALIZATION_PHASE.toProtocolStage() + ) + .build() + + val initialComputationDetails = computationDetails { + kingdomComputation = NON_AGGREGATOR_COMPUTATION_DETAILS.kingdomComputation + reachOnlyLiquidLegionsV2 = + ReachOnlyLiquidLegionsSketchAggregationV2Kt.computationDetails { + role = RoleInComputation.NON_AGGREGATOR + parameters = ROLLV2_PARAMETERS + participant += + listOf(COMPUTATION_PARTICIPANT_1, COMPUTATION_PARTICIPANT_2, COMPUTATION_PARTICIPANT_3) + } + } + + fakeComputationDb.addComputation( + partialToken.localComputationId, + partialToken.computationStage, + computationDetails = initialComputationDetails, + requisitions = REQUISITIONS + ) + + var cryptoRequest = CompleteReachOnlyInitializationPhaseRequest.getDefaultInstance() + whenever(mockCryptoWorker.completeReachOnlyInitializationPhase(any())).thenAnswer { + cryptoRequest = it.getArgument(0) + completeReachOnlyInitializationPhaseResponse { + this.elGamalKeyPair = elGamalKeyPair { + publicKey = elGamalPublicKey { + generator = ByteString.copyFromUtf8("generator-foo") + element = ByteString.copyFromUtf8("element-foo") + } + secretKey = ByteString.copyFromUtf8("secretKey-foo") + } + } + } + + // Stage 1. Process the above computation + nonAggregatorMill.pollAndProcessNextComputation() + + // Stage 2. Check the status of the computation + assertThat(fakeComputationDb[LOCAL_ID]) + .isEqualTo( + computationToken { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 1 + computationStage = WAIT_REQUISITIONS_AND_KEY_SET.toProtocolStage() + version = 3 // claimTask + updateComputationDetails + enqueueComputation + this.computationDetails = computationDetails { + kingdomComputation = initialComputationDetails.kingdomComputation + reachOnlyLiquidLegionsV2 = + ReachOnlyLiquidLegionsSketchAggregationV2Kt.computationDetails { + role = RoleInComputation.NON_AGGREGATOR + parameters = ROLLV2_PARAMETERS + participant += + listOf( + COMPUTATION_PARTICIPANT_1, + COMPUTATION_PARTICIPANT_2, + COMPUTATION_PARTICIPANT_3 + ) + this.localElgamalKey = elGamalKeyPair { + publicKey = elGamalPublicKey { + generator = ByteString.copyFromUtf8("generator-foo") + element = ByteString.copyFromUtf8("element-foo") + } + secretKey = ByteString.copyFromUtf8("secretKey-foo") + } + } + } + requisitions.addAll(REQUISITIONS) + } + ) + + verifyProtoArgument( + mockComputationParticipants, + SystemComputationParticipantsCoroutineImplBase::setParticipantRequisitionParams + ) + .comparingExpectedFieldsOnly() + .isEqualTo( + setParticipantRequisitionParamsRequest { + name = ComputationParticipantKey(GLOBAL_ID, DUCHY_ONE_NAME).toName() + this.requisitionParams = requisitionParams { + duchyCertificate = CONSENT_SIGNALING_CERT_NAME + reachOnlyLiquidLegionsV2 = + RequisitionParamsKt.liquidLegionsV2 { + elGamalPublicKey = CONSENT_SIGNALING_EL_GAMAL_PUBLIC_KEY.toByteString() + } + } + } + ) + + assertThat(cryptoRequest) + .isEqualTo(completeReachOnlyInitializationPhaseRequest { curveId = CURVE_ID }) + } + + @Test + fun `confirmation phase, failed due to missing local requisition`() = runBlocking { + // Stage 0. preparing the storage and set up mock + val requisition1 = REQUISITION_1 + // requisition2 is fulfilled at Duchy One, but doesn't have path set. + val requisition2 = + REQUISITION_2.copy { details = details.copy { externalFulfillingDuchyId = DUCHY_ONE_NAME } } + val computationDetailsWithoutPublicKey = computationDetails { + kingdomComputation = AGGREGATOR_COMPUTATION_DETAILS.kingdomComputation + reachOnlyLiquidLegionsV2 = + ReachOnlyLiquidLegionsSketchAggregationV2Kt.computationDetails { + role = RoleInComputation.AGGREGATOR + parameters = ROLLV2_PARAMETERS + participant += + listOf(COMPUTATION_PARTICIPANT_2, COMPUTATION_PARTICIPANT_3, COMPUTATION_PARTICIPANT_1) + // partiallyCombinedPublicKey and combinedPublicKey are the same at the aggregator. + partiallyCombinedPublicKey = COMBINED_PUBLIC_KEY + localElgamalKey = DUCHY_ONE_KEY_PAIR + } + } + fakeComputationDb.addComputation( + globalId = GLOBAL_ID, + stage = CONFIRMATION_PHASE.toProtocolStage(), + computationDetails = computationDetailsWithoutPublicKey, + requisitions = listOf(requisition1, requisition2) + ) + + whenever(mockComputationLogEntries.createComputationLogEntry(any())) + .thenReturn(ComputationLogEntry.getDefaultInstance()) + + // Stage 1. Process the above computation + aggregatorMill.pollAndProcessNextComputation() + + // Stage 2. Check the status of the computation + assertThat(fakeComputationDb[LOCAL_ID]!!) + .isEqualTo( + computationToken { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 1 + computationStage = COMPLETE.toProtocolStage() + version = 2 // claimTask + transitionStage + this.computationDetails = computationDetails { + kingdomComputation = computationDetailsWithoutPublicKey.kingdomComputation + reachOnlyLiquidLegionsV2 = computationDetailsWithoutPublicKey.reachOnlyLiquidLegionsV2 + endingState = CompletedReason.FAILED + } + requisitions.addAll(listOf(requisition1, requisition2)) + } + ) + + verifyProtoArgument( + mockComputationParticipants, + SystemComputationParticipantsCoroutineImplBase::failComputationParticipant + ) + .comparingExpectedFieldsOnly() + .isEqualTo( + failComputationParticipantRequest { + name = ComputationParticipantKey(GLOBAL_ID, DUCHY_ONE_NAME).toName() + failure = + ComputationParticipantKt.failure { + participantChildReferenceId = MILL_ID + errorMessage = + "PERMANENT error: java.lang.Exception: @Mill a nice mill, Computation 1234 " + + "failed due to:\n" + + "Cannot verify participation of all DataProviders.\n" + + "Missing expected data for requisition 222." + this.stageAttempt = stageAttempt { + stage = CONFIRMATION_PHASE.number + stageName = CONFIRMATION_PHASE.name + attemptNumber = 1 + } + } + } + ) + } + + @Test + fun `confirmation phase, passed at non-aggregator`() = runBlocking { + // Stage 0. preparing the storage and set up mock + val computationDetailsWithoutPublicKey = computationDetails { + kingdomComputation = NON_AGGREGATOR_COMPUTATION_DETAILS.kingdomComputation + reachOnlyLiquidLegionsV2 = + ReachOnlyLiquidLegionsSketchAggregationV2Kt.computationDetails { + role = RoleInComputation.NON_AGGREGATOR + parameters = ROLLV2_PARAMETERS + participant += + listOf(COMPUTATION_PARTICIPANT_1, COMPUTATION_PARTICIPANT_2, COMPUTATION_PARTICIPANT_3) + partiallyCombinedPublicKey = PARTIALLY_COMBINED_PUBLIC_KEY + localElgamalKey = DUCHY_ONE_KEY_PAIR + } + } + fakeComputationDb.addComputation( + globalId = GLOBAL_ID, + stage = CONFIRMATION_PHASE.toProtocolStage(), + computationDetails = computationDetailsWithoutPublicKey, + requisitions = REQUISITIONS + ) + + whenever(mockComputationLogEntries.createComputationLogEntry(any())) + .thenReturn(ComputationLogEntry.getDefaultInstance()) + + // Stage 1. Process the above computation + aggregatorMill.pollAndProcessNextComputation() + + // Stage 2. Check the status of the computation + assertThat(fakeComputationDb[LOCAL_ID]!!) + .isEqualTo( + computationToken { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 1 + computationStage = WAIT_TO_START.toProtocolStage() + version = 3 // claimTask + updateComputationDetail + transitionStage + computationDetails = computationDetails { + kingdomComputation = NON_AGGREGATOR_COMPUTATION_DETAILS.kingdomComputation + reachOnlyLiquidLegionsV2 = NON_AGGREGATOR_COMPUTATION_DETAILS.reachOnlyLiquidLegionsV2 + } + requisitions.addAll(REQUISITIONS) + } + ) + + verifyProtoArgument( + mockComputationParticipants, + SystemComputationParticipantsCoroutineImplBase::confirmComputationParticipant + ) + .isEqualTo( + confirmComputationParticipantRequest { + name = ComputationParticipantKey(GLOBAL_ID, DUCHY_ONE_NAME).toName() + } + ) + } + + @Test + fun `confirmation phase, passed at aggregator`() = runBlocking { + // Stage 0. preparing the storage and set up mock + val computationDetailsWithoutPublicKey = computationDetails { + kingdomComputation = AGGREGATOR_COMPUTATION_DETAILS.kingdomComputation + reachOnlyLiquidLegionsV2 = + ReachOnlyLiquidLegionsSketchAggregationV2Kt.computationDetails { + role = RoleInComputation.AGGREGATOR + parameters = ROLLV2_PARAMETERS + participant += + listOf(COMPUTATION_PARTICIPANT_2, COMPUTATION_PARTICIPANT_3, COMPUTATION_PARTICIPANT_1) + localElgamalKey = DUCHY_ONE_KEY_PAIR + } + } + fakeComputationDb.addComputation( + globalId = GLOBAL_ID, + stage = CONFIRMATION_PHASE.toProtocolStage(), + computationDetails = computationDetailsWithoutPublicKey, + requisitions = REQUISITIONS + ) + + whenever(mockComputationLogEntries.createComputationLogEntry(any())) + .thenReturn(ComputationLogEntry.getDefaultInstance()) + + // Stage 1. Process the above computation + aggregatorMill.pollAndProcessNextComputation() + + // Stage 2. Check the status of the computation + assertThat(fakeComputationDb[LOCAL_ID]!!) + .isEqualTo( + computationToken { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 1 + computationStage = WAIT_SETUP_PHASE_INPUTS.toProtocolStage() + version = 3 // claimTask + updateComputationDetails + transitionStage + blobs.addAll(listOf(newEmptyOutputBlobMetadata(0), newEmptyOutputBlobMetadata(1))) + stageSpecificDetails = computationStageDetails { + reachOnlyLiquidLegionsV2 = stageDetails { + waitSetupPhaseInputsDetails = + ReachOnlyLiquidLegionsSketchAggregationV2Kt.waitSetupPhaseInputsDetails { + externalDuchyLocalBlobId.put("DUCHY_TWO", 0L) + externalDuchyLocalBlobId.put("DUCHY_THREE", 1L) + } + } + } + computationDetails = computationDetails { + kingdomComputation = AGGREGATOR_COMPUTATION_DETAILS.kingdomComputation + reachOnlyLiquidLegionsV2 = AGGREGATOR_COMPUTATION_DETAILS.reachOnlyLiquidLegionsV2 + } + requisitions.addAll(REQUISITIONS) + } + ) + + verifyProtoArgument( + mockComputationParticipants, + SystemComputationParticipantsCoroutineImplBase::confirmComputationParticipant + ) + .isEqualTo( + confirmComputationParticipantRequest { + name = ComputationParticipantKey(GLOBAL_ID, DUCHY_ONE_NAME).toName() + } + ) + } + + @Test + fun `confirmation phase, failed due to invalid nonce and ElGamal key signature`() = runBlocking { + val COMPUTATION_PARTICIPANT_2_WITH_INVALID_SIGNATURE = computationParticipant { + duchyId = DUCHY_TWO_NAME + publicKey = DUCHY_TWO_PUBLIC_KEY + elGamalPublicKey = CONSENT_SIGNALING_EL_GAMAL_PUBLIC_KEY.toByteString() + elGamalPublicKeySignature = ByteString.copyFromUtf8("An invalid signature") + duchyCertificateDer = CONSENT_SIGNALING_CERT_DER + } + // Stage 0. preparing the storage and set up mock + val computationDetailsWithoutInvalidDuchySignature = computationDetails { + kingdomComputation = AGGREGATOR_COMPUTATION_DETAILS.kingdomComputation + reachOnlyLiquidLegionsV2 = + ReachOnlyLiquidLegionsSketchAggregationV2Kt.computationDetails { + role = RoleInComputation.AGGREGATOR + parameters = ROLLV2_PARAMETERS + participant += + listOf( + COMPUTATION_PARTICIPANT_2_WITH_INVALID_SIGNATURE, + COMPUTATION_PARTICIPANT_3, + COMPUTATION_PARTICIPANT_1 + ) + combinedPublicKey = COMBINED_PUBLIC_KEY + // partiallyCombinedPublicKey and combinedPublicKey are the same at the aggregator. + partiallyCombinedPublicKey = COMBINED_PUBLIC_KEY + localElgamalKey = DUCHY_ONE_KEY_PAIR + } + } + val requisitionWithInvalidNonce = REQUISITION_1.copy { details = details.copy { nonce = 404L } } + fakeComputationDb.addComputation( + globalId = GLOBAL_ID, + stage = CONFIRMATION_PHASE.toProtocolStage(), + computationDetails = computationDetailsWithoutInvalidDuchySignature, + requisitions = listOf(requisitionWithInvalidNonce) + ) + + whenever(mockComputationLogEntries.createComputationLogEntry(any())) + .thenReturn(ComputationLogEntry.getDefaultInstance()) + + // Stage 1. Process the above computation + aggregatorMill.pollAndProcessNextComputation() + + // Stage 2. Check the status of the computation + assertThat(fakeComputationDb[LOCAL_ID]!!) + .isEqualTo( + computationToken { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 1 + computationStage = COMPLETE.toProtocolStage() + version = 2 // claimTask + transitionStage + computationDetails = computationDetails { + kingdomComputation = computationDetailsWithoutInvalidDuchySignature.kingdomComputation + reachOnlyLiquidLegionsV2 = + computationDetailsWithoutInvalidDuchySignature.reachOnlyLiquidLegionsV2 + endingState = CompletedReason.FAILED + } + requisitions.addAll(listOf(requisitionWithInvalidNonce)) + } + ) + + verifyProtoArgument( + mockComputationParticipants, + SystemComputationParticipantsCoroutineImplBase::failComputationParticipant + ) + .comparingExpectedFieldsOnly() + .isEqualTo( + failComputationParticipantRequest { + name = ComputationParticipantKey(GLOBAL_ID, DUCHY_ONE_NAME).toName() + failure = + ComputationParticipantKt.failure { + participantChildReferenceId = MILL_ID + errorMessage = + "PERMANENT error: java.lang.Exception: @Mill a nice mill, Computation 1234 " + + "failed due to:\n" + + "Cannot verify participation of all DataProviders.\n" + + "Invalid ElGamal public key signature for Duchy $DUCHY_TWO_NAME" + this.stageAttempt = stageAttempt { + stage = CONFIRMATION_PHASE.number + stageName = CONFIRMATION_PHASE.name + attemptNumber = 1 + } + } + } + ) + } + + @Test + fun `setup phase at non-aggregator using cached result`() = runBlocking { + // Stage 0. preparing the storage and set up mock + val partialToken = + FakeComputationsDatabase.newPartialToken( + localId = LOCAL_ID, + stage = SETUP_PHASE.toProtocolStage() + ) + .build() + val requisitionBlobContext = + RequisitionBlobContext(GLOBAL_ID, REQUISITION_1.externalKey.externalRequisitionId) + requisitionStore.writeString(requisitionBlobContext, "local_requisition") + val cachedBlobContext = ComputationBlobContext(GLOBAL_ID, SETUP_PHASE.toProtocolStage(), 1L) + computationStore.writeString(cachedBlobContext, "cached result") + fakeComputationDb.addComputation( + partialToken.localComputationId, + partialToken.computationStage, + computationDetails = NON_AGGREGATOR_COMPUTATION_DETAILS, + requisitions = listOf(REQUISITION_1, REQUISITION_2, REQUISITION_3), + blobs = listOf(cachedBlobContext.toMetadata(ComputationBlobDependency.OUTPUT)) + ) + + // Stage 1. Process the above computation + nonAggregatorMill.pollAndProcessNextComputation() + + // Stage 2. Check the status of the computation + assertThat(fakeComputationDb[LOCAL_ID]) + .isEqualTo( + computationToken { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 1 + computationStage = WAIT_EXECUTION_PHASE_INPUTS.toProtocolStage() + blobs.addAll( + listOf( + computationStageBlobMetadata { + dependencyType = ComputationBlobDependency.INPUT + blobId = 0L + path = cachedBlobContext.blobKey + }, + computationStageBlobMetadata { + dependencyType = ComputationBlobDependency.OUTPUT + blobId = 1L + } + ) + ) + version = 2 // claimTask + transitionStage + computationDetails = NON_AGGREGATOR_COMPUTATION_DETAILS + requisitions.addAll(listOf(REQUISITION_1, REQUISITION_2, REQUISITION_3)) + } + ) + + assertThat(computationControlRequests) + .containsExactlyElementsIn( + buildAdvanceComputationRequests(GLOBAL_ID, SETUP_PHASE_INPUT, "cached result") + ) + .inOrder() + } + + @Test + fun `setup phase at non-aggregator using calculated result`() = runBlocking { + // Stage 0. preparing the storage and set up mock + val partialToken = + FakeComputationsDatabase.newPartialToken( + localId = LOCAL_ID, + stage = SETUP_PHASE.toProtocolStage() + ) + .build() + val requisitionBlobContext = + RequisitionBlobContext(GLOBAL_ID, REQUISITION_1.externalKey.externalRequisitionId) + val calculatedBlobContext = ComputationBlobContext(GLOBAL_ID, SETUP_PHASE.toProtocolStage(), 1L) + requisitionStore.writeString(requisitionBlobContext, "local_requisition") + fakeComputationDb.addComputation( + partialToken.localComputationId, + partialToken.computationStage, + computationDetails = NON_AGGREGATOR_COMPUTATION_DETAILS, + requisitions = listOf(REQUISITION_1, REQUISITION_2, REQUISITION_3), + blobs = listOf(newEmptyOutputBlobMetadata(calculatedBlobContext.blobId)) + ) + + var cryptoRequest = CompleteReachOnlySetupPhaseRequest.getDefaultInstance() + whenever(mockCryptoWorker.completeReachOnlySetupPhase(any())).thenAnswer { + cryptoRequest = it.getArgument(0) + val postFix = ByteString.copyFromUtf8("-completeReachOnlySetupPhase") + completeReachOnlySetupPhaseResponse { + combinedRegisterVector = cryptoRequest.combinedRegisterVector.concat(postFix) + serializedExcessiveNoiseCiphertext = ByteString.copyFromUtf8("-encryptedNoise") + } + } + + // Stage 1. Process the above computation + nonAggregatorMill.pollAndProcessNextComputation() + + // Stage 2. Check the status of the computation + val blobKey = calculatedBlobContext.blobKey + assertThat(fakeComputationDb[LOCAL_ID]) + .isEqualTo( + computationToken { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 1 + computationStage = WAIT_EXECUTION_PHASE_INPUTS.toProtocolStage() + blobs.addAll( + listOf( + computationStageBlobMetadata { + dependencyType = ComputationBlobDependency.INPUT + blobId = 0L + path = blobKey + }, + computationStageBlobMetadata { + dependencyType = ComputationBlobDependency.OUTPUT + blobId = 1L + } + ) + ) + version = 3 // claimTask + writeOutputBlob + transitionStage + computationDetails = NON_AGGREGATOR_COMPUTATION_DETAILS + requisitions.addAll(listOf(REQUISITION_1, REQUISITION_2, REQUISITION_3)) + } + ) + + assertThat(computationStore.get(blobKey)?.readToString()) + .isEqualTo("local_requisition-completeReachOnlySetupPhase-encryptedNoise") + + assertThat(computationControlRequests) + .containsExactlyElementsIn( + buildAdvanceComputationRequests( + GLOBAL_ID, + SETUP_PHASE_INPUT, + "local_requisition-co", + "mpleteReachOnlySetup", + "Phase-encryptedNoise" + ) + ) + .inOrder() + + assertThat(cryptoRequest) + .isEqualTo( + completeReachOnlySetupPhaseRequest { + combinedRegisterVector = ByteString.copyFromUtf8("local_requisition") + curveId = CURVE_ID + noiseParameters = registerNoiseGenerationParameters { + compositeElGamalPublicKey = COMBINED_PUBLIC_KEY + curveId = CURVE_ID + contributorsCount = WORKER_COUNT + totalSketchesCount = REQUISITIONS.size + dpParams = reachNoiseDifferentialPrivacyParams { + blindHistogram = TEST_NOISE_CONFIG.reachNoiseConfig.blindHistogramNoise + noiseForPublisherNoise = TEST_NOISE_CONFIG.reachNoiseConfig.noiseForPublisherNoise + globalReachDpNoise = TEST_NOISE_CONFIG.reachNoiseConfig.globalReachDpNoise + } + } + compositeElGamalPublicKey = COMBINED_PUBLIC_KEY + parallelism = PARALLELISM + } + ) + } + + @Test + fun `setup phase at aggregator using cached result`() = runBlocking { + // Stage 0. preparing the storage and set up mock + val partialToken = + FakeComputationsDatabase.newPartialToken( + localId = LOCAL_ID, + stage = SETUP_PHASE.toProtocolStage() + ) + .build() + val requisitionBlobContext = + RequisitionBlobContext(GLOBAL_ID, REQUISITION_1.externalKey.externalRequisitionId) + requisitionStore.writeString(requisitionBlobContext, "local_requisition") + val cachedBlobContext = ComputationBlobContext(GLOBAL_ID, SETUP_PHASE.toProtocolStage(), 1L) + computationStore.writeString(cachedBlobContext, "cached result") + fakeComputationDb.addComputation( + partialToken.localComputationId, + partialToken.computationStage, + computationDetails = AGGREGATOR_COMPUTATION_DETAILS, + requisitions = listOf(REQUISITION_1, REQUISITION_2, REQUISITION_3), + blobs = listOf(cachedBlobContext.toMetadata(ComputationBlobDependency.OUTPUT)) + ) + + // Stage 1. Process the above computation + nonAggregatorMill.pollAndProcessNextComputation() + + // Stage 2. Check the status of the computation + assertThat(fakeComputationDb[LOCAL_ID]) + .isEqualTo( + computationToken { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 1 + computationStage = WAIT_EXECUTION_PHASE_INPUTS.toProtocolStage() + blobs.addAll( + listOf( + computationStageBlobMetadata { + dependencyType = ComputationBlobDependency.INPUT + blobId = 0L + path = cachedBlobContext.blobKey + }, + computationStageBlobMetadata { + dependencyType = ComputationBlobDependency.OUTPUT + blobId = 1L + } + ) + ) + version = 2 // claimTask + transitionStage + computationDetails = AGGREGATOR_COMPUTATION_DETAILS + requisitions.addAll(listOf(REQUISITION_1, REQUISITION_2, REQUISITION_3)) + } + ) + + assertThat(computationControlRequests) + .containsExactlyElementsIn( + buildAdvanceComputationRequests(GLOBAL_ID, EXECUTION_PHASE_INPUT, "cached result") + ) + .inOrder() + } + + @Test + fun `setup phase at aggregator using calculated result`() = runBlocking { + // Stage 0. preparing the storage and set up mock + val partialToken = + FakeComputationsDatabase.newPartialToken( + localId = LOCAL_ID, + stage = SETUP_PHASE.toProtocolStage() + ) + .build() + val requisitionBlobContext = + RequisitionBlobContext(GLOBAL_ID, REQUISITION_1.externalKey.externalRequisitionId) + requisitionStore.writeString(requisitionBlobContext, "local_requisition_") + val inputBlob0Context = ComputationBlobContext(GLOBAL_ID, SETUP_PHASE.toProtocolStage(), 0L) + computationStore.writeString(inputBlob0Context, "duchy_2_sketch_" + NOISE_CIPHERTEXT) + val inputBlob1Context = ComputationBlobContext(GLOBAL_ID, SETUP_PHASE.toProtocolStage(), 1L) + computationStore.writeString(inputBlob1Context, "duchy_3_sketch_" + NOISE_CIPHERTEXT) + fakeComputationDb.addComputation( + partialToken.localComputationId, + partialToken.computationStage, + computationDetails = AGGREGATOR_COMPUTATION_DETAILS, + blobs = + listOf( + newInputBlobMetadata(0L, inputBlob0Context.blobKey), + newInputBlobMetadata(1L, inputBlob1Context.blobKey), + newEmptyOutputBlobMetadata(3L) + ), + requisitions = listOf(REQUISITION_1, REQUISITION_2, REQUISITION_3) + ) + + var cryptoRequest = CompleteReachOnlySetupPhaseRequest.getDefaultInstance() + whenever(mockCryptoWorker.completeReachOnlySetupPhaseAtAggregator(any())).thenAnswer { + cryptoRequest = it.getArgument(0) + val postFix = ByteString.copyFromUtf8("-completeReachOnlySetupPhase") + completeReachOnlySetupPhaseResponse { + combinedRegisterVector = cryptoRequest.combinedRegisterVector.concat(postFix) + serializedExcessiveNoiseCiphertext = ByteString.copyFromUtf8("-encryptedNoise") + } + } + + // Stage 1. Process the above computation + aggregatorMill.pollAndProcessNextComputation() + + // Stage 2. Check the status of the computation + val blobKey = ComputationBlobContext(GLOBAL_ID, SETUP_PHASE.toProtocolStage(), 3L).blobKey + assertThat(fakeComputationDb[LOCAL_ID]) + .isEqualTo( + computationToken { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 1 + computationStage = WAIT_EXECUTION_PHASE_INPUTS.toProtocolStage() + blobs.addAll( + listOf( + computationStageBlobMetadata { + dependencyType = ComputationBlobDependency.INPUT + blobId = 0 + path = blobKey + }, + computationStageBlobMetadata { + dependencyType = ComputationBlobDependency.OUTPUT + blobId = 1 + } + ) + ) + version = 3 // claimTask + writeOutputBlob + transitionStage + computationDetails = AGGREGATOR_COMPUTATION_DETAILS + requisitions.addAll(listOf(REQUISITION_1, REQUISITION_2, REQUISITION_3)) + } + ) + + assertThat(computationStore.get(blobKey)?.readToString()) + .isEqualTo( + "local_requisition_duchy_2_sketch_duchy_3_sketch_-completeReachOnlySetupPhase-encryptedNoise" + ) + + assertThat(computationControlRequests) + .containsExactlyElementsIn( + buildAdvanceComputationRequests( + GLOBAL_ID, + EXECUTION_PHASE_INPUT, + "local_requisition_du", + "chy_2_sketch_duchy_3", + "_sketch_-completeRea", + "chOnlySetupPhase-enc", + "ryptedNoise" + ) + ) + .inOrder() + + assertThat(cryptoRequest) + .isEqualTo( + completeReachOnlySetupPhaseRequest { + combinedRegisterVector = + ByteString.copyFromUtf8("local_requisition_duchy_2_sketch_duchy_3_sketch_") + curveId = CURVE_ID + noiseParameters = registerNoiseGenerationParameters { + compositeElGamalPublicKey = COMBINED_PUBLIC_KEY + curveId = CURVE_ID + contributorsCount = WORKER_COUNT + totalSketchesCount = REQUISITIONS.size + dpParams = reachNoiseDifferentialPrivacyParams { + blindHistogram = TEST_NOISE_CONFIG.reachNoiseConfig.blindHistogramNoise + noiseForPublisherNoise = TEST_NOISE_CONFIG.reachNoiseConfig.noiseForPublisherNoise + globalReachDpNoise = TEST_NOISE_CONFIG.reachNoiseConfig.globalReachDpNoise + } + } + compositeElGamalPublicKey = COMBINED_PUBLIC_KEY + serializedExcessiveNoiseCiphertext = + SERIALIZED_NOISE_CIPHERTEXT.concat(SERIALIZED_NOISE_CIPHERTEXT) + parallelism = PARALLELISM + } + ) + } + + @Test + fun `setup phase at aggregator, failed due to invalid input blob size`() = runBlocking { + // Stage 0. preparing the storage and set up mock + val partialToken = + FakeComputationsDatabase.newPartialToken( + localId = LOCAL_ID, + stage = SETUP_PHASE.toProtocolStage() + ) + .build() + val requisitionBlobContext = + RequisitionBlobContext(GLOBAL_ID, REQUISITION_1.externalKey.externalRequisitionId) + requisitionStore.writeString(requisitionBlobContext, "local_requisition_") + val inputBlob0Context = ComputationBlobContext(GLOBAL_ID, SETUP_PHASE.toProtocolStage(), 0L) + computationStore.writeString(inputBlob0Context, "duchy_2_sketch_") + val inputBlob1Context = ComputationBlobContext(GLOBAL_ID, SETUP_PHASE.toProtocolStage(), 1L) + computationStore.writeString(inputBlob1Context, "duchy_3_sketch_" + NOISE_CIPHERTEXT) + fakeComputationDb.addComputation( + partialToken.localComputationId, + partialToken.computationStage, + computationDetails = AGGREGATOR_COMPUTATION_DETAILS, + blobs = + listOf( + newInputBlobMetadata(0L, inputBlob0Context.blobKey), + newInputBlobMetadata(1L, inputBlob1Context.blobKey), + newEmptyOutputBlobMetadata(3L) + ), + requisitions = listOf(REQUISITION_1, REQUISITION_2, REQUISITION_3) + ) + + // Stage 1. Process the above computation + aggregatorMill.pollAndProcessNextComputation() + + // Stage 2. Check the status of the computation + assertThat(fakeComputationDb[LOCAL_ID]!!) + .isEqualTo( + computationToken { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 1 + computationStage = COMPLETE.toProtocolStage() + version = 2 // claimTask + transitionStage + computationDetails = computationDetails { + kingdomComputation = AGGREGATOR_COMPUTATION_DETAILS.kingdomComputation + reachOnlyLiquidLegionsV2 = AGGREGATOR_COMPUTATION_DETAILS.reachOnlyLiquidLegionsV2 + endingState = CompletedReason.FAILED + } + requisitions.addAll(listOf(REQUISITION_1, REQUISITION_2, REQUISITION_3)) + } + ) + + verifyProtoArgument( + mockComputationParticipants, + SystemComputationParticipantsCoroutineImplBase::failComputationParticipant + ) + .comparingExpectedFieldsOnly() + .isEqualTo( + failComputationParticipantRequest { + name = ComputationParticipantKey(GLOBAL_ID, DUCHY_ONE_NAME).toName() + failure = + ComputationParticipantKt.failure { + participantChildReferenceId = MILL_ID + errorMessage = + "PERMANENT error: java.lang.IllegalArgumentException: Invalid input blob size. Input" + + " blob duchy_2_sketch_ has size 15 which is less than (66)." + this.stageAttempt = stageAttempt { + stage = SETUP_PHASE.number + stageName = SETUP_PHASE.name + attemptNumber = 1 + } + } + } + ) + } + + @Test + fun `execution phase at non-aggregator using cached result`() = runBlocking { + // Stage 0. preparing the storage and set up mock + val partialToken = + FakeComputationsDatabase.newPartialToken( + localId = LOCAL_ID, + stage = EXECUTION_PHASE.toProtocolStage() + ) + .build() + val inputBlobContext = ComputationBlobContext(GLOBAL_ID, EXECUTION_PHASE.toProtocolStage(), 0L) + computationStore.writeString(inputBlobContext, "sketch" + NOISE_CIPHERTEXT) + val cachedBlobContext = ComputationBlobContext(GLOBAL_ID, EXECUTION_PHASE.toProtocolStage(), 1L) + computationStore.writeString(cachedBlobContext, "cached result") + fakeComputationDb.addComputation( + partialToken.localComputationId, + partialToken.computationStage, + computationDetails = NON_AGGREGATOR_COMPUTATION_DETAILS, + blobs = + listOf( + inputBlobContext.toMetadata(ComputationBlobDependency.INPUT), + cachedBlobContext.toMetadata(ComputationBlobDependency.OUTPUT) + ), + requisitions = REQUISITIONS + ) + + // Stage 1. Process the above computation + nonAggregatorMill.pollAndProcessNextComputation() + + // Stage 2. Check the status of the computation + assertThat(fakeComputationDb[LOCAL_ID]) + .isEqualTo( + computationToken { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 1 + computationStage = COMPLETE.toProtocolStage() + version = 2 // claimTask + transitionStage + computationDetails = computationDetails { + kingdomComputation = NON_AGGREGATOR_COMPUTATION_DETAILS.kingdomComputation + reachOnlyLiquidLegionsV2 = NON_AGGREGATOR_COMPUTATION_DETAILS.reachOnlyLiquidLegionsV2 + endingState = CompletedReason.SUCCEEDED + } + requisitions += REQUISITIONS + } + ) + + assertThat(computationControlRequests) + .containsExactlyElementsIn( + buildAdvanceComputationRequests(GLOBAL_ID, EXECUTION_PHASE_INPUT, "cached result") + ) + .inOrder() + } + + @Test + fun `execution phase at non-aggregator using calculated result`() = runBlocking { + // Stage 0. preparing the storage and set up mock + val partialToken = + FakeComputationsDatabase.newPartialToken( + localId = LOCAL_ID, + stage = EXECUTION_PHASE.toProtocolStage() + ) + .build() + val inputBlobContext = ComputationBlobContext(GLOBAL_ID, EXECUTION_PHASE.toProtocolStage(), 0L) + val calculatedBlobContext = + ComputationBlobContext(GLOBAL_ID, EXECUTION_PHASE.toProtocolStage(), 1L) + computationStore.writeString(inputBlobContext, "data" + NOISE_CIPHERTEXT) + fakeComputationDb.addComputation( + partialToken.localComputationId, + partialToken.computationStage, + computationDetails = NON_AGGREGATOR_COMPUTATION_DETAILS, + blobs = + listOf( + inputBlobContext.toMetadata(ComputationBlobDependency.INPUT), + newEmptyOutputBlobMetadata(calculatedBlobContext.blobId) + ), + requisitions = REQUISITIONS + ) + + var cryptoRequest = CompleteReachOnlyExecutionPhaseRequest.getDefaultInstance() + whenever(mockCryptoWorker.completeReachOnlyExecutionPhase(any())).thenAnswer { + cryptoRequest = it.getArgument(0) + val postFix = ByteString.copyFromUtf8("-completeReachOnlyExecutionPhase") + completeReachOnlyExecutionPhaseResponse { + combinedRegisterVector = cryptoRequest.combinedRegisterVector.concat(postFix) + serializedExcessiveNoiseCiphertext = ByteString.copyFromUtf8("-partiallyDecryptedNoise") + } + } + + // Stage 1. Process the above computation + nonAggregatorMill.pollAndProcessNextComputation() + + // Stage 2. Check the status of the computation + val blobKey = calculatedBlobContext.blobKey + assertThat(fakeComputationDb[LOCAL_ID]) + .isEqualTo( + computationToken { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 1 + computationStage = COMPLETE.toProtocolStage() + version = 3 // claimTask + writeOutputBlob + transitionStage + computationDetails = computationDetails { + kingdomComputation = NON_AGGREGATOR_COMPUTATION_DETAILS.kingdomComputation + reachOnlyLiquidLegionsV2 = NON_AGGREGATOR_COMPUTATION_DETAILS.reachOnlyLiquidLegionsV2 + endingState = CompletedReason.SUCCEEDED + } + requisitions.addAll(REQUISITIONS) + } + ) + assertThat(computationStore.get(blobKey)?.readToString()) + .isEqualTo("data-completeReachOnlyExecutionPhase-partiallyDecryptedNoise") + + assertThat(cryptoRequest) + .isEqualTo( + completeReachOnlyExecutionPhaseRequest { + combinedRegisterVector = ByteString.copyFromUtf8("data") + localElGamalKeyPair = DUCHY_ONE_KEY_PAIR + curveId = CURVE_ID + serializedExcessiveNoiseCiphertext = SERIALIZED_NOISE_CIPHERTEXT + parallelism = PARALLELISM + } + ) + } + + @Test + fun `execution phase at non-aggregator, failed due to invalid input blob size`() = runBlocking { + // Stage 0. preparing the storage and set up mock + val partialToken = + FakeComputationsDatabase.newPartialToken( + localId = LOCAL_ID, + stage = EXECUTION_PHASE.toProtocolStage() + ) + .build() + val inputBlobContext = ComputationBlobContext(GLOBAL_ID, EXECUTION_PHASE.toProtocolStage(), 0L) + val calculatedBlobContext = + ComputationBlobContext(GLOBAL_ID, EXECUTION_PHASE.toProtocolStage(), 1L) + computationStore.writeString(inputBlobContext, "data") + fakeComputationDb.addComputation( + partialToken.localComputationId, + partialToken.computationStage, + computationDetails = NON_AGGREGATOR_COMPUTATION_DETAILS, + blobs = + listOf( + inputBlobContext.toMetadata(ComputationBlobDependency.INPUT), + newEmptyOutputBlobMetadata(calculatedBlobContext.blobId) + ), + requisitions = REQUISITIONS + ) + + // Stage 1. Process the above computation + nonAggregatorMill.pollAndProcessNextComputation() + + // Stage 2. Check the status of the computation + assertThat(fakeComputationDb[LOCAL_ID]!!) + .isEqualTo( + computationToken { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 1 + computationStage = COMPLETE.toProtocolStage() + version = 2 // claimTask + transitionStage + computationDetails = computationDetails { + kingdomComputation = NON_AGGREGATOR_COMPUTATION_DETAILS.kingdomComputation + reachOnlyLiquidLegionsV2 = NON_AGGREGATOR_COMPUTATION_DETAILS.reachOnlyLiquidLegionsV2 + endingState = CompletedReason.FAILED + } + requisitions.addAll(listOf(REQUISITION_1, REQUISITION_2, REQUISITION_3)) + } + ) + + verifyProtoArgument( + mockComputationParticipants, + SystemComputationParticipantsCoroutineImplBase::failComputationParticipant + ) + .comparingExpectedFieldsOnly() + .isEqualTo( + failComputationParticipantRequest { + name = ComputationParticipantKey(GLOBAL_ID, DUCHY_ONE_NAME).toName() + failure = + ComputationParticipantKt.failure { + participantChildReferenceId = MILL_ID + errorMessage = + "PERMANENT error: Invalid input blob size. Input blob data has size 4 which is less than (66)." + this.stageAttempt = stageAttempt { + stage = EXECUTION_PHASE.number + stageName = EXECUTION_PHASE.name + attemptNumber = 1 + } + } + } + ) + } + + @Test + fun `execution phase at aggregator using cached result`() = runBlocking { + // Stage 0. preparing the storage and set up mock + val partialToken = + FakeComputationsDatabase.newPartialToken( + localId = LOCAL_ID, + stage = EXECUTION_PHASE.toProtocolStage() + ) + .build() + val inputBlobContext = ComputationBlobContext(GLOBAL_ID, EXECUTION_PHASE.toProtocolStage(), 0L) + computationStore.writeString(inputBlobContext, "sketch" + NOISE_CIPHERTEXT) + val cachedBlobContext = ComputationBlobContext(GLOBAL_ID, EXECUTION_PHASE.toProtocolStage(), 1L) + computationStore.writeString(cachedBlobContext, "cached result") + fakeComputationDb.addComputation( + partialToken.localComputationId, + partialToken.computationStage, + computationDetails = AGGREGATOR_COMPUTATION_DETAILS, + blobs = + listOf( + inputBlobContext.toMetadata(ComputationBlobDependency.INPUT), + cachedBlobContext.toMetadata(ComputationBlobDependency.OUTPUT) + ), + requisitions = REQUISITIONS + ) + + // Stage 1. Process the above computation + nonAggregatorMill.pollAndProcessNextComputation() + + // Stage 2. Check the status of the computation + assertThat(fakeComputationDb[LOCAL_ID]) + .isEqualTo( + computationToken { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 1 + computationStage = COMPLETE.toProtocolStage() + version = 2 // claimTask + transitionStage + computationDetails = computationDetails { + kingdomComputation = AGGREGATOR_COMPUTATION_DETAILS.kingdomComputation + reachOnlyLiquidLegionsV2 = AGGREGATOR_COMPUTATION_DETAILS.reachOnlyLiquidLegionsV2 + endingState = CompletedReason.SUCCEEDED + } + requisitions += REQUISITIONS + } + ) + } + + @Test + fun `execution phase at aggregator using calculated result`() = runBlocking { + // Stage 0. preparing the storage and set up mock + val partialToken = + FakeComputationsDatabase.newPartialToken( + localId = LOCAL_ID, + stage = EXECUTION_PHASE.toProtocolStage() + ) + .build() + val computationDetailsWithVidSamplingWidth = + AGGREGATOR_COMPUTATION_DETAILS.copy { + kingdomComputation = + kingdomComputation.copy { + measurementSpec = SERIALIZED_MEASUREMENT_SPEC_WITH_VID_SAMPLING_WIDTH + } + } + val inputBlobContext = ComputationBlobContext(GLOBAL_ID, EXECUTION_PHASE.toProtocolStage(), 0L) + val calculatedBlobContext = + ComputationBlobContext(GLOBAL_ID, EXECUTION_PHASE.toProtocolStage(), 1L) + computationStore.writeString(inputBlobContext, "data" + NOISE_CIPHERTEXT) + fakeComputationDb.addComputation( + partialToken.localComputationId, + partialToken.computationStage, + computationDetails = computationDetailsWithVidSamplingWidth, + blobs = + listOf( + inputBlobContext.toMetadata(ComputationBlobDependency.INPUT), + newEmptyOutputBlobMetadata(calculatedBlobContext.blobId) + ), + requisitions = REQUISITIONS + ) + val testReach = 123L + var cryptoRequest = CompleteReachOnlyExecutionPhaseAtAggregatorRequest.getDefaultInstance() + whenever(mockCryptoWorker.completeReachOnlyExecutionPhaseAtAggregator(any())).thenAnswer { + cryptoRequest = it.getArgument(0) + completeReachOnlyExecutionPhaseAtAggregatorResponse { reach = testReach } + } + var systemComputationResult = SetComputationResultRequest.getDefaultInstance() + whenever(mockSystemComputations.setComputationResult(any())).thenAnswer { + systemComputationResult = it.getArgument(0) + Computation.getDefaultInstance() + } + + // Stage 1. Process the above computation + aggregatorMill.pollAndProcessNextComputation() + + // Stage 2. Check the status of the computation + val blobKey = calculatedBlobContext.blobKey + assertThat(fakeComputationDb[LOCAL_ID]) + .isEqualTo( + computationToken { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 1 + computationStage = COMPLETE.toProtocolStage() + version = 3 // claimTask + writeOutputBlob + transitionStage + computationDetails = computationDetails { + kingdomComputation = computationDetailsWithVidSamplingWidth.kingdomComputation + reachOnlyLiquidLegionsV2 = + computationDetailsWithVidSamplingWidth.reachOnlyLiquidLegionsV2 + endingState = CompletedReason.SUCCEEDED + } + requisitions.addAll(REQUISITIONS) + } + ) + assertThat(computationStore.get(blobKey)?.readToString()).isNotEmpty() + + assertThat(systemComputationResult.name).isEqualTo("computations/$GLOBAL_ID") + // The signature is non-deterministic, so we only verity the encryption is not empty. + assertThat(systemComputationResult.encryptedResult).isNotEmpty() + assertThat(systemComputationResult) + .comparingExpectedFieldsOnly() + .isEqualTo( + setComputationResultRequest { + name = "computations/$GLOBAL_ID" + aggregatorCertificate = CONSENT_SIGNALING_CERT_NAME + resultPublicKey = ENCRYPTION_PUBLIC_KEY.toByteString() + } + ) + + assertThat(cryptoRequest) + .isEqualTo( + completeReachOnlyExecutionPhaseAtAggregatorRequest { + combinedRegisterVector = ByteString.copyFromUtf8("data") + localElGamalKeyPair = DUCHY_ONE_KEY_PAIR + curveId = CURVE_ID + serializedExcessiveNoiseCiphertext = SERIALIZED_NOISE_CIPHERTEXT + liquidLegionsParameters = liquidLegionsSketchParameters { + decayRate = DECAY_RATE + size = SKETCH_SIZE + } + reachDpNoiseBaseline = globalReachDpNoiseBaseline { + contributorsCount = WORKER_COUNT + globalReachDpNoise = TEST_NOISE_CONFIG.reachNoiseConfig.globalReachDpNoise + } + vidSamplingIntervalWidth = 0.5f + noiseParameters = registerNoiseGenerationParameters { + compositeElGamalPublicKey = COMBINED_PUBLIC_KEY + curveId = CURVE_ID + contributorsCount = WORKER_COUNT + totalSketchesCount = REQUISITIONS.size + dpParams = reachNoiseDifferentialPrivacyParams { + blindHistogram = TEST_NOISE_CONFIG.reachNoiseConfig.blindHistogramNoise + noiseForPublisherNoise = TEST_NOISE_CONFIG.reachNoiseConfig.noiseForPublisherNoise + globalReachDpNoise = TEST_NOISE_CONFIG.reachNoiseConfig.globalReachDpNoise + } + } + parallelism = PARALLELISM + } + ) + } + + @Test + fun `execution phase at aggregator, failed due to invalid input blob size`() = runBlocking { + // Stage 0. preparing the storage and set up mock + val partialToken = + FakeComputationsDatabase.newPartialToken( + localId = LOCAL_ID, + stage = EXECUTION_PHASE.toProtocolStage() + ) + .build() + val inputBlobContext = ComputationBlobContext(GLOBAL_ID, EXECUTION_PHASE.toProtocolStage(), 0L) + val calculatedBlobContext = + ComputationBlobContext(GLOBAL_ID, EXECUTION_PHASE.toProtocolStage(), 1L) + computationStore.writeString(inputBlobContext, "data") + fakeComputationDb.addComputation( + partialToken.localComputationId, + partialToken.computationStage, + computationDetails = AGGREGATOR_COMPUTATION_DETAILS, + blobs = + listOf( + inputBlobContext.toMetadata(ComputationBlobDependency.INPUT), + newEmptyOutputBlobMetadata(calculatedBlobContext.blobId) + ), + requisitions = REQUISITIONS + ) + + // Stage 1. Process the above computation + nonAggregatorMill.pollAndProcessNextComputation() + + // Stage 2. Check the status of the computation + assertThat(fakeComputationDb[LOCAL_ID]!!) + .isEqualTo( + computationToken { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 1 + computationStage = COMPLETE.toProtocolStage() + version = 2 // claimTask + transitionStage + computationDetails = computationDetails { + kingdomComputation = AGGREGATOR_COMPUTATION_DETAILS.kingdomComputation + reachOnlyLiquidLegionsV2 = AGGREGATOR_COMPUTATION_DETAILS.reachOnlyLiquidLegionsV2 + endingState = CompletedReason.FAILED + } + requisitions.addAll(listOf(REQUISITION_1, REQUISITION_2, REQUISITION_3)) + } + ) + + verifyProtoArgument( + mockComputationParticipants, + SystemComputationParticipantsCoroutineImplBase::failComputationParticipant + ) + .comparingExpectedFieldsOnly() + .isEqualTo( + failComputationParticipantRequest { + name = ComputationParticipantKey(GLOBAL_ID, DUCHY_ONE_NAME).toName() + failure = + ComputationParticipantKt.failure { + participantChildReferenceId = MILL_ID + errorMessage = + "PERMANENT error: Invalid input blob size. Input blob data has size 4 which is less than (66)." + this.stageAttempt = stageAttempt { + stage = EXECUTION_PHASE.number + stageName = EXECUTION_PHASE.name + attemptNumber = 1 + } + } + } + ) + } +} + +private fun ComputationBlobContext.toMetadata(dependencyType: ComputationBlobDependency) = + computationStageBlobMetadata { + blobId = this@toMetadata.blobId + path = blobKey + this.dependencyType = dependencyType + } + +private suspend fun Blob.readToString(): String = read().flatten().toStringUtf8() + +private suspend fun ComputationStore.writeString( + context: ComputationBlobContext, + content: String +): Blob = write(context, ByteString.copyFromUtf8(content)) + +private suspend fun RequisitionStore.writeString( + context: RequisitionBlobContext, + content: String +): Blob = write(context, ByteString.copyFromUtf8(content)) diff --git a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/BUILD.bazel b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/BUILD.bazel index ff0ea7a100a..64d03902a74 100644 --- a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/BUILD.bazel +++ b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/BUILD.bazel @@ -30,3 +30,35 @@ kt_jvm_test( "@wfa_common_jvm//imports/kotlin/kotlin/test", ], ) + +kt_jvm_test( + name = "ReachOnlyLiquidLegionsV2EncryptionUtilityTest", + srcs = ["ReachOnlyLiquidLegionsV2EncryptionUtilityTest.kt"], + test_class = "org.wfanet.measurement.duchy.daemon.mill.liquidlegionsv2.crypto.ReachOnlyLiquidLegionsV2EncryptionUtilityTest", + deps = [ + "//src/main/kotlin/org/wfanet/measurement/duchy/daemon/utils:computation_conversions", + "//src/main/proto/wfa/any_sketch:sketch_kt_jvm_proto", + "//src/main/proto/wfa/any_sketch/crypto:sketch_encryption_methods_kt_jvm_proto", + "//src/main/proto/wfa/measurement/internal/duchy/protocol:reach_only_liquid_legions_v2_encryption_methods_kt_jvm_proto", + "//src/main/swig/protocol/reachonlyliquidlegionsv2:reach_only_liquid_legions_v2_encryption_utility", + "@any_sketch_java//src/main/java/org/wfanet/anysketch/crypto:sketch_encrypter_adapter", + "@any_sketch_java//src/main/java/org/wfanet/estimation:estimators", + "@wfa_common_jvm//imports/java/com/google/common/truth", + "@wfa_common_jvm//imports/java/com/google/protobuf", + "@wfa_common_jvm//imports/java/org/junit", + "@wfa_common_jvm//imports/kotlin/kotlin/test", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common", + ], +) + +kt_jvm_test( + name = "JniReachOnlyLiquidLegionsV2EncryptionTest", + srcs = ["JniReachOnlyLiquidLegionsV2EncryptionTest.kt"], + test_class = "org.wfanet.measurement.duchy.daemon.mill.liquidlegionsv2.crypto.JniReachOnlyLiquidLegionsV2EncryptionTest", + deps = [ + "//src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto:reachonlyliquidlegionsv2encryption", + "//src/main/proto/wfa/measurement/internal/duchy/protocol:reach_only_liquid_legions_v2_encryption_methods_kt_jvm_proto", + "@wfa_common_jvm//imports/java/com/google/common/truth", + "@wfa_common_jvm//imports/kotlin/kotlin/test", + ], +) diff --git a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/JniReachOnlyLiquidLegionsV2EncryptionTest.kt b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/JniReachOnlyLiquidLegionsV2EncryptionTest.kt new file mode 100644 index 00000000000..20676ac6984 --- /dev/null +++ b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/JniReachOnlyLiquidLegionsV2EncryptionTest.kt @@ -0,0 +1,43 @@ +// Copyright 2023 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package org.wfanet.measurement.duchy.daemon.mill.liquidlegionsv2.crypto + +import com.google.common.truth.Truth.assertThat +import kotlin.test.assertFailsWith +import org.junit.Test +import org.wfanet.anysketch.crypto.CombineElGamalPublicKeysRequest +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlySetupPhaseRequest + +class JniReachOnlyLiquidLegionsV2EncryptionTest { + + @Test + fun `check JNI lib is loaded successfully`() { + // Send an invalid request and check if we can get the error thrown inside JNI. + val e1 = + assertFailsWith(RuntimeException::class) { + JniReachOnlyLiquidLegionsV2Encryption() + .completeReachOnlySetupPhase(CompleteReachOnlySetupPhaseRequest.getDefaultInstance()) + } + assertThat(e1.message).contains("ECGroup::CreateGroup() - Could not create group.") + + // Send an invalid request and check if we can get the error thrown inside JNI. + val e2 = + assertFailsWith(RuntimeException::class) { + JniReachOnlyLiquidLegionsV2Encryption() + .combineElGamalPublicKeys(CombineElGamalPublicKeysRequest.getDefaultInstance()) + } + assertThat(e2.message).contains("Keys cannot be empty") + } +} diff --git a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/ReachOnlyLiquidLegionsV2EncryptionUtilityTest.kt b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/ReachOnlyLiquidLegionsV2EncryptionUtilityTest.kt new file mode 100644 index 00000000000..8e122c56429 --- /dev/null +++ b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/crypto/ReachOnlyLiquidLegionsV2EncryptionUtilityTest.kt @@ -0,0 +1,295 @@ +// Copyright 2023 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package org.wfanet.measurement.duchy.daemon.mill.liquidlegionsv2.crypto + +import com.google.common.truth.Truth.assertThat +import com.google.protobuf.ByteString +import java.nio.file.Paths +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.JUnit4 +import org.wfanet.anysketch.SketchKt +import org.wfanet.anysketch.crypto.CombineElGamalPublicKeysResponse +import org.wfanet.anysketch.crypto.EncryptSketchRequest.DestroyedRegisterStrategy.FLAGGED_KEY +import org.wfanet.anysketch.crypto.EncryptSketchResponse +import org.wfanet.anysketch.crypto.SketchEncrypterAdapter +import org.wfanet.anysketch.crypto.combineElGamalPublicKeysRequest +import org.wfanet.anysketch.crypto.encryptSketchRequest +import org.wfanet.anysketch.sketch +import org.wfanet.estimation.Estimators +import org.wfanet.measurement.common.loadLibrary +import org.wfanet.measurement.duchy.daemon.utils.toAnySketchElGamalPublicKey +import org.wfanet.measurement.duchy.daemon.utils.toCmmsElGamalPublicKey +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyExecutionPhaseAtAggregatorResponse +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyExecutionPhaseResponse +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlyInitializationPhaseResponse +import org.wfanet.measurement.internal.duchy.protocol.CompleteReachOnlySetupPhaseResponse +import org.wfanet.measurement.internal.duchy.protocol.completeReachOnlyExecutionPhaseAtAggregatorRequest +import org.wfanet.measurement.internal.duchy.protocol.completeReachOnlyExecutionPhaseRequest +import org.wfanet.measurement.internal.duchy.protocol.completeReachOnlyInitializationPhaseRequest +import org.wfanet.measurement.internal.duchy.protocol.completeReachOnlySetupPhaseRequest +import org.wfanet.measurement.internal.duchy.protocol.liquidLegionsSketchParameters +import org.wfanet.measurement.internal.duchy.protocol.reachonlyliquidlegionsv2.ReachOnlyLiquidLegionsV2EncryptionUtility + +@RunWith(JUnit4::class) +class ReachOnlyLiquidLegionsV2EncryptionUtilityTest { + + // Helper function to go through the entire Liquid Legions V2 protocol using the input data. + // The final relative_frequency_distribution map are returned. + private fun goThroughEntireMpcProtocol( + encrypted_sketch: ByteString + ): CompleteReachOnlyExecutionPhaseAtAggregatorResponse { + // Setup phase at Duchy 1 (NON_AGGREGATOR). Duchy 1 receives all the sketches. + val completeReachOnlySetupPhaseRequest1 = completeReachOnlySetupPhaseRequest { + combinedRegisterVector = encrypted_sketch + curveId = CURVE_ID + compositeElGamalPublicKey = CLIENT_EL_GAMAL_KEYS + parallelism = PARALLELISM + } + val completeReachOnlySetupPhaseResponse1 = + CompleteReachOnlySetupPhaseResponse.parseFrom( + ReachOnlyLiquidLegionsV2EncryptionUtility.completeReachOnlySetupPhase( + completeReachOnlySetupPhaseRequest1.toByteArray() + ) + ) + // Setup phase at Duchy 2 (NON_AGGREGATOR). Duchy 2 does not receive any sketche. + val completeReachOnlySetupPhaseRequest2 = completeReachOnlySetupPhaseRequest { + curveId = CURVE_ID + compositeElGamalPublicKey = CLIENT_EL_GAMAL_KEYS + parallelism = PARALLELISM + } + val completeReachOnlySetupPhaseResponse2 = + CompleteReachOnlySetupPhaseResponse.parseFrom( + ReachOnlyLiquidLegionsV2EncryptionUtility.completeReachOnlySetupPhase( + completeReachOnlySetupPhaseRequest2.toByteArray() + ) + ) + // Setup phase at Duchy 3 (AGGREGATOR). Aggregator receives the combined register vector and + // the concatenated excessive noise ciphertexts. + val completeReachOnlySetupPhaseRequest3 = completeReachOnlySetupPhaseRequest { + combinedRegisterVector = + completeReachOnlySetupPhaseResponse1.combinedRegisterVector.concat( + completeReachOnlySetupPhaseResponse2.combinedRegisterVector + ) + curveId = CURVE_ID + compositeElGamalPublicKey = CLIENT_EL_GAMAL_KEYS + serializedExcessiveNoiseCiphertext = + completeReachOnlySetupPhaseResponse1.serializedExcessiveNoiseCiphertext.concat( + completeReachOnlySetupPhaseResponse2.serializedExcessiveNoiseCiphertext + ) + parallelism = PARALLELISM + } + val completeReachOnlySetupPhaseResponse3 = + CompleteReachOnlySetupPhaseResponse.parseFrom( + ReachOnlyLiquidLegionsV2EncryptionUtility.completeReachOnlySetupPhase( + completeReachOnlySetupPhaseRequest3.toByteArray() + ) + ) + + // Execution phase at duchy 1 (non-aggregator). + val completeReachOnlyExecutionPhaseRequest1 = completeReachOnlyExecutionPhaseRequest { + combinedRegisterVector = completeReachOnlySetupPhaseResponse3.combinedRegisterVector + localElGamalKeyPair = DUCHY_1_EL_GAMAL_KEYS + curveId = CURVE_ID + serializedExcessiveNoiseCiphertext = + completeReachOnlySetupPhaseResponse3.serializedExcessiveNoiseCiphertext + parallelism = PARALLELISM + } + val completeReachOnlyExecutionPhaseResponse1 = + CompleteReachOnlyExecutionPhaseResponse.parseFrom( + ReachOnlyLiquidLegionsV2EncryptionUtility.completeReachOnlyExecutionPhase( + completeReachOnlyExecutionPhaseRequest1.toByteArray() + ) + ) + + // Execution phase at duchy 2 (non-aggregator). + val completeReachOnlyExecutionPhaseRequest2 = completeReachOnlyExecutionPhaseRequest { + combinedRegisterVector = completeReachOnlyExecutionPhaseResponse1.combinedRegisterVector + localElGamalKeyPair = DUCHY_2_EL_GAMAL_KEYS + curveId = CURVE_ID + serializedExcessiveNoiseCiphertext = + completeReachOnlyExecutionPhaseResponse1.serializedExcessiveNoiseCiphertext + parallelism = PARALLELISM + } + val completeReachOnlyExecutionPhaseResponse2 = + CompleteReachOnlyExecutionPhaseResponse.parseFrom( + ReachOnlyLiquidLegionsV2EncryptionUtility.completeReachOnlyExecutionPhase( + completeReachOnlyExecutionPhaseRequest2.toByteArray() + ) + ) + + // Execution phase at duchy 3 (aggregator). + val completeReachOnlyExecutionPhaseAtAggregatorRequest = + completeReachOnlyExecutionPhaseAtAggregatorRequest { + combinedRegisterVector = completeReachOnlyExecutionPhaseResponse2.combinedRegisterVector + localElGamalKeyPair = DUCHY_3_EL_GAMAL_KEYS + curveId = CURVE_ID + serializedExcessiveNoiseCiphertext = + completeReachOnlyExecutionPhaseResponse2.serializedExcessiveNoiseCiphertext + parallelism = PARALLELISM + liquidLegionsParameters = liquidLegionsSketchParameters { + decayRate = DECAY_RATE + size = LIQUID_LEGIONS_SIZE + } + vidSamplingIntervalWidth = VID_SAMPLING_INTERVAL_WIDTH + parallelism = PARALLELISM + } + return CompleteReachOnlyExecutionPhaseAtAggregatorResponse.parseFrom( + ReachOnlyLiquidLegionsV2EncryptionUtility.completeReachOnlyExecutionPhaseAtAggregator( + completeReachOnlyExecutionPhaseAtAggregatorRequest.toByteArray() + ) + ) + } + + @Test + fun endToEnd_basicBehavior() { + val rawSketch = sketch { + registers += SketchKt.register { index = 1L } + registers += SketchKt.register { index = 2L } + registers += SketchKt.register { index = 2L } + registers += SketchKt.register { index = 3L } + registers += SketchKt.register { index = 4L } + } + val request = encryptSketchRequest { + sketch = rawSketch + curveId = CURVE_ID + maximumValue = MAX_COUNTER_VALUE + elGamalKeys = CLIENT_EL_GAMAL_KEYS.toAnySketchElGamalPublicKey() + destroyedRegisterStrategy = FLAGGED_KEY + } + val response = + EncryptSketchResponse.parseFrom(SketchEncrypterAdapter.EncryptSketch(request.toByteArray())) + val encryptedSketch = response.encryptedSketch + val result = goThroughEntireMpcProtocol(encryptedSketch).reach + val expectedResult = + Estimators.EstimateCardinalityLiquidLegions( + DECAY_RATE, + LIQUID_LEGIONS_SIZE, + 4, + VID_SAMPLING_INTERVAL_WIDTH.toDouble() + ) + assertEquals(expectedResult, result) + } + + @Test + fun `completeReachOnlySetupPhase fails with invalid request message`() { + val exception = + assertFailsWith(RuntimeException::class) { + ReachOnlyLiquidLegionsV2EncryptionUtility.completeReachOnlySetupPhase( + "something not a proto".toByteArray() + ) + } + assertThat(exception).hasMessageThat().contains("Failed to parse") + } + + @Test + fun `completeReachOnlyExecutionPhase fails with invalid request message`() { + val exception = + assertFailsWith(RuntimeException::class) { + ReachOnlyLiquidLegionsV2EncryptionUtility.completeReachOnlyExecutionPhase( + "something not a proto".toByteArray() + ) + } + assertThat(exception).hasMessageThat().contains("Failed to parse") + } + + @Test + fun `completeReachOnlyExecutionPhaseAtAggregator fails with invalid request message`() { + val exception = + assertFailsWith(RuntimeException::class) { + ReachOnlyLiquidLegionsV2EncryptionUtility.completeReachOnlyExecutionPhaseAtAggregator( + "something not a proto".toByteArray() + ) + } + assertThat(exception).hasMessageThat().contains("Failed to parse") + } + + companion object { + init { + loadLibrary( + "reach_only_liquid_legions_v2_encryption_utility", + Paths.get("wfa_measurement_system/src/main/swig/protocol/reachonlyliquidlegionsv2") + ) + loadLibrary( + "sketch_encrypter_adapter", + Paths.get("any_sketch_java/src/main/java/org/wfanet/anysketch/crypto") + ) + loadLibrary("estimators", Paths.get("any_sketch_java/src/main/java/org/wfanet/estimation")) + } + + private const val DECAY_RATE = 12.0 + private const val LIQUID_LEGIONS_SIZE = 100_000L + private const val MAXIMUM_FREQUENCY = 10 + private const val VID_SAMPLING_INTERVAL_WIDTH = 0.5f + + private const val CURVE_ID = 415L // NID_X9_62_prime256v1 + private const val PARALLELISM = 3 + private const val MAX_COUNTER_VALUE = 10 + + private val COMPLETE_INITIALIZATION_REQUEST = completeReachOnlyInitializationPhaseRequest { + curveId = CURVE_ID + } + private val DUCHY_1_EL_GAMAL_KEYS = + CompleteReachOnlyInitializationPhaseResponse.parseFrom( + ReachOnlyLiquidLegionsV2EncryptionUtility.completeReachOnlyInitializationPhase( + COMPLETE_INITIALIZATION_REQUEST.toByteArray() + ) + ) + .elGamalKeyPair + private val DUCHY_2_EL_GAMAL_KEYS = + CompleteReachOnlyInitializationPhaseResponse.parseFrom( + ReachOnlyLiquidLegionsV2EncryptionUtility.completeReachOnlyInitializationPhase( + COMPLETE_INITIALIZATION_REQUEST.toByteArray() + ) + ) + .elGamalKeyPair + private val DUCHY_3_EL_GAMAL_KEYS = + CompleteReachOnlyInitializationPhaseResponse.parseFrom( + ReachOnlyLiquidLegionsV2EncryptionUtility.completeReachOnlyInitializationPhase( + COMPLETE_INITIALIZATION_REQUEST.toByteArray() + ) + ) + .elGamalKeyPair + + private val CLIENT_EL_GAMAL_KEYS = + CombineElGamalPublicKeysResponse.parseFrom( + SketchEncrypterAdapter.CombineElGamalPublicKeys( + combineElGamalPublicKeysRequest { + curveId = CURVE_ID + elGamalKeys += DUCHY_1_EL_GAMAL_KEYS.publicKey.toAnySketchElGamalPublicKey() + elGamalKeys += DUCHY_2_EL_GAMAL_KEYS.publicKey.toAnySketchElGamalPublicKey() + elGamalKeys += DUCHY_3_EL_GAMAL_KEYS.publicKey.toAnySketchElGamalPublicKey() + } + .toByteArray() + ) + ) + .elGamalKeys + .toCmmsElGamalPublicKey() + private val DUCHY_2_3_COMBINED_EL_GAMAL_KEYS = + CombineElGamalPublicKeysResponse.parseFrom( + SketchEncrypterAdapter.CombineElGamalPublicKeys( + combineElGamalPublicKeysRequest { + curveId = CURVE_ID + elGamalKeys += DUCHY_2_EL_GAMAL_KEYS.publicKey.toAnySketchElGamalPublicKey() + elGamalKeys += DUCHY_3_EL_GAMAL_KEYS.publicKey.toAnySketchElGamalPublicKey() + } + .toByteArray() + ) + ) + .elGamalKeys + .toCmmsElGamalPublicKey() + } +}