diff --git a/src/main/k8s/testing/secretfiles/aggregator_protocols_setup_config.textproto b/src/main/k8s/testing/secretfiles/aggregator_protocols_setup_config.textproto index e4b95869b4c..3e78df01981 100644 --- a/src/main/k8s/testing/secretfiles/aggregator_protocols_setup_config.textproto +++ b/src/main/k8s/testing/secretfiles/aggregator_protocols_setup_config.textproto @@ -4,3 +4,7 @@ liquid_legions_v2 { role: AGGREGATOR external_aggregator_duchy_id: "aggregator" } +reach_only_liquid_legions_v2 { + role: AGGREGATOR + external_aggregator_duchy_id: "aggregator" +} diff --git a/src/main/k8s/testing/secretfiles/non_aggregator_protocols_setup_config.textproto b/src/main/k8s/testing/secretfiles/non_aggregator_protocols_setup_config.textproto index fd50191bd76..3688b0429fc 100644 --- a/src/main/k8s/testing/secretfiles/non_aggregator_protocols_setup_config.textproto +++ b/src/main/k8s/testing/secretfiles/non_aggregator_protocols_setup_config.textproto @@ -4,3 +4,7 @@ liquid_legions_v2 { role: NON_AGGREGATOR external_aggregator_duchy_id: "aggregator" } +reach_only_liquid_legions_v2 { + role: NON_AGGREGATOR + external_aggregator_duchy_id: "aggregator" +} diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/herald/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/herald/BUILD.bazel index 4ff006b7913..a8fd04376f4 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/herald/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/herald/BUILD.bazel @@ -24,6 +24,7 @@ kt_jvm_library( srcs = [ "Herald.kt", "LiquidLegionsV2Starter.kt", + "ReachOnlyLiquidLegionsV2Starter.kt", ], runtime_deps = ["@wfa_common_jvm//imports/java/io/grpc/netty"], deps = [ diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/herald/Herald.kt b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/herald/Herald.kt index d6c9a51d697..2f1a6eacd51 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/herald/Herald.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/herald/Herald.kt @@ -32,9 +32,7 @@ import kotlinx.coroutines.launch import kotlinx.coroutines.sync.Semaphore import org.wfanet.measurement.common.grpc.grpcStatusCode import org.wfanet.measurement.common.protoTimestamp -import org.wfanet.measurement.duchy.daemon.utils.MeasurementType import org.wfanet.measurement.duchy.daemon.utils.key -import org.wfanet.measurement.duchy.daemon.utils.toMeasurementType import org.wfanet.measurement.duchy.service.internal.computations.toGetTokenRequest import org.wfanet.measurement.internal.duchy.ComputationDetails import org.wfanet.measurement.internal.duchy.ComputationsGrpcKt.ComputationsCoroutineStub @@ -229,16 +227,22 @@ class Herald( val globalId: String = systemComputation.key.computationId logger.info("[id=$globalId] Creating Computation...") try { - when (systemComputation.toMeasurementType()) { - MeasurementType.REACH, - MeasurementType.REACH_AND_FREQUENCY -> { + when (systemComputation.mpcProtocolConfig.protocolCase) { + Computation.MpcProtocolConfig.ProtocolCase.LIQUID_LEGIONS_V2 -> LiquidLegionsV2Starter.createComputation( internalComputationsClient, systemComputation, protocolsSetupConfig.liquidLegionsV2, blobStorageBucket ) - } + Computation.MpcProtocolConfig.ProtocolCase.REACH_ONLY_LIQUID_LEGIONS_V2 -> + ReachOnlyLiquidLegionsV2Starter.createComputation( + internalComputationsClient, + systemComputation, + protocolsSetupConfig.reachOnlyLiquidLegionsV2, + blobStorageBucket + ) + else -> error("Unknown or unsupported protocol for creation.") } logger.info("[id=$globalId]: Created Computation") } catch (e: StatusException) { @@ -302,6 +306,13 @@ class Herald( systemComputation, protocolsSetupConfig.liquidLegionsV2.externalAggregatorDuchyId ) + ComputationDetails.ProtocolCase.REACH_ONLY_LIQUID_LEGIONS_V2 -> + ReachOnlyLiquidLegionsV2Starter.updateRequisitionsAndKeySets( + token, + internalComputationsClient, + systemComputation, + protocolsSetupConfig.reachOnlyLiquidLegionsV2.externalAggregatorDuchyId + ) else -> error("Unknown or unsupported protocol.") } logger.info("[id=$globalId]: Confirmed Computation") @@ -317,6 +328,8 @@ class Herald( when (token.computationDetails.protocolCase) { ComputationDetails.ProtocolCase.LIQUID_LEGIONS_V2 -> LiquidLegionsV2Starter.startComputation(token, internalComputationsClient) + ComputationDetails.ProtocolCase.REACH_ONLY_LIQUID_LEGIONS_V2 -> + ReachOnlyLiquidLegionsV2Starter.startComputation(token, internalComputationsClient) else -> error("Unknown or unsupported protocol.") } logger.info("[id=$globalId]: Started Computation") @@ -365,8 +378,10 @@ class Herald( } ?: return if ( - token.computationDetails.hasLiquidLegionsV2() && - token.computationStage == LiquidLegionsV2Starter.TERMINAL_STAGE + (token.computationDetails.hasLiquidLegionsV2() && + token.computationStage == LiquidLegionsV2Starter.TERMINAL_STAGE) || + (token.computationDetails.hasReachOnlyLiquidLegionsV2() && + token.computationStage == ReachOnlyLiquidLegionsV2Starter.TERMINAL_STAGE) ) { return } @@ -376,6 +391,8 @@ class Herald( endingComputationStage = when (token.computationDetails.protocolCase) { ComputationDetails.ProtocolCase.LIQUID_LEGIONS_V2 -> LiquidLegionsV2Starter.TERMINAL_STAGE + ComputationDetails.ProtocolCase.REACH_ONLY_LIQUID_LEGIONS_V2 -> + ReachOnlyLiquidLegionsV2Starter.TERMINAL_STAGE else -> error { "Unknown or unsupported protocol." } } reason = ComputationDetails.CompletedReason.FAILED diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/herald/LiquidLegionsV2Starter.kt b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/herald/LiquidLegionsV2Starter.kt index f84898a73af..da840b2e677 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/herald/LiquidLegionsV2Starter.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/herald/LiquidLegionsV2Starter.kt @@ -188,7 +188,7 @@ object LiquidLegionsV2Starter { Stage.WAIT_EXECUTION_PHASE_THREE_INPUTS, Stage.EXECUTION_PHASE_THREE, Stage.COMPLETE -> { - logger.info( + logger.warning( "[id=${token.globalComputationId}]: not updating," + " stage '$stage' is after WAIT_REQUISITIONS_AND_KEY_SET" ) @@ -245,7 +245,7 @@ object LiquidLegionsV2Starter { Stage.WAIT_EXECUTION_PHASE_THREE_INPUTS, Stage.EXECUTION_PHASE_THREE, Stage.COMPLETE -> { - logger.info( + logger.warning( "[id=${token.globalComputationId}]: not starting," + " stage '$stage' is after WAIT_TO_START" ) diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/herald/ReachOnlyLiquidLegionsV2Starter.kt b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/herald/ReachOnlyLiquidLegionsV2Starter.kt new file mode 100644 index 00000000000..09a1005e74d --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/herald/ReachOnlyLiquidLegionsV2Starter.kt @@ -0,0 +1,333 @@ +// Copyright 2023 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package org.wfanet.measurement.duchy.daemon.herald + +import java.util.logging.Logger +import org.wfanet.measurement.api.Version +import org.wfanet.measurement.api.v2alpha.MeasurementSpec +import org.wfanet.measurement.duchy.daemon.utils.key +import org.wfanet.measurement.duchy.daemon.utils.sha1Hash +import org.wfanet.measurement.duchy.daemon.utils.toDuchyDifferentialPrivacyParams +import org.wfanet.measurement.duchy.daemon.utils.toDuchyElGamalPublicKey +import org.wfanet.measurement.duchy.daemon.utils.toKingdomComputationDetails +import org.wfanet.measurement.duchy.daemon.utils.toRequisitionEntries +import org.wfanet.measurement.duchy.db.computation.advanceComputationStage +import org.wfanet.measurement.duchy.service.internal.computations.outputPathList +import org.wfanet.measurement.duchy.toProtocolStage +import org.wfanet.measurement.internal.duchy.ComputationToken +import org.wfanet.measurement.internal.duchy.ComputationTypeEnum +import org.wfanet.measurement.internal.duchy.ComputationsGrpcKt +import org.wfanet.measurement.internal.duchy.computationDetails +import org.wfanet.measurement.internal.duchy.config.LiquidLegionsV2SetupConfig +import org.wfanet.measurement.internal.duchy.createComputationRequest +import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsV2NoiseConfig +import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsV2NoiseConfigKt +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2 +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2Kt +import org.wfanet.measurement.internal.duchy.protocol.liquidLegionsSketchParameters +import org.wfanet.measurement.internal.duchy.protocol.liquidLegionsV2NoiseConfig +import org.wfanet.measurement.internal.duchy.updateComputationDetailsRequest +import org.wfanet.measurement.system.v1alpha.Computation +import org.wfanet.measurement.system.v1alpha.ComputationParticipant + +private const val MIN_REACH_EPSILON = 0.00001 +private const val MIN_FREQUENCY_EPSILON = 0.00001 + +object ReachOnlyLiquidLegionsV2Starter { + + private val logger: Logger = Logger.getLogger(this::class.java.name) + + val TERMINAL_STAGE = ReachOnlyLiquidLegionsSketchAggregationV2.Stage.COMPLETE.toProtocolStage() + + suspend fun createComputation( + computationStorageClient: ComputationsGrpcKt.ComputationsCoroutineStub, + systemComputation: Computation, + liquidLegionsV2SetupConfig: LiquidLegionsV2SetupConfig, + blobStorageBucket: String + ) { + require(systemComputation.name.isNotEmpty()) { "Resource name not specified" } + val globalId: String = systemComputation.key.computationId + val initialComputationDetails = computationDetails { + blobsStoragePrefix = "$blobStorageBucket/$globalId" + kingdomComputation = systemComputation.toKingdomComputationDetails() + reachOnlyLiquidLegionsV2 = + ReachOnlyLiquidLegionsSketchAggregationV2Kt.computationDetails { + role = liquidLegionsV2SetupConfig.role + parameters = systemComputation.toReachOnlyLiquidLegionsV2Parameters() + } + } + val requisitions = + systemComputation.requisitionsList.toRequisitionEntries(systemComputation.measurementSpec) + + computationStorageClient.createComputation( + createComputationRequest { + computationType = + ComputationTypeEnum.ComputationType.REACH_ONLY_LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 + globalComputationId = globalId + computationDetails = initialComputationDetails + this.requisitions += requisitions + } + ) + } + + /** + * Orders the list of computation participants by their roles in the computation. The + * non-aggregators are shuffled by the sha1Hash of their elgamal public keys and the global + * computation id, the aggregator is placed at the end of the list. This return order is also the + * order of all participants in the MPC ring structure. + */ + private fun List< + ReachOnlyLiquidLegionsSketchAggregationV2.ComputationDetails.ComputationParticipant + > + .orderByRoles( + globalComputationId: String, + aggregatorId: String + ): List { + val aggregator = + this.find { it.duchyId == aggregatorId } + ?: error("Aggregator duchy is missing from the participants.") + val nonAggregators = this.filter { it.duchyId != aggregatorId } + return nonAggregators.sortedBy { + sha1Hash(it.elGamalPublicKey.toStringUtf8() + globalComputationId) + } + aggregator + } + + private suspend fun updateRequisitionsAndKeySetsInternal( + token: ComputationToken, + computationStorageClient: ComputationsGrpcKt.ComputationsCoroutineStub, + systemComputation: Computation, + aggregatorId: String + ) { + val updatedDetails = computationDetails { + reachOnlyLiquidLegionsV2 = + ReachOnlyLiquidLegionsSketchAggregationV2Kt.computationDetails { + participant += + systemComputation.computationParticipantsList + .map { it.toDuchyComputationParticipant(systemComputation.publicApiVersion) } + .orderByRoles(token.globalComputationId, aggregatorId) + } + } + val requisitions = + systemComputation.requisitionsList.toRequisitionEntries(systemComputation.measurementSpec) + val updateComputationDetailsRequest = updateComputationDetailsRequest { + this.token = token + details = updatedDetails + this.requisitions += requisitions + } + + val newToken = + computationStorageClient.updateComputationDetails(updateComputationDetailsRequest).token + logger.info( + "[id=${token.globalComputationId}] " + "Requisitions and Duchy Elgamal Keys are now updated." + ) + + computationStorageClient.advanceComputationStage( + computationToken = newToken, + stage = ReachOnlyLiquidLegionsSketchAggregationV2.Stage.CONFIRMATION_PHASE.toProtocolStage() + ) + } + + suspend fun updateRequisitionsAndKeySets( + token: ComputationToken, + computationStorageClient: ComputationsGrpcKt.ComputationsCoroutineStub, + systemComputation: Computation, + aggregatorId: String, + ) { + require(token.computationDetails.hasReachOnlyLiquidLegionsV2()) { + "Reach Only Liquid Legions V2 ComputationDetails required" + } + + val stage = token.computationStage.reachOnlyLiquidLegionsSketchAggregationV2 + @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. + when (stage) { + // We expect stage WAIT_REQUISITIONS_AND_KEY_SET. + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.WAIT_REQUISITIONS_AND_KEY_SET -> { + updateRequisitionsAndKeySetsInternal( + token, + computationStorageClient, + systemComputation, + aggregatorId + ) + return + } + + // For past stages, we throw. + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.INITIALIZATION_PHASE -> { + error( + "[id=${token.globalComputationId}]: cannot update requisitions and key sets for " + + "computation still in state ${stage.name}" + ) + } + + // For future stages, we log and exit. + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.WAIT_TO_START, + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.CONFIRMATION_PHASE, + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.WAIT_SETUP_PHASE_INPUTS, + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.SETUP_PHASE, + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.WAIT_EXECUTION_PHASE_INPUTS, + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.EXECUTION_PHASE, + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.COMPLETE -> { + logger.warning( + "[id=${token.globalComputationId}]: not updating," + + " stage '$stage' is after WAIT_REQUISITIONS_AND_KEY_SET" + ) + return + } + + // For weird stages, we throw. + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.UNRECOGNIZED, + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.STAGE_UNSPECIFIED -> { + error("[id=${token.globalComputationId}]: Unrecognized stage '$stage'") + } + } + } + + suspend fun startComputation( + token: ComputationToken, + computationStorageClient: ComputationsGrpcKt.ComputationsCoroutineStub + ) { + require(token.computationDetails.hasReachOnlyLiquidLegionsV2()) { + "Reach-Only Liquid Legions V2 computation required" + } + + val stage = token.computationStage.reachOnlyLiquidLegionsSketchAggregationV2 + @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. + when (stage) { + // We expect stage WAIT_TO_START. + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.WAIT_TO_START -> { + computationStorageClient.advanceComputationStage( + computationToken = token, + inputsToNextStage = token.outputPathList(), + stage = ReachOnlyLiquidLegionsSketchAggregationV2.Stage.SETUP_PHASE.toProtocolStage() + ) + logger.info("[id=${token.globalComputationId}] Computation is now started") + return + } + + // For past stages, we throw. + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.INITIALIZATION_PHASE, + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.WAIT_REQUISITIONS_AND_KEY_SET, + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.CONFIRMATION_PHASE -> { + error( + "[id=${token.globalComputationId}]: cannot start a computation still" + + " in state ${stage.name}" + ) + } + + // For future stages, we log and exit. + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.WAIT_SETUP_PHASE_INPUTS, + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.SETUP_PHASE, + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.WAIT_EXECUTION_PHASE_INPUTS, + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.EXECUTION_PHASE, + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.COMPLETE -> { + logger.warning( + "[id=${token.globalComputationId}]: not starting," + + " stage '$stage' is after WAIT_TO_START" + ) + return + } + + // For weird stages, we throw. + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.UNRECOGNIZED, + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.STAGE_UNSPECIFIED -> { + error("[id=${token.globalComputationId}]: Unrecognized stage '$stage'") + } + } + } + + private fun ComputationParticipant.toDuchyComputationParticipant( + publicApiVersion: String + ): ReachOnlyLiquidLegionsSketchAggregationV2.ComputationDetails.ComputationParticipant { + require(requisitionParams.hasReachOnlyLiquidLegionsV2()) { + "Missing reach-only liquid legions v2 requisition params." + } + return ReachOnlyLiquidLegionsSketchAggregationV2Kt.ComputationDetailsKt.computationParticipant { + duchyId = key.duchyId + publicKey = + requisitionParams.reachOnlyLiquidLegionsV2.elGamalPublicKey.toDuchyElGamalPublicKey( + Version.fromString(publicApiVersion) + ) + elGamalPublicKey = requisitionParams.reachOnlyLiquidLegionsV2.elGamalPublicKey + elGamalPublicKeySignature = + requisitionParams.reachOnlyLiquidLegionsV2.elGamalPublicKeySignature + duchyCertificateDer = requisitionParams.duchyCertificateDer + } + } + + private fun Computation.MpcProtocolConfig.NoiseMechanism.toInternalNoiseMechanism(): + LiquidLegionsV2NoiseConfig.NoiseMechanism { + return when (this) { + Computation.MpcProtocolConfig.NoiseMechanism.GEOMETRIC -> + LiquidLegionsV2NoiseConfig.NoiseMechanism.GEOMETRIC + Computation.MpcProtocolConfig.NoiseMechanism.DISCRETE_GAUSSIAN -> + LiquidLegionsV2NoiseConfig.NoiseMechanism.DISCRETE_GAUSSIAN + Computation.MpcProtocolConfig.NoiseMechanism.UNRECOGNIZED, + Computation.MpcProtocolConfig.NoiseMechanism.NOISE_MECHANISM_UNSPECIFIED -> + error("Invalid system NoiseMechanism") + } + } + + /** Creates a reach-only liquid legions v2 `Parameters` from the system Api computation. */ + private fun Computation.toReachOnlyLiquidLegionsV2Parameters(): + ReachOnlyLiquidLegionsSketchAggregationV2.ComputationDetails.Parameters { + require(mpcProtocolConfig.hasReachOnlyLiquidLegionsV2()) { + "Missing reachOnlyLiquidLegionV2 in the duchy protocol config." + } + + return ReachOnlyLiquidLegionsSketchAggregationV2Kt.ComputationDetailsKt.parameters { + sketchParameters = liquidLegionsSketchParameters { + decayRate = mpcProtocolConfig.reachOnlyLiquidLegionsV2.sketchParams.decayRate + size = mpcProtocolConfig.reachOnlyLiquidLegionsV2.sketchParams.maxSize + } + ellipticCurveId = mpcProtocolConfig.reachOnlyLiquidLegionsV2.ellipticCurveId + noise = liquidLegionsV2NoiseConfig { + noiseMechanism = + mpcProtocolConfig.reachOnlyLiquidLegionsV2.noiseMechanism.toInternalNoiseMechanism() + reachNoiseConfig = + LiquidLegionsV2NoiseConfigKt.reachNoiseConfig { + val mpcNoise = mpcProtocolConfig.reachOnlyLiquidLegionsV2.mpcNoise + blindHistogramNoise = mpcNoise.blindedHistogramNoise.toDuchyDifferentialPrivacyParams() + noiseForPublisherNoise = mpcNoise.publisherNoise.toDuchyDifferentialPrivacyParams() + + when (Version.fromString(publicApiVersion)) { + Version.V2_ALPHA -> { + val measurementSpec = MeasurementSpec.parseFrom(measurementSpec) + @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. + when (measurementSpec.measurementTypeCase) { + MeasurementSpec.MeasurementTypeCase.REACH -> { + val reach = measurementSpec.reach + require(reach.privacyParams.delta > 0) { + "RoLLv2 requires that privacy_params.delta be greater than 0" + } + require(reach.privacyParams.epsilon > MIN_REACH_EPSILON) { + "RoLLv2 requires that privacy_params.epsilon be greater than $MIN_REACH_EPSILON" + } + globalReachDpNoise = reach.privacyParams.toDuchyDifferentialPrivacyParams() + } + MeasurementSpec.MeasurementTypeCase.REACH_AND_FREQUENCY, + MeasurementSpec.MeasurementTypeCase.IMPRESSION, + MeasurementSpec.MeasurementTypeCase.DURATION, + MeasurementSpec.MeasurementTypeCase.MEASUREMENTTYPE_NOT_SET -> { + throw IllegalArgumentException("Missing Reach in the measurementSpec.") + } + } + } + Version.VERSION_UNSPECIFIED -> error("Public api version is invalid or unspecified.") + } + } + } + } + } +} diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/readers/RequisitionReader.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/readers/RequisitionReader.kt index 999c205396d..171ad148eb5 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/readers/RequisitionReader.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/readers/RequisitionReader.kt @@ -230,6 +230,9 @@ class RequisitionReader : BaseSpannerReader() { ComputationParticipant.Details.ProtocolCase.LIQUID_LEGIONS_V2 -> { liquidLegionsV2 = participantDetails.liquidLegionsV2 } + ComputationParticipant.Details.ProtocolCase.REACH_ONLY_LIQUID_LEGIONS_V2 -> { + reachOnlyLiquidLegionsV2 = participantDetails.reachOnlyLiquidLegionsV2 + } // Protocol may only be set after computation participant sets requisition params. ComputationParticipant.Details.ProtocolCase.PROTOCOL_NOT_SET -> Unit } diff --git a/src/main/proto/wfa/measurement/internal/duchy/protocol/liquid_legions_sketch_aggregation_v2.proto b/src/main/proto/wfa/measurement/internal/duchy/protocol/liquid_legions_sketch_aggregation_v2.proto index 0568f525341..a7d99d3a154 100644 --- a/src/main/proto/wfa/measurement/internal/duchy/protocol/liquid_legions_sketch_aggregation_v2.proto +++ b/src/main/proto/wfa/measurement/internal/duchy/protocol/liquid_legions_sketch_aggregation_v2.proto @@ -119,6 +119,7 @@ message LiquidLegionsSketchAggregationV2 { // The maximum frequency to reveal in the histogram. int32 maximum_frequency = 1; // Parameters used for liquidLegions sketch creation and estimation. + // TODO(@ple): rename to sketch_parameters. LiquidLegionsSketchParameters liquid_legions_sketch = 2; // Noise parameters selected for the LiquidLegionV2 MPC protocol. LiquidLegionsV2NoiseConfig noise = 3; diff --git a/src/main/proto/wfa/measurement/internal/duchy/protocol/liquid_legions_v2_noise_config.proto b/src/main/proto/wfa/measurement/internal/duchy/protocol/liquid_legions_v2_noise_config.proto index c46bc0d5fac..9ef947f8d25 100644 --- a/src/main/proto/wfa/measurement/internal/duchy/protocol/liquid_legions_v2_noise_config.proto +++ b/src/main/proto/wfa/measurement/internal/duchy/protocol/liquid_legions_v2_noise_config.proto @@ -22,7 +22,7 @@ option java_package = "org.wfanet.measurement.internal.duchy.protocol"; option java_multiple_files = true; // Configuration for various noises added by the MPC workers in the -// LiquidLegionV2 protocol. +// LiquidLegionV2 protocols. Also used by reach-only protocol. message LiquidLegionsV2NoiseConfig { message ReachNoiseConfig { // DP params for the blind histogram noise register. @@ -44,6 +44,7 @@ message LiquidLegionsV2NoiseConfig { // Differential privacy parameters for noise tuples. // Same value is used for both (0, R, R) and (R, R, R) tuples. + // Ignored by reach-only protocol. DifferentialPrivacyParams frequency_noise_config = 2; // The mechanism used to generate noise in computations. diff --git a/src/main/proto/wfa/measurement/internal/duchy/protocol/reach_only_liquid_legions_sketch_aggregation_v2.proto b/src/main/proto/wfa/measurement/internal/duchy/protocol/reach_only_liquid_legions_sketch_aggregation_v2.proto index 8e754dc8039..490b60d0516 100644 --- a/src/main/proto/wfa/measurement/internal/duchy/protocol/reach_only_liquid_legions_sketch_aggregation_v2.proto +++ b/src/main/proto/wfa/measurement/internal/duchy/protocol/reach_only_liquid_legions_sketch_aggregation_v2.proto @@ -99,9 +99,10 @@ message ReachOnlyLiquidLegionsSketchAggregationV2 { // Parameters used in this computation. message Parameters { - // Parameters used for liquidLegions sketch creation and estimation. - LiquidLegionsSketchParameters liquid_legions_sketch = 1; - // Noise parameters selected for the LiquidLegionV2 MPC protocol. + // Parameters used for reachOnlyLiquidLegions sketch creation and + // estimation. + LiquidLegionsSketchParameters sketch_parameters = 1; + // Noise parameters selected for the ReachOnlyLiquidLegionV2 MPC protocol. LiquidLegionsV2NoiseConfig noise = 2; // ID of the OpenSSL built-in elliptic curve. For example, 415 for the // prime256v1 curve. Required. Immutable. diff --git a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/herald/HeraldTest.kt b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/herald/HeraldTest.kt index 1e60c457138..055b6c572e6 100644 --- a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/herald/HeraldTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/herald/HeraldTest.kt @@ -49,6 +49,7 @@ import org.wfanet.measurement.api.v2alpha.EncryptionPublicKey import org.wfanet.measurement.api.v2alpha.MeasurementSpecKt.reach import org.wfanet.measurement.api.v2alpha.MeasurementSpecKt.reachAndFrequency import org.wfanet.measurement.api.v2alpha.differentialPrivacyParams as cmmsDifferentialPrivacyParams +import org.wfanet.measurement.api.v2alpha.elGamalPublicKey import org.wfanet.measurement.api.v2alpha.encryptionPublicKey import org.wfanet.measurement.api.v2alpha.measurementSpec import org.wfanet.measurement.common.grpc.testing.GrpcTestServerRule @@ -67,7 +68,6 @@ import org.wfanet.measurement.duchy.service.internal.computations.newPassThrough import org.wfanet.measurement.duchy.storage.ComputationStore import org.wfanet.measurement.duchy.storage.RequisitionStore import org.wfanet.measurement.duchy.toProtocolStage -import org.wfanet.measurement.internal.duchy.ComputationDetails import org.wfanet.measurement.internal.duchy.ComputationDetailsKt.kingdomComputationDetails import org.wfanet.measurement.internal.duchy.ComputationsGrpcKt.ComputationsCoroutineImplBase as InternalComputationsCoroutineImplBase import org.wfanet.measurement.internal.duchy.ComputationsGrpcKt.ComputationsCoroutineStub as InternalComputationsCoroutineStub @@ -79,27 +79,26 @@ import org.wfanet.measurement.internal.duchy.GetComputationTokenRequest import org.wfanet.measurement.internal.duchy.computationDetails import org.wfanet.measurement.internal.duchy.computationToken import org.wfanet.measurement.internal.duchy.config.LiquidLegionsV2SetupConfig.RoleInComputation -import org.wfanet.measurement.internal.duchy.config.ProtocolsSetupConfig +import org.wfanet.measurement.internal.duchy.config.liquidLegionsV2SetupConfig +import org.wfanet.measurement.internal.duchy.config.protocolsSetupConfig import org.wfanet.measurement.internal.duchy.deleteComputationRequest import org.wfanet.measurement.internal.duchy.differentialPrivacyParams as duchyDifferentialPrivacyParams +import org.wfanet.measurement.internal.duchy.elGamalPublicKey as internalElgamalPublicKey import org.wfanet.measurement.internal.duchy.getComputationTokenResponse import org.wfanet.measurement.internal.duchy.getContinuationTokenResponse -import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggregationV2.ComputationDetails.ComputationParticipant -import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggregationV2.Stage.CONFIRMATION_PHASE -import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggregationV2.Stage.INITIALIZATION_PHASE -import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggregationV2.Stage.SETUP_PHASE -import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggregationV2.Stage.WAIT_REQUISITIONS_AND_KEY_SET -import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggregationV2.Stage.WAIT_TO_START +import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggregationV2 import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggregationV2Kt -import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggregationV2Kt.ComputationDetailsKt.parameters import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsV2NoiseConfig import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsV2NoiseConfigKt.reachNoiseConfig +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2 +import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2Kt import org.wfanet.measurement.internal.duchy.protocol.liquidLegionsSketchParameters import org.wfanet.measurement.internal.duchy.protocol.liquidLegionsV2NoiseConfig import org.wfanet.measurement.internal.duchy.setContinuationTokenRequest import org.wfanet.measurement.storage.StorageClient import org.wfanet.measurement.storage.testing.InMemoryStorageClient import org.wfanet.measurement.system.v1alpha.Computation +import org.wfanet.measurement.system.v1alpha.Computation.MpcProtocolConfig import org.wfanet.measurement.system.v1alpha.Computation.MpcProtocolConfig.NoiseMechanism as SystemNoiseMechanism import org.wfanet.measurement.system.v1alpha.ComputationKey import org.wfanet.measurement.system.v1alpha.ComputationKt.MpcProtocolConfigKt.LiquidLegionsV2Kt.liquidLegionsSketchParams @@ -112,6 +111,8 @@ import org.wfanet.measurement.system.v1alpha.ComputationLogEntry import org.wfanet.measurement.system.v1alpha.ComputationParticipant as SystemComputationParticipant import org.wfanet.measurement.system.v1alpha.ComputationParticipantKey import org.wfanet.measurement.system.v1alpha.ComputationParticipantKt +import org.wfanet.measurement.system.v1alpha.ComputationParticipantKt.RequisitionParamsKt +import org.wfanet.measurement.system.v1alpha.ComputationParticipantKt.requisitionParams import org.wfanet.measurement.system.v1alpha.ComputationParticipantsGrpcKt.ComputationParticipantsCoroutineImplBase as SystemComputationParticipantsCoroutineImplBase import org.wfanet.measurement.system.v1alpha.ComputationParticipantsGrpcKt.ComputationParticipantsCoroutineStub as SystemComputationParticipantsCoroutineStub import org.wfanet.measurement.system.v1alpha.ComputationsGrpcKt.ComputationsCoroutineImplBase as SystemComputationsCoroutineImplBase @@ -120,6 +121,7 @@ import org.wfanet.measurement.system.v1alpha.FailComputationParticipantRequest import org.wfanet.measurement.system.v1alpha.Requisition import org.wfanet.measurement.system.v1alpha.StreamActiveComputationsResponse import org.wfanet.measurement.system.v1alpha.computation +import org.wfanet.measurement.system.v1alpha.computationParticipant as systemComputationParticipant import org.wfanet.measurement.system.v1alpha.computationParticipant import org.wfanet.measurement.system.v1alpha.copy import org.wfanet.measurement.system.v1alpha.differentialPrivacyParams as systemDifferentialPrivacyParams @@ -176,7 +178,7 @@ private val PUBLIC_API_REACH_ONLY_MEASUREMENT_SPEC = measurementSpec { private val SERIALIZED_REACH_ONLY_MEASUREMENT_SPEC: ByteString = PUBLIC_API_REACH_ONLY_MEASUREMENT_SPEC.toByteString() -private val MPC_PROTOCOL_CONFIG = mpcProtocolConfig { +private val LLV2_MPC_PROTOCOL_CONFIG = mpcProtocolConfig { liquidLegionsV2 = liquidLegionsV2 { sketchParams = liquidLegionsSketchParams { decayRate = 12.0 @@ -198,38 +200,79 @@ private val MPC_PROTOCOL_CONFIG = mpcProtocolConfig { } } +private val RO_LLV2_MPC_PROTOCOL_CONFIG = mpcProtocolConfig { + reachOnlyLiquidLegionsV2 = liquidLegionsV2 { + sketchParams = liquidLegionsSketchParams { + decayRate = 12.0 + maxSize = 100_000 + } + mpcNoise = mpcNoise { + blindedHistogramNoise = systemDifferentialPrivacyParams { + epsilon = 3.1 + delta = 3.2 + } + publisherNoise = systemDifferentialPrivacyParams { + epsilon = 4.1 + delta = 4.2 + } + } + ellipticCurveId = 415 + noiseMechanism = SystemNoiseMechanism.GEOMETRIC + } +} + private const val AGGREGATOR_DUCHY_ID = "aggregator_duchy" private const val AGGREGATOR_HERALD_ID = "aggregator_herald" private const val NON_AGGREGATOR_DUCHY_ID = "worker_duchy" private const val NON_AGGREGATOR_HERALD_ID = "worker_herald" -private val AGGREGATOR_PROTOCOLS_SETUP_CONFIG = - ProtocolsSetupConfig.newBuilder() - .apply { - liquidLegionsV2Builder.apply { - role = RoleInComputation.AGGREGATOR - externalAggregatorDuchyId = DUCHY_ONE - } +private val AGGREGATOR_PROTOCOLS_SETUP_CONFIG = protocolsSetupConfig { + liquidLegionsV2 = liquidLegionsV2SetupConfig { + role = RoleInComputation.AGGREGATOR + externalAggregatorDuchyId = DUCHY_ONE + } + reachOnlyLiquidLegionsV2 = liquidLegionsV2SetupConfig { + role = RoleInComputation.AGGREGATOR + externalAggregatorDuchyId = DUCHY_ONE + } +} + +private val NON_AGGREGATOR_PROTOCOLS_SETUP_CONFIG = protocolsSetupConfig { + liquidLegionsV2 = liquidLegionsV2SetupConfig { + role = RoleInComputation.NON_AGGREGATOR + externalAggregatorDuchyId = DUCHY_ONE + } + reachOnlyLiquidLegionsV2 = liquidLegionsV2SetupConfig { + role = RoleInComputation.NON_AGGREGATOR + externalAggregatorDuchyId = DUCHY_ONE + } +} + +private val LLV2_AGGREGATOR_COMPUTATION_DETAILS = computationDetails { + liquidLegionsV2 = + LiquidLegionsSketchAggregationV2Kt.computationDetails { role = RoleInComputation.AGGREGATOR } +} + +private val LLV2_NON_AGGREGATOR_COMPUTATION_DETAILS = computationDetails { + liquidLegionsV2 = + LiquidLegionsSketchAggregationV2Kt.computationDetails { + role = RoleInComputation.NON_AGGREGATOR } - .build() -private val NON_AGGREGATOR_PROTOCOLS_SETUP_CONFIG = - ProtocolsSetupConfig.newBuilder() - .apply { - liquidLegionsV2Builder.apply { - role = RoleInComputation.NON_AGGREGATOR - externalAggregatorDuchyId = DUCHY_ONE - } +} + +private val RO_LLV2_AGGREGATOR_COMPUTATION_DETAILS = computationDetails { + reachOnlyLiquidLegionsV2 = + ReachOnlyLiquidLegionsSketchAggregationV2Kt.computationDetails { + role = RoleInComputation.AGGREGATOR } - .build() +} -private val AGGREGATOR_COMPUTATION_DETAILS = - ComputationDetails.newBuilder() - .apply { liquidLegionsV2Builder.apply { role = RoleInComputation.AGGREGATOR } } - .build() -private val NON_AGGREGATOR_COMPUTATION_DETAILS = - ComputationDetails.newBuilder() - .apply { liquidLegionsV2Builder.apply { role = RoleInComputation.NON_AGGREGATOR } } - .build() +private val RO_LLV2_NON_AGGREGATOR_COMPUTATION_DETAILS = computationDetails { + reachOnlyLiquidLegionsV2 = + ReachOnlyLiquidLegionsSketchAggregationV2Kt.computationDetails { + role = RoleInComputation.NON_AGGREGATOR + } +} private const val COMPUTATION_GLOBAL_ID = "123" @@ -358,7 +401,7 @@ class HeraldTest { } @Test - fun `syncStatuses creates new computations`() = runTest { + fun `syncStatuses creates new llv2 computations`() = runTest { val confirmingKnown = buildComputationAtKingdom("1", Computation.State.PENDING_REQUISITION_PARAMS) @@ -376,8 +419,8 @@ class HeraldTest { fakeComputationDatabase.addComputation( globalId = confirmingKnown.key.computationId, - stage = INITIALIZATION_PHASE.toProtocolStage(), - computationDetails = AGGREGATOR_COMPUTATION_DETAILS, + stage = LiquidLegionsSketchAggregationV2.Stage.INITIALIZATION_PHASE.toProtocolStage(), + computationDetails = LLV2_AGGREGATOR_COMPUTATION_DETAILS, blobs = listOf(newInputBlobMetadata(0L, "input-blob"), newEmptyOutputBlobMetadata(1L)) ) @@ -393,9 +436,9 @@ class HeraldTest { ) .containsExactly( confirmingKnown.key.computationId.toLong(), - INITIALIZATION_PHASE.toProtocolStage(), + LiquidLegionsSketchAggregationV2.Stage.INITIALIZATION_PHASE.toProtocolStage(), confirmingUnknown.key.computationId.toLong(), - INITIALIZATION_PHASE.toProtocolStage() + LiquidLegionsSketchAggregationV2.Stage.INITIALIZATION_PHASE.toProtocolStage() ) assertThat( @@ -419,42 +462,43 @@ class HeraldTest { liquidLegionsV2 = LiquidLegionsSketchAggregationV2Kt.computationDetails { role = RoleInComputation.AGGREGATOR - parameters = parameters { - maximumFrequency = 10 - liquidLegionsSketch = liquidLegionsSketchParameters { - decayRate = 12.0 - size = 100_000L - } - noise = liquidLegionsV2NoiseConfig { - reachNoiseConfig = reachNoiseConfig { - blindHistogramNoise = duchyDifferentialPrivacyParams { - epsilon = 3.1 - delta = 3.2 - } - noiseForPublisherNoise = duchyDifferentialPrivacyParams { - epsilon = 4.1 - delta = 4.2 + parameters = + LiquidLegionsSketchAggregationV2Kt.ComputationDetailsKt.parameters { + maximumFrequency = 10 + liquidLegionsSketch = liquidLegionsSketchParameters { + decayRate = 12.0 + size = 100_000L + } + noise = liquidLegionsV2NoiseConfig { + reachNoiseConfig = reachNoiseConfig { + blindHistogramNoise = duchyDifferentialPrivacyParams { + epsilon = 3.1 + delta = 3.2 + } + noiseForPublisherNoise = duchyDifferentialPrivacyParams { + epsilon = 4.1 + delta = 4.2 + } + globalReachDpNoise = duchyDifferentialPrivacyParams { + epsilon = 1.1 + delta = 1.2 + } } - globalReachDpNoise = duchyDifferentialPrivacyParams { - epsilon = 1.1 - delta = 1.2 + frequencyNoiseConfig = duchyDifferentialPrivacyParams { + epsilon = 2.1 + delta = 2.2 } + noiseMechanism = LiquidLegionsV2NoiseConfig.NoiseMechanism.GEOMETRIC } - frequencyNoiseConfig = duchyDifferentialPrivacyParams { - epsilon = 2.1 - delta = 2.2 - } - noiseMechanism = LiquidLegionsV2NoiseConfig.NoiseMechanism.GEOMETRIC + ellipticCurveId = 415 } - ellipticCurveId = 415 - } } } ) } @Test - fun `syncStatuses creates new computations for reach-only`() = runTest { + fun `syncStatuses creates new llv2 computations for reach-only`() = runTest { val confirmingKnown = buildComputationAtKingdom( "1", @@ -477,8 +521,8 @@ class HeraldTest { fakeComputationDatabase.addComputation( globalId = confirmingKnown.key.computationId, - stage = INITIALIZATION_PHASE.toProtocolStage(), - computationDetails = AGGREGATOR_COMPUTATION_DETAILS, + stage = LiquidLegionsSketchAggregationV2.Stage.INITIALIZATION_PHASE.toProtocolStage(), + computationDetails = LLV2_AGGREGATOR_COMPUTATION_DETAILS, blobs = listOf(newInputBlobMetadata(0L, "input-blob"), newEmptyOutputBlobMetadata(1L)) ) @@ -494,9 +538,9 @@ class HeraldTest { ) .containsExactly( confirmingKnown.key.computationId.toLong(), - INITIALIZATION_PHASE.toProtocolStage(), + LiquidLegionsSketchAggregationV2.Stage.INITIALIZATION_PHASE.toProtocolStage(), confirmingUnknown.key.computationId.toLong(), - INITIALIZATION_PHASE.toProtocolStage() + LiquidLegionsSketchAggregationV2.Stage.INITIALIZATION_PHASE.toProtocolStage() ) assertThat( @@ -520,38 +564,139 @@ class HeraldTest { liquidLegionsV2 = LiquidLegionsSketchAggregationV2Kt.computationDetails { role = RoleInComputation.AGGREGATOR - parameters = parameters { - maximumFrequency = 10 - liquidLegionsSketch = liquidLegionsSketchParameters { - decayRate = 12.0 - size = 100_000L - } - noise = liquidLegionsV2NoiseConfig { - noiseMechanism = LiquidLegionsV2NoiseConfig.NoiseMechanism.GEOMETRIC - reachNoiseConfig = reachNoiseConfig { - blindHistogramNoise = duchyDifferentialPrivacyParams { - epsilon = 3.1 - delta = 3.2 - } - noiseForPublisherNoise = duchyDifferentialPrivacyParams { - epsilon = 4.1 - delta = 4.2 + parameters = + LiquidLegionsSketchAggregationV2Kt.ComputationDetailsKt.parameters { + maximumFrequency = 10 + liquidLegionsSketch = liquidLegionsSketchParameters { + decayRate = 12.0 + size = 100_000L + } + noise = liquidLegionsV2NoiseConfig { + noiseMechanism = LiquidLegionsV2NoiseConfig.NoiseMechanism.GEOMETRIC + reachNoiseConfig = reachNoiseConfig { + blindHistogramNoise = duchyDifferentialPrivacyParams { + epsilon = 3.1 + delta = 3.2 + } + noiseForPublisherNoise = duchyDifferentialPrivacyParams { + epsilon = 4.1 + delta = 4.2 + } + globalReachDpNoise = duchyDifferentialPrivacyParams { + epsilon = 1.1 + delta = 1.2 + } } - globalReachDpNoise = duchyDifferentialPrivacyParams { - epsilon = 1.1 - delta = 1.2 + } + ellipticCurveId = 415 + } + } + } + ) + } + + @Test + fun `syncStatuses creates new rollv2 computations for reach-only`() = runTest { + val confirmingKnown = + buildComputationAtKingdom( + "1", + Computation.State.PENDING_REQUISITION_PARAMS, + serializedMeasurementSpec = SERIALIZED_REACH_ONLY_MEASUREMENT_SPEC, + mpcProtocolConfig = RO_LLV2_MPC_PROTOCOL_CONFIG + ) + + val systemApiRequisitions1 = + REACH_ONLY_REQUISITION_1.toSystemRequisition("2", Requisition.State.UNFULFILLED) + val systemApiRequisitions2 = + REACH_ONLY_REQUISITION_2.toSystemRequisition("2", Requisition.State.UNFULFILLED) + val confirmingUnknown = + buildComputationAtKingdom( + "2", + Computation.State.PENDING_REQUISITION_PARAMS, + listOf(systemApiRequisitions1, systemApiRequisitions2), + serializedMeasurementSpec = SERIALIZED_REACH_ONLY_MEASUREMENT_SPEC, + mpcProtocolConfig = RO_LLV2_MPC_PROTOCOL_CONFIG + ) + mockStreamActiveComputationsToReturn(confirmingKnown, confirmingUnknown) + + fakeComputationDatabase.addComputation( + globalId = confirmingKnown.key.computationId, + stage = + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.INITIALIZATION_PHASE.toProtocolStage(), + computationDetails = RO_LLV2_AGGREGATOR_COMPUTATION_DETAILS, + blobs = listOf(newInputBlobMetadata(0L, "input-blob"), newEmptyOutputBlobMetadata(1L)) + ) + + aggregatorHerald.syncStatuses() + + verifyBlocking(continuationTokensService, atLeastOnce()) { + setContinuationToken(eq(setContinuationTokenRequest { this.token = "2" })) + } + assertThat( + fakeComputationDatabase.mapValues { (_, fakeComputation) -> + fakeComputation.computationStage + } + ) + .containsExactly( + confirmingKnown.key.computationId.toLong(), + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.INITIALIZATION_PHASE.toProtocolStage(), + confirmingUnknown.key.computationId.toLong(), + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.INITIALIZATION_PHASE.toProtocolStage() + ) + + assertThat( + fakeComputationDatabase[confirmingUnknown.key.computationId.toLong()]?.requisitionsList + ) + .containsExactly( + REACH_ONLY_REQUISITION_1.toRequisitionMetadata(Requisition.State.UNFULFILLED), + REACH_ONLY_REQUISITION_2.toRequisitionMetadata(Requisition.State.UNFULFILLED) + ) + assertThat( + fakeComputationDatabase[confirmingUnknown.key.computationId.toLong()]?.computationDetails + ) + .isEqualTo( + computationDetails { + blobsStoragePrefix = "computation-blob-storage/2" + kingdomComputation = kingdomComputationDetails { + publicApiVersion = PUBLIC_API_VERSION + measurementSpec = SERIALIZED_REACH_ONLY_MEASUREMENT_SPEC + measurementPublicKey = PUBLIC_API_ENCRYPTION_PUBLIC_KEY.toDuchyEncryptionPublicKey() + } + reachOnlyLiquidLegionsV2 = + ReachOnlyLiquidLegionsSketchAggregationV2Kt.computationDetails { + role = RoleInComputation.AGGREGATOR + parameters = + ReachOnlyLiquidLegionsSketchAggregationV2Kt.ComputationDetailsKt.parameters { + sketchParameters = liquidLegionsSketchParameters { + decayRate = 12.0 + size = 100_000L + } + noise = liquidLegionsV2NoiseConfig { + noiseMechanism = LiquidLegionsV2NoiseConfig.NoiseMechanism.GEOMETRIC + reachNoiseConfig = reachNoiseConfig { + blindHistogramNoise = duchyDifferentialPrivacyParams { + epsilon = 3.1 + delta = 3.2 + } + noiseForPublisherNoise = duchyDifferentialPrivacyParams { + epsilon = 4.1 + delta = 4.2 + } + globalReachDpNoise = duchyDifferentialPrivacyParams { + epsilon = 1.1 + delta = 1.2 + } } } + ellipticCurveId = 415 } - ellipticCurveId = 415 - } } } ) } @Test - fun `syncStatuses update llv2 computations in WAIT_REQUISITIONS_AND_KEY_SET`() = runTest { + fun `syncStatuses confirms participants for llv2 computations`() = runTest { val globalId = "123456" val systemApiRequisitions1 = REQUISITION_1.toSystemRequisition(globalId, Requisition.State.FULFILLED, DUCHY_ONE) @@ -636,8 +781,9 @@ class HeraldTest { fakeComputationDatabase.addComputation( globalId = globalId, - stage = WAIT_REQUISITIONS_AND_KEY_SET.toProtocolStage(), - computationDetails = NON_AGGREGATOR_COMPUTATION_DETAILS, + stage = + LiquidLegionsSketchAggregationV2.Stage.WAIT_REQUISITIONS_AND_KEY_SET.toProtocolStage(), + computationDetails = LLV2_NON_AGGREGATOR_COMPUTATION_DETAILS, requisitions = listOf( REQUISITION_1.toRequisitionMetadata(Requisition.State.UNFULFILLED), @@ -659,11 +805,11 @@ class HeraldTest { val duchyComputationToken = fakeComputationDatabase.readComputationToken(globalId)!! assertThat(duchyComputationToken.computationStage) - .isEqualTo(CONFIRMATION_PHASE.toProtocolStage()) + .isEqualTo(LiquidLegionsSketchAggregationV2.Stage.CONFIRMATION_PHASE.toProtocolStage()) assertThat(duchyComputationToken.computationDetails.liquidLegionsV2.participantList) .isEqualTo( mutableListOf( - ComputationParticipant.newBuilder() + LiquidLegionsSketchAggregationV2.ComputationDetails.ComputationParticipant.newBuilder() .apply { duchyId = DUCHY_THREE publicKeyBuilder.apply { @@ -675,7 +821,7 @@ class HeraldTest { duchyCertificateDer = ByteString.copyFromUtf8("duchyCertificateDer_3") } .build(), - ComputationParticipant.newBuilder() + LiquidLegionsSketchAggregationV2.ComputationDetails.ComputationParticipant.newBuilder() .apply { duchyId = DUCHY_TWO publicKeyBuilder.apply { @@ -687,7 +833,7 @@ class HeraldTest { duchyCertificateDer = ByteString.copyFromUtf8("duchyCertificateDer_2") } .build(), - ComputationParticipant.newBuilder() + LiquidLegionsSketchAggregationV2.ComputationDetails.ComputationParticipant.newBuilder() .apply { duchyId = DUCHY_ONE publicKeyBuilder.apply { @@ -709,7 +855,188 @@ class HeraldTest { } @Test - fun `syncStatuses starts computations in wait_to_start`() = runTest { + fun `syncStatuses confirms participants for rollv2 computations`() = runTest { + val globalId = "123456" + val systemApiRequisitions1 = + REQUISITION_1.toSystemRequisition(globalId, Requisition.State.FULFILLED, DUCHY_ONE) + val systemApiRequisitions2 = + REQUISITION_2.toSystemRequisition(globalId, Requisition.State.FULFILLED, DUCHY_TWO) + val v2alphaApiElgamalPublicKey1 = elGamalPublicKey { + generator = ByteString.copyFromUtf8("generator_1") + element = ByteString.copyFromUtf8("element_1") + } + val v2alphaApiElgamalPublicKey2 = elGamalPublicKey { + generator = ByteString.copyFromUtf8("generator_2") + element = ByteString.copyFromUtf8("element_2") + } + val v2alphaApiElgamalPublicKey3 = elGamalPublicKey { + generator = ByteString.copyFromUtf8("generator_3") + element = ByteString.copyFromUtf8("element_3") + } + val systemComputationParticipant1 = systemComputationParticipant { + name = ComputationParticipantKey(globalId, DUCHY_ONE).toName() + requisitionParams = requisitionParams { + duchyCertificate = "duchyCertificate_1" + duchyCertificateDer = ByteString.copyFromUtf8("duchyCertificateDer_1") + reachOnlyLiquidLegionsV2 = + RequisitionParamsKt.liquidLegionsV2 { + elGamalPublicKey = v2alphaApiElgamalPublicKey1.toByteString() + elGamalPublicKeySignature = ByteString.copyFromUtf8("elGamalPublicKeySignature_1") + } + } + } + val systemComputationParticipant2 = systemComputationParticipant { + name = ComputationParticipantKey(globalId, DUCHY_TWO).toName() + requisitionParams = requisitionParams { + duchyCertificate = "duchyCertificate_2" + duchyCertificateDer = ByteString.copyFromUtf8("duchyCertificateDer_2") + reachOnlyLiquidLegionsV2 = + RequisitionParamsKt.liquidLegionsV2 { + elGamalPublicKey = v2alphaApiElgamalPublicKey2.toByteString() + elGamalPublicKeySignature = ByteString.copyFromUtf8("elGamalPublicKeySignature_2") + } + } + } + val systemComputationParticipant3 = systemComputationParticipant { + name = ComputationParticipantKey(globalId, DUCHY_THREE).toName() + requisitionParams = requisitionParams { + duchyCertificate = "duchyCertificate_3" + duchyCertificateDer = ByteString.copyFromUtf8("duchyCertificateDer_3") + reachOnlyLiquidLegionsV2 = + RequisitionParamsKt.liquidLegionsV2 { + elGamalPublicKey = v2alphaApiElgamalPublicKey3.toByteString() + elGamalPublicKeySignature = ByteString.copyFromUtf8("elGamalPublicKeySignature_3") + } + } + } + val waitingRequisitionsAndKeySet = + buildComputationAtKingdom( + globalId, + Computation.State.PENDING_PARTICIPANT_CONFIRMATION, + listOf(systemApiRequisitions1, systemApiRequisitions2), + listOf( + systemComputationParticipant1, + systemComputationParticipant2, + systemComputationParticipant3 + ) + ) + + mockStreamActiveComputationsToReturn(waitingRequisitionsAndKeySet) + + fakeComputationDatabase.addComputation( + globalId = globalId, + stage = + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.WAIT_REQUISITIONS_AND_KEY_SET + .toProtocolStage(), + computationDetails = RO_LLV2_NON_AGGREGATOR_COMPUTATION_DETAILS, + requisitions = + listOf( + REQUISITION_1.toRequisitionMetadata(Requisition.State.UNFULFILLED), + REQUISITION_2.toRequisitionMetadata(Requisition.State.UNFULFILLED) + ) + ) + + aggregatorHerald.syncStatuses() + + verifyBlocking(continuationTokensService, atLeastOnce()) { + setContinuationToken( + eq( + setContinuationTokenRequest { + this.token = waitingRequisitionsAndKeySet.continuationToken() + } + ) + ) + } + + val duchyComputationToken = fakeComputationDatabase.readComputationToken(globalId)!! + assertThat(duchyComputationToken.computationStage) + .isEqualTo( + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.CONFIRMATION_PHASE.toProtocolStage() + ) + assertThat(duchyComputationToken.computationDetails.reachOnlyLiquidLegionsV2.participantList) + .isEqualTo( + listOf( + ReachOnlyLiquidLegionsSketchAggregationV2Kt.ComputationDetailsKt.computationParticipant { + duchyId = DUCHY_THREE + publicKey = internalElgamalPublicKey { + generator = ByteString.copyFromUtf8("generator_3") + element = ByteString.copyFromUtf8("element_3") + } + elGamalPublicKey = v2alphaApiElgamalPublicKey3.toByteString() + elGamalPublicKeySignature = ByteString.copyFromUtf8("elGamalPublicKeySignature_3") + duchyCertificateDer = ByteString.copyFromUtf8("duchyCertificateDer_3") + }, + ReachOnlyLiquidLegionsSketchAggregationV2Kt.ComputationDetailsKt.computationParticipant { + duchyId = DUCHY_TWO + publicKey = internalElgamalPublicKey { + generator = ByteString.copyFromUtf8("generator_2") + element = ByteString.copyFromUtf8("element_2") + } + elGamalPublicKey = v2alphaApiElgamalPublicKey2.toByteString() + elGamalPublicKeySignature = ByteString.copyFromUtf8("elGamalPublicKeySignature_2") + duchyCertificateDer = ByteString.copyFromUtf8("duchyCertificateDer_2") + }, + ReachOnlyLiquidLegionsSketchAggregationV2Kt.ComputationDetailsKt.computationParticipant { + duchyId = DUCHY_ONE + publicKey = internalElgamalPublicKey { + generator = ByteString.copyFromUtf8("generator_1") + element = ByteString.copyFromUtf8("element_1") + } + elGamalPublicKey = v2alphaApiElgamalPublicKey1.toByteString() + elGamalPublicKeySignature = ByteString.copyFromUtf8("elGamalPublicKeySignature_1") + duchyCertificateDer = ByteString.copyFromUtf8("duchyCertificateDer_1") + } + ) + ) + assertThat(duchyComputationToken.requisitionsList) + .containsExactly( + REQUISITION_1.toRequisitionMetadata(Requisition.State.FULFILLED, DUCHY_ONE), + REQUISITION_2.toRequisitionMetadata(Requisition.State.FULFILLED, DUCHY_TWO) + ) + } + + @Test + fun `syncStatuses starts llv2 computations`() = runTest { + val waitingToStart = + buildComputationAtKingdom(COMPUTATION_GLOBAL_ID, Computation.State.PENDING_COMPUTATION) + val addingNoise = buildComputationAtKingdom("231313", Computation.State.PENDING_COMPUTATION) + mockStreamActiveComputationsToReturn(waitingToStart, addingNoise) + + fakeComputationDatabase.addComputation( + globalId = waitingToStart.key.computationId, + stage = LiquidLegionsSketchAggregationV2.Stage.WAIT_TO_START.toProtocolStage(), + computationDetails = LLV2_NON_AGGREGATOR_COMPUTATION_DETAILS, + blobs = listOf(newPassThroughBlobMetadata(0L, "local-copy-of-sketches")) + ) + + fakeComputationDatabase.addComputation( + globalId = addingNoise.key.computationId, + stage = LiquidLegionsSketchAggregationV2.Stage.SETUP_PHASE.toProtocolStage(), + computationDetails = LLV2_AGGREGATOR_COMPUTATION_DETAILS, + blobs = + listOf(newInputBlobMetadata(0L, "inputs-to-add-noise"), newEmptyOutputBlobMetadata(1L)) + ) + + aggregatorHerald.syncStatuses() + + verifyBlocking(continuationTokensService, atLeastOnce()) { + setContinuationToken(eq(setContinuationTokenRequest { this.token = "231313" })) + } + assertThat( + fakeComputationDatabase.mapValues { (_, fakeComputation) -> + fakeComputation.computationStage + } + ) + .containsExactly( + waitingToStart.key.computationId.toLong(), + LiquidLegionsSketchAggregationV2.Stage.SETUP_PHASE.toProtocolStage(), + addingNoise.key.computationId.toLong(), + LiquidLegionsSketchAggregationV2.Stage.SETUP_PHASE.toProtocolStage() + ) + } + + @Test + fun `syncStatuses starts rollv2 computations`() = runTest { val waitingToStart = buildComputationAtKingdom(COMPUTATION_GLOBAL_ID, Computation.State.PENDING_COMPUTATION) val addingNoise = buildComputationAtKingdom("231313", Computation.State.PENDING_COMPUTATION) @@ -717,15 +1044,15 @@ class HeraldTest { fakeComputationDatabase.addComputation( globalId = waitingToStart.key.computationId, - stage = WAIT_TO_START.toProtocolStage(), - computationDetails = NON_AGGREGATOR_COMPUTATION_DETAILS, + stage = ReachOnlyLiquidLegionsSketchAggregationV2.Stage.WAIT_TO_START.toProtocolStage(), + computationDetails = RO_LLV2_NON_AGGREGATOR_COMPUTATION_DETAILS, blobs = listOf(newPassThroughBlobMetadata(0L, "local-copy-of-sketches")) ) fakeComputationDatabase.addComputation( globalId = addingNoise.key.computationId, - stage = SETUP_PHASE.toProtocolStage(), - computationDetails = AGGREGATOR_COMPUTATION_DETAILS, + stage = ReachOnlyLiquidLegionsSketchAggregationV2.Stage.SETUP_PHASE.toProtocolStage(), + computationDetails = RO_LLV2_AGGREGATOR_COMPUTATION_DETAILS, blobs = listOf(newInputBlobMetadata(0L, "inputs-to-add-noise"), newEmptyOutputBlobMetadata(1L)) ) @@ -742,9 +1069,9 @@ class HeraldTest { ) .containsExactly( waitingToStart.key.computationId.toLong(), - SETUP_PHASE.toProtocolStage(), + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.SETUP_PHASE.toProtocolStage(), addingNoise.key.computationId.toLong(), - SETUP_PHASE.toProtocolStage() + ReachOnlyLiquidLegionsSketchAggregationV2.Stage.SETUP_PHASE.toProtocolStage() ) } @@ -769,8 +1096,8 @@ class HeraldTest { } fakeComputationDatabase.addComputation( globalId = computation.key.computationId, - stage = INITIALIZATION_PHASE.toProtocolStage(), - computationDetails = NON_AGGREGATOR_COMPUTATION_DETAILS, + stage = LiquidLegionsSketchAggregationV2.Stage.INITIALIZATION_PHASE.toProtocolStage(), + computationDetails = LLV2_NON_AGGREGATOR_COMPUTATION_DETAILS, blobs = listOf(newInputBlobMetadata(0L, "local-copy-of-sketches")) ) @@ -786,22 +1113,23 @@ class HeraldTest { ) .containsExactly( computation.key.computationId.toLong(), - INITIALIZATION_PHASE.toProtocolStage() + LiquidLegionsSketchAggregationV2.Stage.INITIALIZATION_PHASE.toProtocolStage() ) // Update the state. fakeComputationDatabase.remove(computation.key.computationId.toLong()) fakeComputationDatabase.addComputation( globalId = computation.key.computationId, - stage = WAIT_TO_START.toProtocolStage(), - computationDetails = NON_AGGREGATOR_COMPUTATION_DETAILS, + stage = LiquidLegionsSketchAggregationV2.Stage.WAIT_TO_START.toProtocolStage(), + computationDetails = LLV2_NON_AGGREGATOR_COMPUTATION_DETAILS, blobs = listOf(newPassThroughBlobMetadata(0L, "local-copy-of-sketches")) ) // Verify that next attempt succeeds. syncResult.await() val finalComputation = assertNotNull(fakeComputationDatabase[computation.key.computationId.toLong()]) - assertThat(finalComputation.computationStage).isEqualTo(SETUP_PHASE.toProtocolStage()) + assertThat(finalComputation.computationStage) + .isEqualTo(LiquidLegionsSketchAggregationV2.Stage.SETUP_PHASE.toProtocolStage()) } @Test @@ -825,8 +1153,8 @@ class HeraldTest { fakeComputationDatabase.addComputation( globalId = computation.key.computationId, - stage = INITIALIZATION_PHASE.toProtocolStage(), - computationDetails = NON_AGGREGATOR_COMPUTATION_DETAILS, + stage = LiquidLegionsSketchAggregationV2.Stage.INITIALIZATION_PHASE.toProtocolStage(), + computationDetails = LLV2_NON_AGGREGATOR_COMPUTATION_DETAILS, blobs = listOf(newInputBlobMetadata(0L, "local-copy-of-sketches")) ) @@ -952,7 +1280,7 @@ class HeraldTest { token = computationToken { globalComputationId = request.globalComputationId localComputationId = request.globalComputationId.toLong() - computationDetails = AGGREGATOR_COMPUTATION_DETAILS + computationDetails = LLV2_AGGREGATOR_COMPUTATION_DETAILS } } } @@ -1015,6 +1343,7 @@ class HeraldTest { systemApiRequisitions: List = listOf(), systemComputationParticipant: List = listOf(), serializedMeasurementSpec: ByteString = SERIALIZED_MEASUREMENT_SPEC, + mpcProtocolConfig: MpcProtocolConfig = LLV2_MPC_PROTOCOL_CONFIG ): Computation { return computation { name = ComputationKey(globalId).toName() @@ -1023,7 +1352,7 @@ class HeraldTest { state = stateAtKingdom requisitions += systemApiRequisitions computationParticipants += systemComputationParticipant - mpcProtocolConfig = MPC_PROTOCOL_CONFIG + this.mpcProtocolConfig = mpcProtocolConfig } }