From ff5757c98151c57d6c304c1167c9ac317f9ae597 Mon Sep 17 00:00:00 2001 From: Sanjay Vasandani Date: Mon, 26 Feb 2024 13:07:14 -0800 Subject: [PATCH 1/4] Fix simulator service account not being created for GKE. (#1501) This addresses an issue that was missed in #1324. --- src/main/k8s/dev/BUILD.bazel | 6 +++--- src/main/k8s/dev/bigquery_edp_simulator_gke.cue | 11 ----------- src/main/k8s/dev/edp_simulator_gke.cue | 11 ++++++++++- 3 files changed, 13 insertions(+), 15 deletions(-) diff --git a/src/main/k8s/dev/BUILD.bazel b/src/main/k8s/dev/BUILD.bazel index beb7ba11b9d..bac8ecf9606 100644 --- a/src/main/k8s/dev/BUILD.bazel +++ b/src/main/k8s/dev/BUILD.bazel @@ -1,3 +1,4 @@ +load("@wfa_common_jvm//build:defs.bzl", "expand_template") load("@wfa_rules_cue//cue:defs.bzl", "cue_library") load( "//build:variables.bzl", @@ -8,9 +9,8 @@ load( "KINGDOM_K8S_SETTINGS", "SIMULATOR_K8S_SETTINGS", ) -load("@wfa_common_jvm//build:defs.bzl", "expand_template") -load("//src/main/k8s:macros.bzl", "cue_dump") load("//build/k8s:defs.bzl", "kustomization_dir") +load("//src/main/k8s:macros.bzl", "cue_dump") SECRET_NAME = "certs-and-configs" @@ -339,13 +339,13 @@ EDP_SIMULATOR_TAGS = { "image_tag": IMAGE_REPOSITORY_SETTINGS.image_tag, "kingdom_public_api_target": KINGDOM_K8S_SETTINGS.public_api_target, "duchy_public_api_target": DUCHY_K8S_SETTINGS.public_api_target, + "google_cloud_project": GCLOUD_SETTINGS.project, } cue_dump( name = "bigquery_edp_simulator_gke", srcs = ["bigquery_edp_simulator_gke.cue"], cue_tags = dict(EDP_SIMULATOR_TAGS.items() + { - "google_cloud_project": GCLOUD_SETTINGS.project, "bigquery_dataset": SIMULATOR_K8S_SETTINGS.bigquery_dataset, "bigquery_table": SIMULATOR_K8S_SETTINGS.bigquery_table, }.items()), diff --git a/src/main/k8s/dev/bigquery_edp_simulator_gke.cue b/src/main/k8s/dev/bigquery_edp_simulator_gke.cue index 6ac980c8608..4490e475c1d 100644 --- a/src/main/k8s/dev/bigquery_edp_simulator_gke.cue +++ b/src/main/k8s/dev/bigquery_edp_simulator_gke.cue @@ -14,8 +14,6 @@ package k8s -#SimulatorServiceAccount: "simulator" - _bigQueryConfig: #BigQueryConfig & { dataset: string @tag("bigquery_dataset") table: string @tag("bigquery_table") @@ -41,16 +39,7 @@ edp_simulators: { _container: { resources: _resourceRequirements } - spec: template: spec: #ServiceAccountPodSpec & { - serviceAccountName: #SimulatorServiceAccount - } } } } } - -serviceAccounts: { - "\(#SimulatorServiceAccount)": #WorkloadIdentityServiceAccount & { - _iamServiceAccountName: "simulator" - } -} diff --git a/src/main/k8s/dev/edp_simulator_gke.cue b/src/main/k8s/dev/edp_simulator_gke.cue index 77459e879de..752a563c97c 100644 --- a/src/main/k8s/dev/edp_simulator_gke.cue +++ b/src/main/k8s/dev/edp_simulator_gke.cue @@ -33,6 +33,8 @@ _secret_name: string @tag("secret_name") _kingdomPublicApiTarget: string @tag("kingdom_public_api_target") _duchyPublicApiTarget: string @tag("duchy_public_api_target") +#SimulatorServiceAccount: "simulator" + objectSets: [ serviceAccounts, configMaps, @@ -62,7 +64,9 @@ edp_simulators: { _mc_resource_name: _mc_name deployment: { - spec: template: spec: #SpotVmPodSpec + spec: template: spec: #SpotVmPodSpec & #ServiceAccountPodSpec & { + serviceAccountName: #SimulatorServiceAccount + } } } } @@ -71,6 +75,11 @@ edp_simulators: { serviceAccounts: [Name=string]: #ServiceAccount & { metadata: name: Name } +serviceAccounts: { + "\(#SimulatorServiceAccount)": #WorkloadIdentityServiceAccount & { + _iamServiceAccountName: "simulator" + } +} configMaps: [Name=string]: #ConfigMap & { metadata: name: Name From 05abea3bf89b0347362fa070cfb722fcf87657c8 Mon Sep 17 00:00:00 2001 From: Rieman Date: Mon, 26 Feb 2024 15:09:56 -0800 Subject: [PATCH 2/4] Output uniformly random guess at frequency distribution when reach is too small (#1498) The variance calculation of frequency distribution will output `NaN` when reach is zero. Moreover, the estimated variance is not accurate when reach is impractically small. The solution is to check whether the reach is too small using its confidence interval. If the confidence interval of the reach contains values <= 0, we claim the reach is too small for an accurate variance estimate of frequency distribution, and output the variance of uniformly random draw from [0, 1]. --- .../stats/LiquidLegions.kt | 24 +- .../stats/MeasurementStatistics.kt | 38 + .../measurementconsumer/stats/Variances.kt | 74 +- .../stats/VariancesTest.kt | 814 +++++++++++------- 4 files changed, 601 insertions(+), 349 deletions(-) diff --git a/src/main/kotlin/org/wfanet/measurement/measurementconsumer/stats/LiquidLegions.kt b/src/main/kotlin/org/wfanet/measurement/measurementconsumer/stats/LiquidLegions.kt index 2e298ee08eb..c65f32cc9f6 100644 --- a/src/main/kotlin/org/wfanet/measurement/measurementconsumer/stats/LiquidLegions.kt +++ b/src/main/kotlin/org/wfanet/measurement/measurementconsumer/stats/LiquidLegions.kt @@ -253,11 +253,16 @@ object LiquidLegions { sketchParams: LiquidLegionsSketchParams, collisionResolution: Boolean, frequencyNoiseVariance: Double, - totalReach: Long, - reachRatio: Double, - frequencyMeasurementParams: FrequencyMeasurementParams, - multiplier: Int, + relativeFrequencyMeasurementVarianceParams: RelativeFrequencyMeasurementVarianceParams, ): Double { + val ( + totalReach: Long, + reachMeasurementVariance: Double, + reachRatio: Double, + frequencyMeasurementParams: FrequencyMeasurementParams, + multiplier: Int) = + relativeFrequencyMeasurementVarianceParams + val expectedRegisterNum = expectedNumberOfNonDestroyedRegisters( sketchParams, @@ -265,8 +270,15 @@ object LiquidLegions { totalReach, frequencyMeasurementParams.vidSamplingInterval.width, ) - if (expectedRegisterNum < 1.0) { - return 0.0 + + // When reach is too small, we have little info to estimate frequency, and thus the estimate of + // relative frequency is equivalent to a uniformly random guess of a probability in [0, 1]. + if ( + isReachTooSmallForComputingRelativeFrequencyVariance(totalReach, reachMeasurementVariance) || + expectedRegisterNum < 1.0 + ) { + return if (frequencyMeasurementParams.maximumFrequency == multiplier) 0.0 + else VARIANCE_OF_UNIFORMLY_RANDOM_PROBABILITY } val registerNumVariance = diff --git a/src/main/kotlin/org/wfanet/measurement/measurementconsumer/stats/MeasurementStatistics.kt b/src/main/kotlin/org/wfanet/measurement/measurementconsumer/stats/MeasurementStatistics.kt index 6bdd535b618..64cc1fbfff9 100644 --- a/src/main/kotlin/org/wfanet/measurement/measurementconsumer/stats/MeasurementStatistics.kt +++ b/src/main/kotlin/org/wfanet/measurement/measurementconsumer/stats/MeasurementStatistics.kt @@ -16,6 +16,7 @@ package org.wfanet.measurement.measurementconsumer.stats +import kotlin.math.sqrt import org.wfanet.measurement.eventdataprovider.noiser.DpParams /** Noise mechanism enums. */ @@ -72,6 +73,43 @@ data class FrequencyMeasurementVarianceParams( val measurementParams: FrequencyMeasurementParams, ) +/** + * The parameters used to compute the variance of a reach ratio at a certain frequency in a relative + * frequency measurement. + */ +data class RelativeFrequencyMeasurementVarianceParams( + val totalReach: Long, + val reachMeasurementVariance: Double, + val reachRatio: Double, + val measurementParams: FrequencyMeasurementParams, + val multiplier: Int, +) + +/** + * A reach result is considered too small when computing variances of relative frequency if the 95% + * confidence interval of the reach covers 0 or negative values. The 95% confidence interval = + * reach_result +/- 1.96 * reach_std. + */ +private const val REACH_THRESHOLD_CONSTANT_FOR_RELATIVE_FREQUENCY_VARIANCE = 1.96 + +/** + * A uniformly random number from [0, 1] has a variance equal to 1 / 12 + * (en.wikipedia.org/wiki/Continuous_uniform_distribution). + */ +const val VARIANCE_OF_UNIFORMLY_RANDOM_PROBABILITY = 1.0 / 12.0 + +/** Determines if a reach is too small for computing relative frequency variance. */ +fun isReachTooSmallForComputingRelativeFrequencyVariance( + reach: Long, + reachVariance: Double, +): Boolean { + // A reach result is considered too small for computing variances of relative frequency if the + // confidence interval lower bound of the reach <= 0. + val reachConfidenceIntervalLowerBound = + reach - REACH_THRESHOLD_CONSTANT_FOR_RELATIVE_FREQUENCY_VARIANCE * sqrt(reachVariance) + return reachConfidenceIntervalLowerBound <= 0 +} + /** The parameters used to compute the variance of an impression measurement. */ data class ImpressionMeasurementVarianceParams( val impression: Long, diff --git a/src/main/kotlin/org/wfanet/measurement/measurementconsumer/stats/Variances.kt b/src/main/kotlin/org/wfanet/measurement/measurementconsumer/stats/Variances.kt index 455a76731ce..d0aca6c4db9 100644 --- a/src/main/kotlin/org/wfanet/measurement/measurementconsumer/stats/Variances.kt +++ b/src/main/kotlin/org/wfanet/measurement/measurementconsumer/stats/Variances.kt @@ -146,11 +146,26 @@ object VariancesImpl : Variances { * Different types of frequency histograms have different values of [multiplier]. */ private fun deterministicFrequencyRelativeVariance( - totalReach: Long, - reachRatio: Double, - measurementParams: FrequencyMeasurementParams, - multiplier: Int, + relativeFrequencyMeasurementVarianceParams: RelativeFrequencyMeasurementVarianceParams ): Double { + + val ( + totalReach: Long, + reachMeasurementVariance: Double, + reachRatio: Double, + measurementParams: FrequencyMeasurementParams, + multiplier: Int) = + relativeFrequencyMeasurementVarianceParams + + // When reach is too small, we have little info to estimate frequency, and thus the estimate of + // relative frequency is equivalent to a uniformly random guess at probability. + if ( + isReachTooSmallForComputingRelativeFrequencyVariance(totalReach, reachMeasurementVariance) + ) { + return if (measurementParams.maximumFrequency == multiplier) 0.0 + else VARIANCE_OF_UNIFORMLY_RANDOM_PROBABILITY + } + val frequencyNoiseVariance: Double = computeNoiseVariance(measurementParams.dpParams, measurementParams.noiseMechanism) val varPart1 = @@ -257,22 +272,16 @@ object VariancesImpl : Variances { sketchParams: LiquidLegionsSketchParams, measurementParams: FrequencyMeasurementParams, ): ( - totalReach: Long, - reachRatio: Double, - measurementParams: FrequencyMeasurementParams, - multiplier: Int, + relativeFrequencyMeasurementVarianceParams: RelativeFrequencyMeasurementVarianceParams ) -> Double { val frequencyNoiseVariance: Double = computeNoiseVariance(measurementParams.dpParams, measurementParams.noiseMechanism) - return { totalReach, reachRatio, freqParams, multiplier -> + return { relativeFrequencyMeasurementVarianceParams -> LiquidLegions.liquidLegionsFrequencyRelativeVariance( sketchParams = sketchParams, collisionResolution = true, frequencyNoiseVariance = frequencyNoiseVariance, - totalReach = totalReach, - reachRatio = reachRatio, - frequencyMeasurementParams = freqParams, - multiplier = multiplier, + relativeFrequencyMeasurementVarianceParams = relativeFrequencyMeasurementVarianceParams, ) } } @@ -326,23 +335,17 @@ object VariancesImpl : Variances { sketchParams: LiquidLegionsSketchParams, measurementParams: FrequencyMeasurementParams, ): ( - totalReach: Long, - reachRatio: Double, - measurementParams: FrequencyMeasurementParams, - multiplier: Int, + relativeFrequencyMeasurementVarianceParams: RelativeFrequencyMeasurementVarianceParams ) -> Double { val frequencyNoiseVariance: Double = computeDistributedNoiseVariance(measurementParams.dpParams, measurementParams.noiseMechanism) - return { totalReach, reachRatio, freqParams, multiplier -> + return { relativeFrequencyMeasurementVarianceParams -> LiquidLegions.liquidLegionsFrequencyRelativeVariance( sketchParams = sketchParams, collisionResolution = false, frequencyNoiseVariance = frequencyNoiseVariance, - totalReach = totalReach, - reachRatio = reachRatio, - frequencyMeasurementParams = freqParams, - multiplier = multiplier, + relativeFrequencyMeasurementVarianceParams = relativeFrequencyMeasurementVarianceParams, ) } } @@ -387,10 +390,7 @@ object VariancesImpl : Variances { params: FrequencyMeasurementVarianceParams, frequencyRelativeVarianceFun: ( - totalReach: Long, - reachRatio: Double, - measurementParams: FrequencyMeasurementParams, - multiplier: Int, + relativeFrequencyMeasurementVarianceParams: RelativeFrequencyMeasurementVarianceParams ) -> Double, frequencyCountVarianceFun: ( @@ -415,20 +415,26 @@ object VariancesImpl : Variances { val relativeVariances: Map = (1..maximumFrequency).associateWith { frequency -> frequencyRelativeVarianceFun( - params.totalReach, - params.relativeFrequencyDistribution.getOrDefault(frequency, 0.0), - params.measurementParams, - 1, + RelativeFrequencyMeasurementVarianceParams( + params.totalReach, + params.reachMeasurementVariance, + params.relativeFrequencyDistribution.getOrDefault(frequency, 0.0), + params.measurementParams, + 1, + ) ) } val kPlusRelativeVariances: Map = (1..maximumFrequency).associateWith { frequency -> frequencyRelativeVarianceFun( - params.totalReach, - kPlusRelativeFrequencyDistribution.getValue(frequency), - params.measurementParams, - maximumFrequency - frequency + 1, + RelativeFrequencyMeasurementVarianceParams( + params.totalReach, + params.reachMeasurementVariance, + kPlusRelativeFrequencyDistribution.getValue(frequency), + params.measurementParams, + maximumFrequency - frequency + 1, + ) ) } diff --git a/src/test/kotlin/org/wfanet/measurement/measurementconsumer/stats/VariancesTest.kt b/src/test/kotlin/org/wfanet/measurement/measurementconsumer/stats/VariancesTest.kt index 6f00bc0cec4..5baab270000 100644 --- a/src/test/kotlin/org/wfanet/measurement/measurementconsumer/stats/VariancesTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/measurementconsumer/stats/VariancesTest.kt @@ -452,9 +452,9 @@ class VariancesTest { @Test fun `computeMeasurementVariance returns for deterministic reach-frequency when total reach is small and sampling width is small`() { - val vidSamplingIntervalWidth = 1e-4 - val totalReach = 1L - val reachDpParams = DpParams(0.05, 1e-15) + val vidSamplingIntervalWidth = 5e-2 + val totalReach = 1000L + val reachDpParams = DpParams(0.5, 1e-15) val reachMeasurementParams = ReachMeasurementParams( VidSamplingInterval(0.0, vidSamplingIntervalWidth), @@ -495,29 +495,36 @@ class VariancesTest { ) val expectedRK = - listOf(130523240799.76, 110944754739.79, 104418592319.84, 110944753539.91, 130523238400.0) + listOf( + 0.5270081502877656, + 0.4480709277446008, + 0.4209985202302125, + 0.4457909277446008, + 0.5224481502877656, + ) val expectedRKPlus = - listOf(0.0, 130523240799.75995, 215363345459.78998, 215363344259.90997, 130523238400.0) + listOf(0.0, 0.5270081502877656, 0.8660294479748132, 0.8637494479748131, 0.5224481502877656) val expectedNK = listOf( - 2.5828737279268425e+23, - 2.195442669924104e+23, - 2.06629897600801e+23, - 2.1954426461785614e+23, - 2.582873680435757e+23, + 599711.6131995119, + 505012.9301397436, + 469784.2425013215, + 494025.5502842458, + 577736.8534885163, ) val expectedNKPlus = listOf( - 1978861168399.0, - 2.5828737279307992e+23, - 4.261741614272709e+23, - 4.2617415905271664e+23, - 2.582873680435757e+23, + 105826.2014523311, + 620876.8534899781, + 967202.4129305566, + 956215.0330750588, + 577736.8534885163, ) for (frequency in 1..maximumFrequency) { assertThat(rKVars.getValue(frequency)) .isWithin(computeErrorTolerance(rKVars.getValue(frequency), expectedRK[frequency - 1])) + .of(expectedRK[frequency - 1]) } for (frequency in 1..maximumFrequency) { assertThat(rKPlusVars.getValue(frequency)) @@ -543,8 +550,8 @@ class VariancesTest { @Test fun `computeMeasurementVariance returns for deterministic reach-frequency when total reach is small and sampling width is large`() { val vidSamplingIntervalWidth = 0.9 - val totalReach = 10L - val reachDpParams = DpParams(0.05, 1e-15) + val totalReach = 1000L + val reachDpParams = DpParams(0.5, 1e-15) val reachMeasurementParams = ReachMeasurementParams( VidSamplingInterval(0.0, vidSamplingIntervalWidth), @@ -586,29 +593,35 @@ class VariancesTest { val expectedRK = listOf( - 16.116646716049377, - 13.69921637530864, - 12.89296181728395, - 13.697883041975308, - 16.113980049382715, + 0.0016391609576782886, + 0.0013939534806932123, + 0.0013077732105870755, + 0.0013806201473598788, + 0.001612494291011622, ) val expectedRKPlus = - listOf(0.0, 16.116646716049377, 26.590400414814813, 26.58906708148148, 16.113980049382715) + listOf( + 0.0, + 0.0016391609576782882, + 0.0026839489135025095, + 0.002670615580169177, + 0.001612494291011622, + ) val expectedNK = listOf( - 399274.49027152435, - 338261.103357843, - 317260.89827872894, - 336273.87503418204, - 395300.03362420265, + 1700.4372667720345, + 1428.6003082911895, + 1323.4327071114915, + 1384.9344632329405, + 1613.1055766555369, ) val expectedNKPlus = listOf( - 24431.495782716047, - 404160.78942806745, - 654501.1302572051, - 652513.9019335442, - 395300.03362420265, + 379.0932143590467, + 1776.2559096438433, + 2719.0847696156193, + 2675.418924557371, + 1613.1055766555369, ) for (frequency in 1..maximumFrequency) { @@ -843,11 +856,114 @@ class VariancesTest { } } + @Test + fun `computeMeasurementVariance returns for deterministic reach-frequency when total reach is too small`() { + val vidSamplingIntervalWidth = 5e-2 + val totalReach = 1L + val reachDpParams = DpParams(0.5, 1e-15) + val reachMeasurementParams = + ReachMeasurementParams( + VidSamplingInterval(0.0, vidSamplingIntervalWidth), + reachDpParams, + NoiseMechanism.GAUSSIAN, + ) + val reachMeasurementVarianceParams = + ReachMeasurementVarianceParams(totalReach, reachMeasurementParams) + val reachMeasurementVariance = + VariancesImpl.computeMeasurementVariance( + DeterministicMethodology, + reachMeasurementVarianceParams, + ) + + val maximumFrequency = 5 + val relativeFrequencyDistribution = + (1..maximumFrequency).associateWith { (maximumFrequency - it) / 10.0 } + val frequencyDpParams = DpParams(0.2, 1e-15) + val frequencyMeasurementParams = + FrequencyMeasurementParams( + VidSamplingInterval(0.0, vidSamplingIntervalWidth), + frequencyDpParams, + NoiseMechanism.GAUSSIAN, + maximumFrequency, + ) + val frequencyMeasurementVarianceParams = + FrequencyMeasurementVarianceParams( + totalReach, + reachMeasurementVariance, + relativeFrequencyDistribution, + frequencyMeasurementParams, + ) + + val (rKVars, rKPlusVars, nKVars, nKPlusVars) = + VariancesImpl.computeMeasurementVariance( + DeterministicMethodology, + frequencyMeasurementVarianceParams, + ) + + val expectedRK = + listOf( + 0.08333333333333336, + 0.08333333333333336, + 0.08333333333333336, + 0.08333333333333336, + 0.08333333333333336, + ) + val expectedRKPlus = + listOf( + 0.0, + 0.08333333333333336, + 0.08333333333333336, + 0.08333333333333336, + 0.08333333333333336, + ) + val expectedNK = + listOf( + 21132.41568673391, + 15053.25158507073, + 10710.991512454175, + 8105.63546888424, + 7237.183454360929, + ) + val expectedNKPlus = + listOf( + 86845.20145233112, + 38501.45597720013, + 15053.251585070733, + 8105.63546888424, + 7237.183454360929, + ) + + for (frequency in 1..maximumFrequency) { + assertThat(rKVars.getValue(frequency)) + .isWithin(computeErrorTolerance(rKVars.getValue(frequency), expectedRK[frequency - 1])) + .of(expectedRK[frequency - 1]) + } + for (frequency in 1..maximumFrequency) { + assertThat(rKPlusVars.getValue(frequency)) + .isWithin( + computeErrorTolerance(rKPlusVars.getValue(frequency), expectedRKPlus[frequency - 1]) + ) + .of(expectedRKPlus[frequency - 1]) + } + for (frequency in 1..maximumFrequency) { + assertThat(nKVars.getValue(frequency)) + .isWithin(computeErrorTolerance(nKVars.getValue(frequency), expectedNK[frequency - 1])) + .of(expectedNK[frequency - 1]) + } + for (frequency in 1..maximumFrequency) { + assertThat(nKPlusVars.getValue(frequency)) + .isWithin( + computeErrorTolerance(nKPlusVars.getValue(frequency), expectedNKPlus[frequency - 1]) + ) + .of(expectedNKPlus[frequency - 1]) + } + } + @Test fun `computeMeasurementVariance returns for deterministic reach-frequency when maximum frequency is 1`() { - val vidSamplingIntervalWidth = 1e-3 - val totalReach = 100L - val reachDpParams = DpParams(0.05, 1e-15) + val vidSamplingIntervalWidth = 0.9 + val totalReach = 1000L + val reachDpParams = DpParams(0.5, 1e-15) val reachMeasurementParams = ReachMeasurementParams( VidSamplingInterval(0.0, vidSamplingIntervalWidth), @@ -888,8 +1004,8 @@ class VariancesTest { val expectedRK = 0.0 val expectedRKPlus = 0.0 - val expectedNK = 19788711484.000004 - val expectedNKPlus = 19788711484.000004 + val expectedNK = 379.0932143590467 + val expectedNKPlus = 379.0932143590467 assertThat(rKVars.getValue(1)) .isWithin(computeErrorTolerance(rKVars.getValue(1), expectedRK)) @@ -1401,11 +1517,11 @@ class VariancesTest { @Test fun `computeMeasurementVariance returns for LiquidLegionsSketch reach-frequency when small total reach, small sampling width, and small decay rate`() { - val decayRate = 1e-3 + val decayRate = 1e-2 val sketchSize = 100000L - val vidSamplingIntervalWidth = 1e-2 - val totalReach = 10L - val reachDpParams = DpParams(0.1, 1e-9) + val vidSamplingIntervalWidth = 1e-1 + val totalReach = 1000L + val reachDpParams = DpParams(0.5, 1e-9) val reachMeasurementParams = ReachMeasurementParams( VidSamplingInterval(0.0, vidSamplingIntervalWidth), @@ -1445,12 +1561,38 @@ class VariancesTest { frequencyMeasurementVarianceParams, ) - val expectedRK = listOf(0.0, 0.0, 0.0, 0.0, 0.0) - val expectedRKPlus = listOf(0.0, 0.0, 0.0, 0.0, 0.0) + val expectedRK = + listOf( + 0.03260806784967439, + 0.027770887646892062, + 0.02579829441483908, + 0.026690288153515446, + 0.030446868862921157, + ) + val expectedRKPlus = + listOf( + 0.0, + 0.032608067849674405, + 0.05212838273722897, + 0.05104778324385237, + 0.030446868862921157, + ) val expectedNK = - listOf(4033588.314191154, 2268893.4267325234, 1008397.0785477886, 252099.26963694714, 0.0) + listOf( + 36541.340051824605, + 30175.900914382, + 27141.9691905643, + 27439.544880371428, + 31068.62798380339, + ) val expectedNKPlus = - listOf(25209926.963694707, 9075573.706930095, 2268893.426732524, 252099.26963694714, 0.0) + listOf( + 20421.118627387797, + 40625.56377730211, + 55030.80330142469, + 52294.447267414136, + 31068.62798380339, + ) for (frequency in 1..maximumFrequency) { assertThat(rKVars.getValue(frequency)) @@ -1482,9 +1624,9 @@ class VariancesTest { fun `computeMeasurementVariance returns for LiquidLegionsSketch reach-frequency when small total reach, small sampling width, and large decay rate`() { val decayRate = 100.0 val sketchSize = 100000L - val vidSamplingIntervalWidth = 1e-2 - val totalReach = 10L - val reachDpParams = DpParams(0.1, 1e-9) + val vidSamplingIntervalWidth = 1e-1 + val totalReach = 1000L + val reachDpParams = DpParams(0.5, 1e-9) val reachMeasurementParams = ReachMeasurementParams( VidSamplingInterval(0.0, vidSamplingIntervalWidth), @@ -1524,12 +1666,38 @@ class VariancesTest { frequencyMeasurementVarianceParams, ) - val expectedRK = listOf(0.0, 0.0, 0.0, 0.0, 0.0) - val expectedRKPlus = listOf(0.0, 0.0, 0.0, 0.0, 0.0) + val expectedRK = + listOf( + 0.0341805974697979, + 0.02910901042381123, + 0.02704846424526232, + 0.027998958934151177, + 0.031960494490477796, + ) + val expectedRKPlus = + listOf( + 0.0, + 0.0341805974697979, + 0.05467740601619346, + 0.05356735452653342, + 0.031960494490477796, + ) val expectedNK = - listOf(4037545.1990081007, 2271119.174442056, 1009386.2997520252, 252346.5749380063, 0.0) + listOf( + 38421.913999295386, + 31710.604101578978, + 28512.945060473943, + 28828.936875980187, + 32658.579548097718, + ) val expectedNKPlus = - listOf(25234657.493800625, 9084476.697768226, 2271119.1744420566, 252346.5749380063, 0.0) + listOf( + 21842.123182039944, + 42790.338635703316, + 57837.46774005718, + 54955.80051445836, + 32658.579548097718, + ) for (frequency in 1..maximumFrequency) { assertThat(rKVars.getValue(frequency)) @@ -1559,12 +1727,11 @@ class VariancesTest { @Test fun `computeMeasurementVariance returns for LiquidLegionsSketch reach-frequency when small total reach, large sampling width, and small decay rate`() { - val decayRate = 1e-3 + val decayRate = 1e-2 val sketchSize = 100000L - val vidSamplingIntervalWidth = 1.0 - val totalReach = 10L - val reachDpParams = DpParams(0.1, 1e-9) + val totalReach = 1000L + val reachDpParams = DpParams(0.5, 1e-9) val reachMeasurementParams = ReachMeasurementParams( VidSamplingInterval(0.0, vidSamplingIntervalWidth), @@ -1606,29 +1773,35 @@ class VariancesTest { val expectedRK = listOf( - 3.047425140662341, - 2.5903113959637496, - 2.4379399717258194, - 2.590310867948552, - 3.047424084631945, + 0.0003084195788141115, + 0.00026218666238117175, + 0.0002465755543090122, + 0.00026158625459763276, + 0.0003072187632470337, ) val expectedRKPlus = - listOf(0.0, 3.0474251406623356, 5.028250663669304, 5.028250135654106, 3.047424084631945) + listOf( + 0.0, + 0.0003084195788141115, + 0.0005079616729787987, + 0.0005073612651952599, + 0.0003072187632470337, + ) val expectedNK = listOf( - 8391.77774412893, - 7017.029139893044, - 6491.543194894815, - 6815.319909134243, - 7988.3592826113245, + 327.8586634488892, + 273.1319033763575, + 251.45587557046386, + 262.8305800311573, + 307.25601675845763, ) val expectedNKPlus = listOf( - 2521.348083089697, - 8896.047360746856, - 13407.7165659821, - 13206.007335223296, - 7988.3592826113245, + 121.26053444866557, + 352.1107703385642, + 518.9367167831224, + 508.635393437924, + 307.25601675845763, ) for (frequency in 1..maximumFrequency) { @@ -1661,10 +1834,9 @@ class VariancesTest { fun `computeMeasurementVariance returns for LiquidLegionsSketch reach-frequency when small total reach, large sampling width, and large decay rate`() { val decayRate = 100.0 val sketchSize = 100000L - val vidSamplingIntervalWidth = 1.0 - val totalReach = 10L - val reachDpParams = DpParams(0.1, 1e-9) + val totalReach = 1000L + val reachDpParams = DpParams(0.5, 1e-9) val reachMeasurementParams = ReachMeasurementParams( VidSamplingInterval(0.0, vidSamplingIntervalWidth), @@ -1706,29 +1878,35 @@ class VariancesTest { val expectedRK = listOf( - 3.062431068650644, - 2.60306772923472, - 2.449937810218249, - 2.6030413116012294, - 3.062378233383659, + 0.0005405568786010254, + 0.00046100419546627093, + 0.00042428097671869016, + 0.000430387222358283, + 0.0004793229323850494, ) val expectedRKPlus = - listOf(0.0, 3.062431068650639, 5.052970315941649, 5.052943898308155, 3.062378233383659) + listOf( + 0.0, + 0.000540556878601025, + 0.0008444625413743105, + 0.0008138455682663226, + 0.0004793229323850494, + ) val expectedNK = listOf( - 8511.182542906958, - 7117.383664496014, - 6584.86219003212, - 6913.618119515269, - 8103.65145294546, + 635.9104781849892, + 514.7338108340482, + 448.29111332634784, + 436.582385661839, + 479.6076278405448, ) val expectedNKPlus = listOf( - 2546.195484478623, - 9020.42163980267, - 13600.304826852385, - 13396.539281871635, - 8103.65145294546, + 593.9533376357285, + 754.7011457121116, + 898.4199131064815, + 820.2684879342742, + 479.6076278405448, ) for (frequency in 1..maximumFrequency) { @@ -1759,12 +1937,12 @@ class VariancesTest { @Test fun `computeMeasurementVariance returns for LiquidLegionsSketch reach-frequency when large total reach, small sampling width, and small decay rate`() { - val decayRate = 1e-3 - val sketchSize = 100000L + val decayRate = 1e-2 + val sketchSize = 3_000_000L - val vidSamplingIntervalWidth = 0.01 + val vidSamplingIntervalWidth = 0.1 val totalReach = 3e8.toLong() - val reachDpParams = DpParams(0.1, 1e-9) + val reachDpParams = DpParams(0.5, 1e-9) val reachMeasurementParams = ReachMeasurementParams( VidSamplingInterval(0.0, vidSamplingIntervalWidth), @@ -1806,35 +1984,35 @@ class VariancesTest { val expectedRK = listOf( - 2.4296711936002306e-06, - 2.1252005145602013e-06, - 1.6238436215468216e-06, - 9.256005145600891e-07, - 3.04711936000057e-08, + 7.92374327869203e-08, + 6.933190871070413e-08, + 5.282946173981984e-08, + 2.9730091874267457e-08, + 3.3799114047009024e-11, ) val expectedRKPlus = listOf( 0.0, - 2.4296711936002306e-06, - 2.1495774694402064e-06, - 9.499774694400937e-07, - 3.04711936000057e-08, + 7.92374327869203e-08, + 6.935894800194172e-08, + 2.9757131165505068e-08, + 3.3799114047009024e-11, ) val expectedNK = listOf( - 4.606184563605248e+32, - 2.5910006531914206e+32, - 1.1515754019878232e+32, - 2.8790880999445434e+31, - 8.77211314124454e+25, + 1947873584904.0, + 1097907668595.9999, + 489940606067.0, + 123972397310.87502, + 3042330.23507461, ) val expectedNKPlus = listOf( - 2.8788216360657756e+33, - 1.0363827835736802e+33, - 2.591001354960473e+32, - 2.879095117635056e+31, - 8.77211314124454e+25, + 12129632842688.0, + 4373800153436.0005, + 1097910102459.9999, + 123974831175.0, + 3042330.23507461, ) for (frequency in 1..maximumFrequency) { @@ -2077,12 +2255,12 @@ class VariancesTest { @Test fun `computeMeasurementVariance returns for LiquidLegionsSketch reach-frequency when maximum frequency is 1`() { - val decayRate = 1e-3 + val decayRate = 1e-2 val sketchSize = 100000L val vidSamplingIntervalWidth = 0.1 - val totalReach = 100L - val reachDpParams = DpParams(0.1, 1e-9) + val totalReach = 1000L + val reachDpParams = DpParams(0.5, 1e-9) val reachMeasurementParams = ReachMeasurementParams( VidSamplingInterval(0.0, vidSamplingIntervalWidth), @@ -2123,8 +2301,8 @@ class VariancesTest { val expectedRK = 0.0 val expectedRKPlus = 0.0 - val expectedNK = 253034.8083089697 - val expectedNKPlus = 253034.8083089697 + val expectedNK = 20421.118627387797 + val expectedNKPlus = 20421.118627387797 assertThat(rKVars.getValue(1)) .isWithin(computeErrorTolerance(rKVars.getValue(1), expectedRK)) @@ -2141,13 +2319,12 @@ class VariancesTest { } @Test - fun `computeMeasurementVariance returns for LiquidLegionsSketch reach-frequency when reach is less than 3`() { + fun `computeMeasurementVariance returns for LiquidLegionsSketch reach-frequency when reach is too small`() { val decayRate = 100.0 val sketchSize = 100000L - - val vidSamplingIntervalWidth = 1e-3 + val vidSamplingIntervalWidth = 1e-1 val totalReach = 1L - val reachDpParams = DpParams(0.1, 1e-9) + val reachDpParams = DpParams(0.5, 1e-9) val reachMeasurementParams = ReachMeasurementParams( VidSamplingInterval(0.0, vidSamplingIntervalWidth), @@ -2187,12 +2364,38 @@ class VariancesTest { frequencyMeasurementVarianceParams, ) - val expectedRK = listOf(0.0, 0.0, 0.0, 0.0, 0.0) - val expectedRKPlus = listOf(0.0, 0.0, 0.0, 0.0, 0.0) + val expectedRK = + listOf( + 0.08333333333333336, + 0.08333333333333336, + 0.08333333333333336, + 0.08333333333333336, + 0.08333333333333336, + ) + val expectedRKPlus = + listOf( + 0.0, + 0.08333333333333336, + 0.08333333333333336, + 0.08333333333333336, + 0.08333333333333336, + ) val expectedNK = - listOf(403738839.74081016, 227103097.35420564, 100934709.93520254, 25233677.483800635, 0.0) + listOf( + 2777.413131500171, + 1978.455244356285, + 1407.7710392535107, + 1065.360516191846, + 951.2236751712913, + ) val expectedNKPlus = - listOf(2523367748.380063, 908412389.416823, 227103097.35420576, 25233677.483800635, 0.0) + listOf( + 11413.684102055491, + 5060.14995191127, + 1978.4552443562857, + 1065.360516191846, + 951.2236751712913, + ) for (frequency in 1..maximumFrequency) { assertThat(rKVars.getValue(frequency)) @@ -2222,12 +2425,12 @@ class VariancesTest { @Test fun `computeMeasurementVariance returns for LiquidLegionsV2 reach-frequency when small total reach, small sampling width, and small decay rate`() { - val decayRate = 1e-3 + val decayRate = 1e-2 val sketchSize = 100000L - val vidSamplingIntervalWidth = 1e-2 - val totalReach = 10L - val reachDpParams = DpParams(0.1, 1e-9) + val vidSamplingIntervalWidth = 0.1 + val totalReach = 1000L + val reachDpParams = DpParams(0.5, 1e-9) val reachMeasurementParams = ReachMeasurementParams( VidSamplingInterval(0.0, vidSamplingIntervalWidth), @@ -2267,12 +2470,38 @@ class VariancesTest { frequencyMeasurementVarianceParams, ) - val expectedRK = listOf(0.0, 0.0, 0.0, 0.0, 0.0) - val expectedRKPlus = listOf(0.0, 0.0, 0.0, 0.0, 0.0) + val expectedRK = + listOf( + 0.05034342984868496, + 0.042845920790728145, + 0.039986714975769756, + 0.04176581240380985, + 0.04818321307484837, + ) + val expectedRKPlus = + listOf( + 0.0, + 0.05034342984868497, + 0.08139249125060682, + 0.08031238286368855, + 0.04818321307484837, + ) val expectedNK = - listOf(6924955.020569592, 3895287.199070394, 1731238.755142398, 432809.6887855995, 0.0) + listOf( + 55886.2748486677, + 46346.59750952415, + 42094.47733130149, + 43129.91431399983, + 49452.90845761906, + ) val expectedNKPlus = - listOf(43280968.87855994, 15581148.79628158, 3895287.199070395, 432809.6887855995, 0.0) + listOf( + 26351.405432392614, + 61156.55593514616, + 85908.92427561936, + 82692.2410800951, + 49452.90845761906, + ) for (frequency in 1..maximumFrequency) { assertThat(rKVars.getValue(frequency)) @@ -2305,9 +2534,9 @@ class VariancesTest { val decayRate = 100.0 val sketchSize = 100000L - val vidSamplingIntervalWidth = 1e-2 - val totalReach = 10L - val reachDpParams = DpParams(0.1, 1e-9) + val vidSamplingIntervalWidth = 1e-1 + val totalReach = 1000L + val reachDpParams = DpParams(0.5, 1e-9) val reachMeasurementParams = ReachMeasurementParams( VidSamplingInterval(0.0, vidSamplingIntervalWidth), @@ -2347,12 +2576,32 @@ class VariancesTest { frequencyMeasurementVarianceParams, ) - val expectedRK = listOf(0.0, 0.0, 0.0, 0.0, 0.0) - val expectedRKPlus = listOf(0.0, 0.0, 0.0, 0.0, 0.0) + val expectedRK = + listOf( + 0.05532921404123075, + 0.04708679728314121, + 0.043959556043144256, + 0.045947490321239945, + 0.05305060011742826, + ) + val expectedRKPlus = + listOf(0.0, 0.0553292140412307, 0.0895272773770838, 0.08838797041518255, 0.05305060011742826) val expectedNK = - listOf(6931747.5574386, 3899108.001059211, 1732936.88935965, 433234.2223399125, 0.0) + listOf( + 61439.77763656265, + 50977.015322569, + 46342.14128092452, + 47535.15551162926, + 54556.058014683185, + ) val expectedNKPlus = - listOf(43323422.233991235, 15596432.004236847, 3899108.001059212, 433234.2223399125, 0.0) + listOf( + 28377.77318112482, + 67115.33227278745, + 94621.86173431553, + 91180.00192337578, + 54556.058014683185, + ) for (frequency in 1..maximumFrequency) { assertThat(rKVars.getValue(frequency)) @@ -2382,12 +2631,12 @@ class VariancesTest { @Test fun `computeMeasurementVariance returns for LiquidLegionsV2 reach-frequency when small total reach, large sampling width, and small decay rate`() { - val decayRate = 1e-3 + val decayRate = 1e-2 val sketchSize = 100000L val vidSamplingIntervalWidth = 1.0 - val totalReach = 10L - val reachDpParams = DpParams(0.1, 1e-9) + val totalReach = 1000L + val reachDpParams = DpParams(0.5, 1e-9) val reachMeasurementParams = ReachMeasurementParams( VidSamplingInterval(0.0, vidSamplingIntervalWidth), @@ -2429,29 +2678,35 @@ class VariancesTest { val expectedRK = listOf( - 4.809657736908107, - 4.088209119574676, - 3.84772595911162, - 4.088208255518936, - 4.809656008796624, + 0.0004929886184465759, + 0.0004191004472982807, + 0.00039407024612424105, + 0.00041789801492445744, + 0.0004905837536989296, ) val expectedRKPlus = - listOf(0.0, 4.809657736908105, 7.935933926611972, 7.935933062556234, 4.809656008796624) + listOf( + 0.0, + 0.0004929886184465762, + 0.0008115674502574244, + 0.000810365017883601, + 0.0004905837536989296, + ) val expectedNK = listOf( - 21993.508778450716, - 18495.358884218687, - 17213.873575034922, - 18149.05285089939, - 21300.89671181211, + 522.1406505274354, + 435.52422094502253, + 401.4074465952872, + 419.79032747817297, + 490.6728635936937, ) val expectedNKPlus = listOf( - 4328.777582607534, - 22859.264294972218, - 35536.07625366837, - 35189.77022034907, - 21300.89671181211, + 181.6405335321324, + 558.4687572338153, + 828.0625118199823, + 812.3286183531291, + 490.6728635936937, ) for (frequency in 1..maximumFrequency) { @@ -2486,8 +2741,8 @@ class VariancesTest { val sketchSize = 100000L val vidSamplingIntervalWidth = 1.0 - val totalReach = 10L - val reachDpParams = DpParams(0.1, 1e-9) + val totalReach = 1000L + val reachDpParams = DpParams(0.5, 1e-9) val reachMeasurementParams = ReachMeasurementParams( VidSamplingInterval(0.0, vidSamplingIntervalWidth), @@ -2529,29 +2784,35 @@ class VariancesTest { val expectedRK = listOf( - 4.852333214777005, - 4.1244853983815455, - 3.8818550207757827, - 4.124442081959716, - 4.852246581933344, + 0.0013421266113186165, + 0.0011442914928402196, + 0.0010551206318847811, + 0.0010746140284523004, + 0.0012027716825427776, ) val expectedRKPlus = - listOf(0.0, 4.852333214777005, 8.00628266392822, 8.006239347506389, 4.852246581933344) + listOf( + 0.0, + 0.0013421266113186165, + 0.002106508838874441, + 0.0020368313744865223, + 0.0012027716825427776, + ) val expectedNK = listOf( - 22396.225589524824, - 18835.716311720287, - 17532.244237694387, - 18485.809367447113, - 21696.411700978482, + 1461.8600223995454, + 1211.9304731251468, + 1085.5879979647943, + 1082.832596918464, + 1203.6642699861634, ) val expectedNKPlus = listOf( - 4371.415731789478, - 23270.508735882726, - 36192.84567250306, - 35842.938728229885, - 21696.411700978482, + 742.1087944958126, + 1610.2817812986902, + 2174.861889114065, + 2045.7640129073934, + 1203.6642699861634, ) for (frequency in 1..maximumFrequency) { @@ -2582,12 +2843,12 @@ class VariancesTest { @Test fun `computeMeasurementVariance returns for LiquidLegionsV2 reach-frequency when large total reach, small sampling width, and small decay rate`() { - val decayRate = 1e-3 - val sketchSize = 100000L + val decayRate = 1e-2 + val sketchSize = 3_000_000L - val vidSamplingIntervalWidth = 0.01 + val vidSamplingIntervalWidth = 0.1 val totalReach = 3e8.toLong() - val reachDpParams = DpParams(0.1, 1e-9) + val reachDpParams = DpParams(0.5, 1e-9) val reachMeasurementParams = ReachMeasurementParams( VidSamplingInterval(0.0, vidSamplingIntervalWidth), @@ -2627,23 +2888,37 @@ class VariancesTest { frequencyMeasurementVarianceParams, ) - val expectedRK = listOf(0.0, 0.0, 0.0, 0.0, 0.0) - val expectedRKPlus = listOf(0.0, 0.0, 0.0, 0.0, 0.0) - val expectedNK = + val expectedRK = + listOf( + 0.0004352090325906098, + 0.0003743314887583703, + 0.000324680233771944, + 0.00028625526763133123, + 0.00025905659033653163, + ) + val expectedRKPlus = listOf( - 7.908010724473936e+32, - 4.448256032516588e+32, - 1.977002681118484e+32, - 4.94250670279621e+31, 0.0, + 0.00043520903259060994, + 0.0005815767610275956, + 0.0004935005399005565, + 0.00025905659033653163, + ) + val expectedNK = + listOf( + 41575204563504.01, + 35045372066235.004, + 29826056933018.504, + 25917259163848.25, + 23318978758726.46, ) val expectedNKPlus = listOf( - 4.942506702796209e+33, - 1.779302413006636e+33, - 4.44825603251659e+32, - 4.94250670279621e+31, - 0.0, + 14999149157200.002, + 44575034394936.01, + 53700555073216.01, + 44572442170829.375, + 23318978758726.46, ) for (frequency in 1..maximumFrequency) { @@ -2950,12 +3225,12 @@ class VariancesTest { } @Test - fun `computeMeasurementVariance returns for LiquidLegionsV2 reach-frequency when reach is less than 3`() { - val decayRate = 100.0 + fun `computeMeasurementVariance returns for LiquidLegionsV2 reach-frequency when reach is too small`() { + val decayRate = 1e-2 val sketchSize = 100000L - val vidSamplingIntervalWidth = 1e-3 - val totalReach = 1L + val vidSamplingIntervalWidth = 1e-2 + val totalReach = 10L val reachDpParams = DpParams(0.1, 1e-9) val reachMeasurementParams = ReachMeasurementParams( @@ -2996,12 +3271,38 @@ class VariancesTest { frequencyMeasurementVarianceParams, ) - val expectedRK = listOf(0.0, 0.0, 0.0, 0.0, 0.0) - val expectedRKPlus = listOf(0.0, 0.0, 0.0, 0.0, 0.0) + val expectedRK = + listOf( + 0.08333333333333336, + 0.08333333333333336, + 0.08333333333333336, + 0.08333333333333336, + 0.08333333333333336, + ) + val expectedRKPlus = + listOf( + 0.0, + 0.08333333333333336, + 0.08333333333333336, + 0.08333333333333336, + 0.08333333333333336, + ) val expectedNK = - listOf(692479821.8969592, 389519899.81703943, 173119955.4742398, 43279988.86855995, 0.0) + listOf( + 10531710.756153584, + 7502042.935890221, + 5337994.492844966, + 4039565.427017812, + 3606755.738408761, + ) val expectedNKPlus = - listOf(4327998886.855994, 1558079599.2681584, 389519899.8170396, 43279988.86855995, 0.0) + listOf( + 43280968.86090512, + 19187904.528334606, + 7502042.935890224, + 4039565.427017812, + 3606755.738408761, + ) for (frequency in 1..maximumFrequency) { assertThat(rKVars.getValue(frequency)) @@ -3330,111 +3631,6 @@ class VariancesTest { } } - @Test - fun `computeMetricVariance returns for reach-frequency`() { - val vidSamplingIntervalWidth = 1e-4 - val totalReach = 1L - val reachDpParams = DpParams(0.05, 1e-15) - val reachMeasurementParams = - ReachMeasurementParams( - VidSamplingInterval(0.0, vidSamplingIntervalWidth), - reachDpParams, - NoiseMechanism.GAUSSIAN, - ) - val reachMeasurementVarianceParams = - ReachMeasurementVarianceParams(totalReach, reachMeasurementParams) - val reachMeasurementVariance = - VariancesImpl.computeMeasurementVariance( - DeterministicMethodology, - reachMeasurementVarianceParams, - ) - - val maximumFrequency = 5 - val relativeFrequencyDistribution = - (1..maximumFrequency).associateWith { (maximumFrequency - it) / 10.0 } - val frequencyDpParams = DpParams(0.2, 1e-15) - val frequencyMeasurementParams = - FrequencyMeasurementParams( - VidSamplingInterval(0.0, vidSamplingIntervalWidth), - frequencyDpParams, - NoiseMechanism.GAUSSIAN, - maximumFrequency, - ) - val frequencyMeasurementVarianceParams = - FrequencyMeasurementVarianceParams( - totalReach, - reachMeasurementVariance, - relativeFrequencyDistribution, - frequencyMeasurementParams, - ) - - val weight = 2 - val coefficient = weight * weight.toDouble() - - val weightedFrequencyMeasurementVarianceParams = - WeightedFrequencyMeasurementVarianceParams( - binaryRepresentation = 1, - weight = weight, - measurementVarianceParams = frequencyMeasurementVarianceParams, - methodology = DeterministicMethodology, - ) - - val (rKVars, rKPlusVars, nKVars, nKPlusVars) = - VariancesImpl.computeMetricVariance( - FrequencyMetricVarianceParams(listOf(weightedFrequencyMeasurementVarianceParams)) - ) - - val expectedRK = - listOf(130523240799.76, 110944754739.79, 104418592319.84, 110944753539.91, 130523238400.0) - .map { it * coefficient } - val expectedRKPlus = - listOf(0.0, 130523240799.75995, 215363345459.78998, 215363344259.90997, 130523238400.0).map { - it * coefficient - } - val expectedNK = - listOf( - 2.5828737279268425e+23, - 2.195442669924104e+23, - 2.06629897600801e+23, - 2.1954426461785614e+23, - 2.582873680435757e+23, - ) - .map { it * coefficient } - val expectedNKPlus = - listOf( - 1978861168399.0, - 2.5828737279307992e+23, - 4.261741614272709e+23, - 4.2617415905271664e+23, - 2.582873680435757e+23, - ) - .map { it * coefficient } - - for (frequency in 1..maximumFrequency) { - assertThat(rKVars.getValue(frequency)) - .isWithin(computeErrorTolerance(rKVars.getValue(frequency), expectedRK[frequency - 1])) - } - for (frequency in 1..maximumFrequency) { - assertThat(rKPlusVars.getValue(frequency)) - .isWithin( - computeErrorTolerance(rKPlusVars.getValue(frequency), expectedRKPlus[frequency - 1]) - ) - .of(expectedRKPlus[frequency - 1]) - } - for (frequency in 1..maximumFrequency) { - assertThat(nKVars.getValue(frequency)) - .isWithin(computeErrorTolerance(nKVars.getValue(frequency), expectedNK[frequency - 1])) - .of(expectedNK[frequency - 1]) - } - for (frequency in 1..maximumFrequency) { - assertThat(nKPlusVars.getValue(frequency)) - .isWithin( - computeErrorTolerance(nKPlusVars.getValue(frequency), expectedNKPlus[frequency - 1]) - ) - .of(expectedNKPlus[frequency - 1]) - } - } - @Test fun `computeMetricVariance for reach-frequency throws IllegalArgumentException when no measurement params`() { assertFailsWith { From d69e6a816a153833ea8a37aacd2794492b38324e Mon Sep 17 00:00:00 2001 From: Sanjay Vasandani Date: Mon, 26 Feb 2024 17:56:18 -0800 Subject: [PATCH 3/4] Specify auto_minor_version_upgrade = false for AWS Postgres DB. (#1504) This avoids issues where the engine version is upgraded outside of Terraform, causing future Terraforming to fail unless engine_version is updated. --- src/main/terraform/aws/modules/rds-postgres/main.tf | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/main/terraform/aws/modules/rds-postgres/main.tf b/src/main/terraform/aws/modules/rds-postgres/main.tf index ffaee7322c8..c928eac22e0 100644 --- a/src/main/terraform/aws/modules/rds-postgres/main.tf +++ b/src/main/terraform/aws/modules/rds-postgres/main.tf @@ -39,9 +39,10 @@ module "db" { identifier = var.name # All available versions: https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/CHAP_PostgreSQL.html#PostgreSQL.Concepts - engine = "postgres" - engine_version = "15.4" - instance_class = var.instance_class + engine = "postgres" + engine_version = "15.5" + auto_minor_version_upgrade = false + instance_class = var.instance_class allocated_storage = 20 From 6fa65e15c7cc3802875de0eadac9c7de892fae75 Mon Sep 17 00:00:00 2001 From: renjiezh <94721804+renjiezh@users.noreply.github.com> Date: Mon, 26 Feb 2024 20:06:47 -0800 Subject: [PATCH 4/4] Update HMSS stages. Revert Duchy ControlService into Blob-Only Pattern. (#1476) --- .../HonestMajorityShareShuffleProtocol.kt | 45 ++- .../AsyncComputationControlService.kt | 79 ++--- .../computationcontrol/ProtocolStages.kt | 189 +---------- .../AdvanceComputationRequestHeaders.kt | 84 +---- .../v1alpha/ComputationControlService.kt | 44 ++- .../async_computation_control_service.proto | 2 - .../measurement/internal/duchy/crypto.proto | 8 + .../internal/duchy/protocol/BUILD.bazel | 1 + .../honest_majority_share_shuffle.proto | 84 +++-- .../v1alpha/computation_control_service.proto | 18 +- ...orityShareShuffleProtocolEnumStagesTest.kt | 24 +- .../AsyncComputationControlServiceTest.kt | 312 +++++++----------- .../HonestMajorityShareShuffleStagesTest.kt | 200 +---------- .../v1alpha/ComputationControlServiceTest.kt | 73 ++-- 14 files changed, 334 insertions(+), 829 deletions(-) diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/HonestMajorityShareShuffleProtocol.kt b/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/HonestMajorityShareShuffleProtocol.kt index 310685fdf9a..9b39b9e2d53 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/HonestMajorityShareShuffleProtocol.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/HonestMajorityShareShuffleProtocol.kt @@ -30,7 +30,9 @@ import org.wfanet.measurement.internal.duchy.protocol.HonestMajorityShareShuffle import org.wfanet.measurement.internal.duchy.protocol.HonestMajorityShareShuffle.Stage.STAGE_UNSPECIFIED import org.wfanet.measurement.internal.duchy.protocol.HonestMajorityShareShuffle.Stage.UNRECOGNIZED import org.wfanet.measurement.internal.duchy.protocol.HonestMajorityShareShuffle.Stage.WAIT_ON_AGGREGATION_INPUT -import org.wfanet.measurement.internal.duchy.protocol.HonestMajorityShareShuffle.Stage.WAIT_ON_SHUFFLE_INPUT +import org.wfanet.measurement.internal.duchy.protocol.HonestMajorityShareShuffle.Stage.WAIT_ON_SHUFFLE_INPUT_PHASE_ONE +import org.wfanet.measurement.internal.duchy.protocol.HonestMajorityShareShuffle.Stage.WAIT_ON_SHUFFLE_INPUT_PHASE_TWO +import org.wfanet.measurement.internal.duchy.protocol.HonestMajorityShareShuffle.Stage.WAIT_TO_START import org.wfanet.measurement.internal.duchy.protocol.HonestMajorityShareShuffleKt.stageDetails import org.wfanet.measurement.internal.duchy.protocol.HonestMajorityShareShuffleKt.waitOnAggregationInputDetails @@ -57,11 +59,11 @@ object HonestMajorityShareShuffleProtocol { override val validSuccessors = mapOf( - INITIALIZED to setOf(SETUP_PHASE), - // A Non-aggregator will skip WAIT_ON_SHUFFLE_INPUT into SHUFFLE_PHASE if the requisition - // data from EDPs and seed from the peer worker have been received. - SETUP_PHASE to setOf(WAIT_ON_SHUFFLE_INPUT, SHUFFLE_PHASE), - WAIT_ON_SHUFFLE_INPUT to setOf(SHUFFLE_PHASE), + INITIALIZED to setOf(WAIT_TO_START, WAIT_ON_SHUFFLE_INPUT_PHASE_ONE), + WAIT_TO_START to setOf(SETUP_PHASE), + WAIT_ON_SHUFFLE_INPUT_PHASE_ONE to setOf(SETUP_PHASE), + SETUP_PHASE to setOf(WAIT_ON_SHUFFLE_INPUT_PHASE_TWO, SHUFFLE_PHASE), + WAIT_ON_SHUFFLE_INPUT_PHASE_TWO to setOf(SHUFFLE_PHASE), WAIT_ON_AGGREGATION_INPUT to setOf(AGGREGATION_PHASE), SHUFFLE_PHASE to setOf(COMPLETE), AGGREGATION_PHASE to setOf(COMPLETE), @@ -91,10 +93,16 @@ object HonestMajorityShareShuffleProtocol { ): Boolean { return when (stage) { INITIALIZED, + WAIT_TO_START, + WAIT_ON_SHUFFLE_INPUT_PHASE_ONE, + WAIT_ON_SHUFFLE_INPUT_PHASE_TWO, SETUP_PHASE, SHUFFLE_PHASE -> details.role == RoleInComputation.NON_AGGREGATOR + WAIT_ON_AGGREGATION_INPUT, AGGREGATION_PHASE -> details.role == RoleInComputation.AGGREGATOR - else -> true /* Stage can be executed at either primary or non-primary */ + COMPLETE -> true /* Stage can be executed at either AGGREGATOR or NON_AGGREGATOR */ + STAGE_UNSPECIFIED, + UNRECOGNIZED -> error("Invalid Stage. $stage") } } @@ -106,7 +114,9 @@ object HonestMajorityShareShuffleProtocol { SETUP_PHASE, SHUFFLE_PHASE, AGGREGATION_PHASE -> AfterTransition.ADD_UNCLAIMED_TO_QUEUE - WAIT_ON_SHUFFLE_INPUT, + WAIT_TO_START, + WAIT_ON_SHUFFLE_INPUT_PHASE_ONE, + WAIT_ON_SHUFFLE_INPUT_PHASE_TWO, WAIT_ON_AGGREGATION_INPUT -> AfterTransition.DO_NOT_ADD_TO_QUEUE COMPLETE -> error("Computation should be ended with call to endComputation(...)") // Stages that we can't transition to ever. @@ -122,19 +132,18 @@ object HonestMajorityShareShuffleProtocol { ): Int { return when (stage) { SETUP_PHASE, - WAIT_ON_SHUFFLE_INPUT, - SHUFFLE_PHASE -> 0 + WAIT_TO_START -> 0 + // The output of these stages are the data received from the peer non-aggregator duchy: + WAIT_ON_SHUFFLE_INPUT_PHASE_ONE, + WAIT_ON_SHUFFLE_INPUT_PHASE_TWO, + // The output of these stages are the computed intermediate data: + SHUFFLE_PHASE, + AGGREGATION_PHASE -> 1 WAIT_ON_AGGREGATION_INPUT -> 2 - AGGREGATION_PHASE -> - // The output is the intermediate computation result either received from another duchy - // or computed locally. - 1 - // Mill have nothing to do for this stage. COMPLETE -> error("Computation should be ended with call to endComputation(...)") - // Stages that we can't transition to ever. + INITIALIZED, UNRECOGNIZED, - STAGE_UNSPECIFIED, - INITIALIZED -> error("Cannot make transition function to stage $stage") + STAGE_UNSPECIFIED -> error("Cannot make transition function to stage $stage") } } diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/computationcontrol/AsyncComputationControlService.kt b/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/computationcontrol/AsyncComputationControlService.kt index c1cc46da7db..394135136d9 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/computationcontrol/AsyncComputationControlService.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/computationcontrol/AsyncComputationControlService.kt @@ -35,7 +35,6 @@ import org.wfanet.measurement.internal.duchy.ComputationsGrpcKt.ComputationsCoro import org.wfanet.measurement.internal.duchy.GetOutputBlobMetadataRequest import org.wfanet.measurement.internal.duchy.getComputationTokenRequest import org.wfanet.measurement.internal.duchy.recordOutputBlobPathRequest -import org.wfanet.measurement.internal.duchy.updateComputationDetailsRequest /** Implementation of the internal Async Computation Control Service. */ class AsyncComputationControlService( @@ -90,7 +89,7 @@ class AsyncComputationControlService( .asRuntimeException() val computationStage = token.computationStage - if (!stages.isValidStage(computationStage, request.computationStage)) { + if (computationStage != request.computationStage) { if (computationStage == stages.nextStage(request.computationStage)) { // This is technically an error, but it should be safe to treat as a no-op. logger.warning { "[id=$globalComputationId]: Computation stage has already been advanced" } @@ -103,57 +102,43 @@ class AsyncComputationControlService( .asRuntimeException() } - if (stages.expectBlob(computationStage)) { - val outputBlob = - token.blobsList.firstOrNull { - it.blobId == request.blobId && it.dependencyType == ComputationBlobDependency.OUTPUT - } ?: failGrpc(Status.FAILED_PRECONDITION) { "No output blob with ID ${request.blobId}" } - if (outputBlob.path.isNotEmpty()) { - if (outputBlob.path != request.blobPath) { - throw Status.FAILED_PRECONDITION.withDescription( - "Output blob ${outputBlob.blobId} already has a different path recorded" - ) - .asRuntimeException() - } - logger.info { - "[id=$globalComputationId]: Path already recorded for output blob ${outputBlob.blobId}" - } - } else { - val response = - try { - computationsClient.recordOutputBlobPath( - recordOutputBlobPathRequest { - this.token = token - outputBlobId = outputBlob.blobId - blobPath = request.blobPath - } - ) - } catch (e: StatusException) { - throw when (e.status.code) { - Status.Code.ABORTED -> RetryableException(e) - else -> Status.UNKNOWN.withCause(e).asRuntimeException() - } - } - - // Computation has changed, so use the new token. - token = response.token + val outputBlob = + token.blobsList.firstOrNull { + it.blobId == request.blobId && it.dependencyType == ComputationBlobDependency.OUTPUT + } ?: failGrpc(Status.FAILED_PRECONDITION) { "No output blob with ID ${request.blobId}" } + if (outputBlob.path.isNotEmpty()) { + if (outputBlob.path != request.blobPath) { + throw Status.FAILED_PRECONDITION.withDescription( + "Output blob ${outputBlob.blobId} already has a different path recorded" + ) + .asRuntimeException() } - } - - if (stages.expectStageInput(token)) { - val computationDetails = - stages.updateComputationDetails(token.computationDetails, request.computationStageInput) + logger.info { + "[id=$globalComputationId]: Path already recorded for output blob ${outputBlob.blobId}" + } + } else { val response = - computationsClient.updateComputationDetails( - updateComputationDetailsRequest { - this.token = token - details = computationDetails + try { + computationsClient.recordOutputBlobPath( + recordOutputBlobPathRequest { + this.token = token + outputBlobId = outputBlob.blobId + blobPath = request.blobPath + } + ) + } catch (e: StatusException) { + throw when (e.status.code) { + Status.Code.ABORTED -> RetryableException(e) + else -> Status.UNKNOWN.withCause(e).asRuntimeException() } - ) + } + + // Computation has changed, so use the new token. token = response.token } - if (stages.readyForNextStage(token)) { + // Advance the computation to next stage if all blob paths are present. + if (!token.outputPathList().any(String::isEmpty)) { try { computationsClient.advanceComputationStage( computationToken = token, diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/computationcontrol/ProtocolStages.kt b/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/computationcontrol/ProtocolStages.kt index 62a0fcde7c6..7be6daa3310 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/computationcontrol/ProtocolStages.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/computationcontrol/ProtocolStages.kt @@ -15,19 +15,14 @@ package org.wfanet.measurement.duchy.service.internal.computationcontrol import org.wfanet.measurement.duchy.db.computation.singleOutputBlobMetadata -import org.wfanet.measurement.duchy.service.internal.computations.outputPathList import org.wfanet.measurement.duchy.toProtocolStage import org.wfanet.measurement.internal.duchy.ComputationBlobDependency -import org.wfanet.measurement.internal.duchy.ComputationDetails import org.wfanet.measurement.internal.duchy.ComputationStage import org.wfanet.measurement.internal.duchy.ComputationStageBlobMetadata -import org.wfanet.measurement.internal.duchy.ComputationStageInput import org.wfanet.measurement.internal.duchy.ComputationToken -import org.wfanet.measurement.internal.duchy.copy import org.wfanet.measurement.internal.duchy.protocol.HonestMajorityShareShuffle import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggregationV2 import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2 -import org.wfanet.measurement.internal.duchy.protocol.copy class IllegalStageException(val computationStage: ComputationStage, buildMessage: () -> String) : IllegalArgumentException(buildMessage()) @@ -48,29 +43,6 @@ sealed class ProtocolStages(val stageType: ComputationStage.StageCase) { */ abstract fun nextStage(stage: ComputationStage): ComputationStage - /** Returns whether the current stage is valid to process the advance request */ - abstract fun isValidStage(currentStage: ComputationStage, requestStage: ComputationStage): Boolean - - /** Returns whether the stage expects the advance request with a blob. */ - abstract fun expectBlob(stage: ComputationStage): Boolean - - /** - * Returns whether the [ComputationStage] of the [ComputationToken] expects the advance request - * with protocol specific input. - * - * If the [token] has the fields set already, return false to skip. - */ - abstract fun expectStageInput(token: ComputationToken): Boolean - - /** Returns the updated [ComputationDetails] with values in [ComputationStageInput]. */ - abstract fun updateComputationDetails( - details: ComputationDetails, - input: ComputationStageInput, - ): ComputationDetails - - /** Returns whether the [ComputationToken] is in the state to advance to the next stage. */ - abstract fun readyForNextStage(token: ComputationToken): Boolean - companion object { fun forStageType(stageType: ComputationStage.StageCase): ProtocolStages? { return when (stageType) { @@ -124,7 +96,6 @@ class LiquidLegionsV2Stages() : override fun nextStage(stage: ComputationStage): ComputationStage { require(stage.stageCase == ComputationStage.StageCase.LIQUID_LEGIONS_SKETCH_AGGREGATION_V2) - @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Protobuf enums fields cannot be null. return when (val protocolStage = stage.liquidLegionsSketchAggregationV2) { LiquidLegionsSketchAggregationV2.Stage.WAIT_SETUP_PHASE_INPUTS -> LiquidLegionsSketchAggregationV2.Stage.SETUP_PHASE @@ -148,28 +119,6 @@ class LiquidLegionsV2Stages() : throw IllegalStageException(stage) { "Next $stageType stage unknown for $protocolStage" } }.toProtocolStage() } - - override fun isValidStage( - currentStage: ComputationStage, - requestStage: ComputationStage, - ): Boolean = currentStage == requestStage - - override fun expectBlob(stage: ComputationStage): Boolean = true - - override fun expectStageInput(token: ComputationToken): Boolean = false - - override fun updateComputationDetails( - details: ComputationDetails, - input: ComputationStageInput, - ): ComputationDetails = - throw IllegalStageException( - LiquidLegionsSketchAggregationV2.Stage.STAGE_UNSPECIFIED.toProtocolStage() - ) { - "Invalid $stageType to update ComputationDetails." - } - - override fun readyForNextStage(token: ComputationToken): Boolean = - !token.outputPathList().any(String::isEmpty) } /** [ProtocolStages] for the Reach-Only Liquid Legions v2 protocol. */ @@ -210,7 +159,6 @@ class ReachOnlyLiquidLegionsV2Stages() : stage.stageCase == ComputationStage.StageCase.REACH_ONLY_LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 ) - @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Protobuf enums fields cannot be null. return when (val protocolStage = stage.reachOnlyLiquidLegionsSketchAggregationV2) { ReachOnlyLiquidLegionsSketchAggregationV2.Stage.WAIT_SETUP_PHASE_INPUTS -> ReachOnlyLiquidLegionsSketchAggregationV2.Stage.SETUP_PHASE @@ -228,28 +176,6 @@ class ReachOnlyLiquidLegionsV2Stages() : throw IllegalStageException(stage) { "Next $stageType stage unknown for $protocolStage" } }.toProtocolStage() } - - override fun isValidStage( - currentStage: ComputationStage, - requestStage: ComputationStage, - ): Boolean = currentStage == requestStage - - override fun expectBlob(stage: ComputationStage): Boolean = true - - override fun expectStageInput(token: ComputationToken): Boolean = false - - override fun updateComputationDetails( - details: ComputationDetails, - input: ComputationStageInput, - ): ComputationDetails = - throw IllegalStageException( - ReachOnlyLiquidLegionsSketchAggregationV2.Stage.STAGE_UNSPECIFIED.toProtocolStage() - ) { - "Invalid $stageType to update ComputationDetails" - } - - override fun readyForNextStage(token: ComputationToken): Boolean = - !token.outputPathList().any(String::isEmpty) } /** [ProtocolStages] for the Honest Majority Share Shuffle protocol. */ @@ -272,9 +198,13 @@ class HonestMajorityShareShuffleStages() : it.dependencyType == ComputationBlobDependency.OUTPUT && it.blobId == blobId } } + HonestMajorityShareShuffle.Stage.WAIT_ON_SHUFFLE_INPUT_PHASE_ONE, + HonestMajorityShareShuffle.Stage.WAIT_ON_SHUFFLE_INPUT_PHASE_TWO -> { + token.singleOutputBlobMetadata() + } HonestMajorityShareShuffle.Stage.INITIALIZED, + HonestMajorityShareShuffle.Stage.WAIT_TO_START, HonestMajorityShareShuffle.Stage.SETUP_PHASE, - HonestMajorityShareShuffle.Stage.WAIT_ON_SHUFFLE_INPUT, HonestMajorityShareShuffle.Stage.SHUFFLE_PHASE, HonestMajorityShareShuffle.Stage.AGGREGATION_PHASE, HonestMajorityShareShuffle.Stage.COMPLETE, @@ -292,11 +222,14 @@ class HonestMajorityShareShuffleStages() : @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Protobuf enums fields cannot be null. return when (protocolStage) { - HonestMajorityShareShuffle.Stage.WAIT_ON_SHUFFLE_INPUT -> + HonestMajorityShareShuffle.Stage.WAIT_ON_SHUFFLE_INPUT_PHASE_ONE -> + HonestMajorityShareShuffle.Stage.SETUP_PHASE.toProtocolStage() + HonestMajorityShareShuffle.Stage.WAIT_ON_SHUFFLE_INPUT_PHASE_TWO -> HonestMajorityShareShuffle.Stage.SHUFFLE_PHASE.toProtocolStage() HonestMajorityShareShuffle.Stage.WAIT_ON_AGGREGATION_INPUT -> HonestMajorityShareShuffle.Stage.AGGREGATION_PHASE.toProtocolStage() HonestMajorityShareShuffle.Stage.INITIALIZED, + HonestMajorityShareShuffle.Stage.WAIT_TO_START, HonestMajorityShareShuffle.Stage.SETUP_PHASE, HonestMajorityShareShuffle.Stage.SHUFFLE_PHASE, HonestMajorityShareShuffle.Stage.AGGREGATION_PHASE, @@ -306,108 +239,4 @@ class HonestMajorityShareShuffleStages() : throw IllegalStageException(stage) { "Next $stageType stage invalid for $protocolStage" } } } - - override fun isValidStage( - currentStage: ComputationStage, - requestStage: ComputationStage, - ): Boolean { - require(currentStage.hasHonestMajorityShareShuffle()) - require(requestStage.hasHonestMajorityShareShuffle()) - val currentProtocolStage = currentStage.honestMajorityShareShuffle - val requestProtocolStage = requestStage.honestMajorityShareShuffle - - return when (requestProtocolStage) { - HonestMajorityShareShuffle.Stage.WAIT_ON_SHUFFLE_INPUT -> - // Non-aggregators execute SETUP phase simultaneously. It is supposed to tolerate current - // stage in INITIALIZED or SETUP_PHASE although the WAIT_ON_SHUFFLE_INPUT is the desired - // one. - setOf( - HonestMajorityShareShuffle.Stage.INITIALIZED, - HonestMajorityShareShuffle.Stage.SETUP_PHASE, - HonestMajorityShareShuffle.Stage.WAIT_ON_SHUFFLE_INPUT, - ) - .contains(currentProtocolStage) - else -> currentProtocolStage == requestProtocolStage - } - } - - override fun expectBlob(stage: ComputationStage): Boolean { - require(stage.hasHonestMajorityShareShuffle()) - - @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Protobuf enums fields cannot be null. - return when (stage.honestMajorityShareShuffle) { - HonestMajorityShareShuffle.Stage.WAIT_ON_AGGREGATION_INPUT -> true - HonestMajorityShareShuffle.Stage.INITIALIZED, - HonestMajorityShareShuffle.Stage.SETUP_PHASE, - HonestMajorityShareShuffle.Stage.WAIT_ON_SHUFFLE_INPUT, - HonestMajorityShareShuffle.Stage.SHUFFLE_PHASE, - HonestMajorityShareShuffle.Stage.AGGREGATION_PHASE, - HonestMajorityShareShuffle.Stage.COMPLETE, - HonestMajorityShareShuffle.Stage.STAGE_UNSPECIFIED, - HonestMajorityShareShuffle.Stage.UNRECOGNIZED -> false - } - } - - override fun expectStageInput(token: ComputationToken): Boolean { - require(token.computationStage.hasHonestMajorityShareShuffle()) - val stage = token.computationStage - - @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Protobuf enums fields cannot be null. - return when (stage.honestMajorityShareShuffle) { - HonestMajorityShareShuffle.Stage.INITIALIZED, - HonestMajorityShareShuffle.Stage.SETUP_PHASE, - HonestMajorityShareShuffle.Stage.WAIT_ON_SHUFFLE_INPUT -> { - token.computationDetails.honestMajorityShareShuffle.seeds.commonRandomSeedFromPeer.isEmpty - } - HonestMajorityShareShuffle.Stage.WAIT_ON_AGGREGATION_INPUT, - HonestMajorityShareShuffle.Stage.SHUFFLE_PHASE, - HonestMajorityShareShuffle.Stage.AGGREGATION_PHASE, - HonestMajorityShareShuffle.Stage.COMPLETE, - HonestMajorityShareShuffle.Stage.STAGE_UNSPECIFIED, - HonestMajorityShareShuffle.Stage.UNRECOGNIZED -> false - } - } - - override fun updateComputationDetails( - details: ComputationDetails, - input: ComputationStageInput, - ): ComputationDetails { - require(details.hasHonestMajorityShareShuffle()) - require(!input.honestMajorityShareShuffleShufflePhaseInput.peerRandomSeed.isEmpty) - return details.copy { - honestMajorityShareShuffle = - honestMajorityShareShuffle.copy { - seeds = - seeds.copy { - commonRandomSeedFromPeer = - input.honestMajorityShareShuffleShufflePhaseInput.peerRandomSeed - } - } - } - } - - private fun ComputationToken.hasPeerSeed(): Boolean = - !computationDetails.honestMajorityShareShuffle.seeds.commonRandomSeedFromPeer.isEmpty - - private fun ComputationToken.requisitionsFulfilled(): Boolean = - requisitionsList.all { !it.secretSeedCiphertext.isEmpty && it.path.isNotBlank() } - - override fun readyForNextStage(token: ComputationToken): Boolean { - require(token.computationStage.hasHonestMajorityShareShuffle()) - - @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Protobuf enums fields cannot be null. - return when (token.computationStage.honestMajorityShareShuffle) { - HonestMajorityShareShuffle.Stage.WAIT_ON_SHUFFLE_INPUT -> - token.hasPeerSeed() && token.requisitionsFulfilled() - HonestMajorityShareShuffle.Stage.WAIT_ON_AGGREGATION_INPUT -> - token.outputPathList().all(String::isNotEmpty) - HonestMajorityShareShuffle.Stage.INITIALIZED, - HonestMajorityShareShuffle.Stage.SETUP_PHASE, - HonestMajorityShareShuffle.Stage.SHUFFLE_PHASE, - HonestMajorityShareShuffle.Stage.AGGREGATION_PHASE, - HonestMajorityShareShuffle.Stage.COMPLETE, - HonestMajorityShareShuffle.Stage.STAGE_UNSPECIFIED, - HonestMajorityShareShuffle.Stage.UNRECOGNIZED -> false - } - } } diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/service/system/v1alpha/AdvanceComputationRequestHeaders.kt b/src/main/kotlin/org/wfanet/measurement/duchy/service/system/v1alpha/AdvanceComputationRequestHeaders.kt index 0512c579232..bb467d3f3bb 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/service/system/v1alpha/AdvanceComputationRequestHeaders.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/service/system/v1alpha/AdvanceComputationRequestHeaders.kt @@ -14,14 +14,10 @@ package org.wfanet.measurement.duchy.service.system.v1alpha -import com.google.protobuf.ByteString import org.wfanet.measurement.common.grpc.failGrpc import org.wfanet.measurement.duchy.toProtocolStage import org.wfanet.measurement.internal.duchy.ComputationStage -import org.wfanet.measurement.internal.duchy.ComputationStageInput -import org.wfanet.measurement.internal.duchy.computationStageInput import org.wfanet.measurement.internal.duchy.protocol.HonestMajorityShareShuffle as HonestMajorityShareShuffleProtocol -import org.wfanet.measurement.internal.duchy.protocol.HonestMajorityShareShuffleKt.shufflePhaseInput as internalShufflePhaseInput import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggregationV2 import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2 import org.wfanet.measurement.system.v1alpha.AdvanceComputationRequest @@ -29,7 +25,6 @@ import org.wfanet.measurement.system.v1alpha.AdvanceComputationRequest.Header.Pr import org.wfanet.measurement.system.v1alpha.AdvanceComputationRequestKt import org.wfanet.measurement.system.v1alpha.ComputationKey import org.wfanet.measurement.system.v1alpha.HonestMajorityShareShuffle -import org.wfanet.measurement.system.v1alpha.HonestMajorityShareShuffleKt.shufflePhaseInput import org.wfanet.measurement.system.v1alpha.LiquidLegionsV2 import org.wfanet.measurement.system.v1alpha.ReachOnlyLiquidLegionsV2 import org.wfanet.measurement.system.v1alpha.honestMajorityShareShuffle @@ -56,40 +51,6 @@ fun AdvanceComputationRequest.Header.stageExpectingInput(): ComputationStage = ProtocolCase.PROTOCOL_NOT_SET -> failGrpc { "Unknown protocol $protocolCase" } } -/** Returns true if the stage expects blob input from the request. */ -fun AdvanceComputationRequest.Header.doesExpectBlobInput(): Boolean = - @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. - when (protocolCase) { - ProtocolCase.LIQUID_LEGIONS_V2, - ProtocolCase.REACH_ONLY_LIQUID_LEGIONS_V2 -> true - ProtocolCase.HONEST_MAJORITY_SHARE_SHUFFLE -> - @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. - when (honestMajorityShareShuffle.description) { - HonestMajorityShareShuffle.Description.SHUFFLE_PHASE_INPUT -> false - HonestMajorityShareShuffle.Description.AGGREGATION_PHASE_INPUT -> true - HonestMajorityShareShuffle.Description.DESCRIPTION_UNSPECIFIED, - HonestMajorityShareShuffle.Description.UNRECOGNIZED -> failGrpc { "Invalid description." } - } - ProtocolCase.PROTOCOL_NOT_SET -> failGrpc { "Unknown protocol $protocolCase" } - } - -/** Returns true if the stage expects protocol specific input from the header of the request. */ -fun AdvanceComputationRequest.Header.doesExpectProtocolSpecificInput(): Boolean = - @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. - when (protocolCase) { - ProtocolCase.LIQUID_LEGIONS_V2, - ProtocolCase.REACH_ONLY_LIQUID_LEGIONS_V2 -> false - ProtocolCase.HONEST_MAJORITY_SHARE_SHUFFLE -> - @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. - when (honestMajorityShareShuffle.description) { - HonestMajorityShareShuffle.Description.SHUFFLE_PHASE_INPUT -> true - HonestMajorityShareShuffle.Description.AGGREGATION_PHASE_INPUT -> false - HonestMajorityShareShuffle.Description.DESCRIPTION_UNSPECIFIED, - HonestMajorityShareShuffle.Description.UNRECOGNIZED -> failGrpc { "Invalid description." } - } - ProtocolCase.PROTOCOL_NOT_SET -> failGrpc { "Unknown protocol $protocolCase" } - } - private fun LiquidLegionsV2.stageExpectingInput(): ComputationStage = when (description) { LiquidLegionsV2.Description.SETUP_PHASE_INPUT -> @@ -114,35 +75,15 @@ private fun ReachOnlyLiquidLegionsV2.stageExpectingInput(): ComputationStage = private fun HonestMajorityShareShuffle.stageExpectingInput(): ComputationStage = when (description) { - HonestMajorityShareShuffle.Description.SHUFFLE_PHASE_INPUT -> - HonestMajorityShareShuffleProtocol.Stage.WAIT_ON_SHUFFLE_INPUT + HonestMajorityShareShuffle.Description.SHUFFLE_PHASE_INPUT_ONE -> + HonestMajorityShareShuffleProtocol.Stage.WAIT_ON_SHUFFLE_INPUT_PHASE_ONE + HonestMajorityShareShuffle.Description.SHUFFLE_PHASE_INPUT_TWO -> + HonestMajorityShareShuffleProtocol.Stage.WAIT_ON_SHUFFLE_INPUT_PHASE_TWO HonestMajorityShareShuffle.Description.AGGREGATION_PHASE_INPUT -> HonestMajorityShareShuffleProtocol.Stage.WAIT_ON_AGGREGATION_INPUT else -> failGrpc { "Unknown HonestMajorityShareShuffle payload description '$description'." } }.toProtocolStage() -fun AdvanceComputationRequest.Header.computationStageInput(): ComputationStageInput { - @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. - return when (protocolCase) { - ProtocolCase.HONEST_MAJORITY_SHARE_SHUFFLE -> honestMajorityShareShuffle.computationStageInput() - ProtocolCase.LIQUID_LEGIONS_V2, - ProtocolCase.REACH_ONLY_LIQUID_LEGIONS_V2, - ProtocolCase.PROTOCOL_NOT_SET -> - failGrpc { "Protocol $protocolCase does not have ComputationStageInput" } - } -} - -private fun HonestMajorityShareShuffle.computationStageInput(): ComputationStageInput = - when (description) { - HonestMajorityShareShuffle.Description.SHUFFLE_PHASE_INPUT -> - computationStageInput { - honestMajorityShareShuffleShufflePhaseInput = internalShufflePhaseInput { - peerRandomSeed = shufflePhaseInput.peerRandomSeed - } - } - else -> failGrpc { "Unknown ReachOnlyLiquidLegionsV2 payload description '$description'." } - } - /** Creates an [AdvanceComputationRequest.Header] for a liquid legions v2 computation. */ fun advanceComputationHeader( liquidLegionsV2ContentDescription: LiquidLegionsV2.Description, @@ -179,20 +120,3 @@ fun advanceComputationHeader( description = honestMajorityShareShuffleContentDescription } } - -/** - * Creates an [AdvanceComputationRequest.Header] for the honest majority share shuffle computation - * with a seed. - */ -fun advanceComputationHeader( - honestMajorityShareShuffleContentDescription: HonestMajorityShareShuffle.Description, - globalComputationId: String, - seed: ByteString, -): AdvanceComputationRequest.Header = - AdvanceComputationRequestKt.header { - name = ComputationKey(globalComputationId).toName() - honestMajorityShareShuffle = honestMajorityShareShuffle { - description = honestMajorityShareShuffleContentDescription - shufflePhaseInput = shufflePhaseInput { peerRandomSeed = seed } - } - } diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/service/system/v1alpha/ComputationControlService.kt b/src/main/kotlin/org/wfanet/measurement/duchy/service/system/v1alpha/ComputationControlService.kt index f477e44e47f..73cba858eaf 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/service/system/v1alpha/ComputationControlService.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/service/system/v1alpha/ComputationControlService.kt @@ -59,10 +59,13 @@ class ComputationControlService( grpcRequireNotNull(ComputationKey.fromName(header.name)?.computationId) { "Resource name unspecified or invalid." } + grpcRequire(consumed.hasRemaining) { "Request stream has no body" } if (header.isForAsyncComputation()) { - val remaining = - if (consumed.hasRemaining) consumed.remaining.map { it.bodyChunk.partialData } else null - handleAsyncRequest(header, remaining, globalComputationId) + handleAsyncRequest( + header, + consumed.remaining.map { it.bodyChunk.partialData }, + globalComputationId, + ) } else { failGrpc { "Synchronous computations are not yet supported." } } @@ -73,35 +76,26 @@ class ComputationControlService( /** Write the payload data stream as a blob, and advance the stage via the async service. */ private suspend fun handleAsyncRequest( header: AdvanceComputationRequest.Header, - content: Flow?, + content: Flow, globalId: String, ) { val stage = header.stageExpectingInput() + val blobMetadata = + asyncComputationControlClient.getOutputBlobMetadata( + getOutputBlobMetadataRequest { + globalComputationId = globalId + dataOrigin = duchyIdentityProvider().id + } + ) + val blob = + computationStore.write(ComputationBlobContext(globalId, stage, blobMetadata.blobId), content) + val request = advanceComputationRequest { globalComputationId = globalId computationStage = stage - - if (header.doesExpectBlobInput()) { - grpcRequire(content != null) { "Request stream has no body" } - val blobMetadata = - asyncComputationControlClient.getOutputBlobMetadata( - getOutputBlobMetadataRequest { - globalComputationId = globalId - dataOrigin = duchyIdentityProvider().id - } - ) - val blob = - computationStore.write( - ComputationBlobContext(globalId, stage, blobMetadata.blobId), - content!!, - ) - blobId = blobMetadata.blobId - blobPath = blob.blobKey - } - if (header.doesExpectProtocolSpecificInput()) { - computationStageInput = header.computationStageInput() - } + blobId = blobMetadata.blobId + blobPath = blob.blobKey } try { asyncComputationControlClient.advanceComputation(request) diff --git a/src/main/proto/wfa/measurement/internal/duchy/async_computation_control_service.proto b/src/main/proto/wfa/measurement/internal/duchy/async_computation_control_service.proto index be79920eca9..fa75429e637 100644 --- a/src/main/proto/wfa/measurement/internal/duchy/async_computation_control_service.proto +++ b/src/main/proto/wfa/measurement/internal/duchy/async_computation_control_service.proto @@ -49,8 +49,6 @@ message AdvanceComputationRequest { int64 blob_id = 3; // The path of the blob newly added. string blob_path = 4; - // The protocol specific input for a certain stage. - ComputationStageInput computation_stage_input = 5; } message AdvanceComputationResponse { diff --git a/src/main/proto/wfa/measurement/internal/duchy/crypto.proto b/src/main/proto/wfa/measurement/internal/duchy/crypto.proto index 3b4c3061b62..82e7e0e322c 100644 --- a/src/main/proto/wfa/measurement/internal/duchy/crypto.proto +++ b/src/main/proto/wfa/measurement/internal/duchy/crypto.proto @@ -59,3 +59,11 @@ message EncryptionPublicKey { // decrypt messages given a private key. bytes data = 2; } + +// Keys of the encryption cipher. +message EncryptionKeyPair { + // Secret key of the cipher. + bytes secret_key = 1; + // Public key of the cipher. + EncryptionPublicKey public_key = 2; +} diff --git a/src/main/proto/wfa/measurement/internal/duchy/protocol/BUILD.bazel b/src/main/proto/wfa/measurement/internal/duchy/protocol/BUILD.bazel index 395b7bdae97..87b3b10727e 100644 --- a/src/main/proto/wfa/measurement/internal/duchy/protocol/BUILD.bazel +++ b/src/main/proto/wfa/measurement/internal/duchy/protocol/BUILD.bazel @@ -139,6 +139,7 @@ proto_library( strip_import_prefix = IMPORT_PREFIX, deps = [ ":share_shuffle_sketch_params_proto", + "//src/main/proto/wfa/measurement/internal/duchy:crypto_proto", "//src/main/proto/wfa/measurement/internal/duchy:noise_mechanism_proto", "//src/main/proto/wfa/measurement/internal/duchy/config:protocols_setup_config_proto", ], diff --git a/src/main/proto/wfa/measurement/internal/duchy/protocol/honest_majority_share_shuffle.proto b/src/main/proto/wfa/measurement/internal/duchy/protocol/honest_majority_share_shuffle.proto index b9096daee17..4693ba88ae6 100644 --- a/src/main/proto/wfa/measurement/internal/duchy/protocol/honest_majority_share_shuffle.proto +++ b/src/main/proto/wfa/measurement/internal/duchy/protocol/honest_majority_share_shuffle.proto @@ -17,6 +17,7 @@ syntax = "proto3"; package wfa.measurement.internal.duchy.protocol; import "wfa/measurement/internal/duchy/config/protocols_setup_config.proto"; +import "wfa/measurement/internal/duchy/crypto.proto"; import "wfa/measurement/internal/duchy/noise_mechanism.proto"; import "wfa/measurement/internal/duchy/protocol/share_shuffle_sketch_params.proto"; @@ -26,11 +27,14 @@ option java_multiple_files = true; message HonestMajorityShareShuffle { // Stages of the HonestMajorityShareShuffle computation. // - // For non-aggregators, the normal stage transition is: - // INITIALIZED -> SETUP_PHASE -> (WAIT_ON_SHUFFLE_INPUT) -> SHUFFLE_PHASE + // For the first non-aggregator, the normal stage transition is: + // INITIALIZED -> WAIT_TO_START -> SETUP_PHASE -> + // WAIT_ON_SHUFFLE_INPUT_PHASE_TWO -> SHUFFLE_PHASE // -> COMPLETE - // The WAIT_ON_INPUT stage will be skipped if all the input (requisition data - // and peer worker input) has been received before the end of SETUP_PHASE. + // + // For the second non-aggregator, the normal stage transition is: + // INITIALIZED -> WAIT_ON_SHUFFLE_INPUT_PHASE_ONE -> SETUP_PHASE + // -> SHUFFLE_PHASE -> COMPLETE // // For the aggregator, the normal stage transition is: // WAIT_ON_AGGREGATION_INPUT -> AGGREGATION_PHASE -> COMPLETE @@ -38,25 +42,37 @@ message HonestMajorityShareShuffle { // The computation stage is unknown. This is never set intentionally. STAGE_UNSPECIFIED = 0; - // Computation is created by the non-aggregator. + // The computation is created by the non-aggregator populated with sampled + // randomness seed and encryption key pair. // - // Non-aggregators are ready to accept requisition fulfillments. The - // randomness seed is sampled and stored. Peer's random seed can be accepted - // at this stage. + // In this stage, non-aggregators send the participant params to the kingdom + // by the mill. INITIALIZED = 1; - // Non-aggregator reads the randomness seed, then sends it to the peer. - // Seeds are used for generating either noise or permutation. Peer's random - // seed can be accepted at this stage. - SETUP_PHASE = 2; + // The first non-aggregator waits on requisition fulfillment and the signal + // from the kingdom to proceed. + // + // In this stage, requisition fulfillment are received from EDPs and + // reported to the kingdom. + WAIT_TO_START = 2; + + // The second non-aggregator waits on requisition fulfillment and the data + // from the peer non-aggregator to proceed. + // + // In this stage, requisition fulfillment are received from EDPs and + // reported to the kingdom. + WAIT_ON_SHUFFLE_INPUT_PHASE_ONE = 3; - // The non-aggregators only stage waiting for input of seeds from the peer - // as well as requisition data from EDPs. - WAIT_ON_SHUFFLE_INPUT = 3; + // Non-aggregator sends random seed and secret seed from DataProviders to + // the peer. Seeds are used for generating noise and permutation. + SETUP_PHASE = 4; - // The aggregator only stage waiting for input of combined shares from + // The first non-aggregator waits for input from the peer non-aggregator. + WAIT_ON_SHUFFLE_INPUT_PHASE_TWO = 5; + + // A aggregator only stage waiting for input of combined shares from // non-aggregators. It is the initial stage for the aggregator. - WAIT_ON_AGGREGATION_INPUT = 4; + WAIT_ON_AGGREGATION_INPUT = 6; // Non-aggregators execute following steps: // @@ -72,16 +88,16 @@ message HonestMajorityShareShuffle { // permutation based on the combined seed. Shuffle the shares. // // 4. Send the result of the share to the aggregator. - SHUFFLE_PHASE = 5; + SHUFFLE_PHASE = 7; // The aggregator adds up shares from non-aggregators, subtracts noise // offset, and calculates the reach and the frequency histogram, then // reports the encrypted result to the kingdom. - AGGREGATION_PHASE = 6; + AGGREGATION_PHASE = 8; // The computation is completed or failed. The worker can remove BLOBs that // are no longer needed. - COMPLETE = 7; + COMPLETE = 9; } message ComputationDetails { @@ -99,26 +115,34 @@ message HonestMajorityShareShuffle { } Parameters parameters = 2; - // Seeds used by non-aggregators to generate noise and permutation. - message RandomSeeds { - // Seed to generates noise and permutation. - bytes common_random_seed = 1; - // Seed from the peer worker to generates noise and permutation. - bytes common_random_seed_from_peer = 2; - } - RandomSeeds seeds = 3; - // The list of ids of duchies participating in this computation. // The list is sorted by the duchy order by names, with the first element // being the first non-aggregator following by other non-aggregators, // and the last element being the aggregator. - repeated string participants = 4; + repeated string participants = 3; + + // Seeds used by non-aggregators to generate noise and permutation. + bytes common_random_seed = 4; + + // Encryption key pair for participant params. + EncryptionKeyPair encryption_key_pair = 5; } // The input message containing the random seed from the peer worker. message ShufflePhaseInput { // Random seed in bytes. bytes peer_random_seed = 1; + + // The requisition data in format of a secret seed. It is ciphertext + // containing a serialized `SignedMessage` message from the CMMS public API. + message SecretSeed { + string external_data_provider_id = 1; + + // Ciphertext containing a serialized `SignedMessage` message from the + // CMMS public API. + bytes secret_seed_ciphertext = 2; + } + repeated SecretSeed secret_seeds = 2; } // Details about a particular attempt of running a stage of the diff --git a/src/main/proto/wfa/measurement/system/v1alpha/computation_control_service.proto b/src/main/proto/wfa/measurement/system/v1alpha/computation_control_service.proto index f5db66733eb..6636b18a5cb 100644 --- a/src/main/proto/wfa/measurement/system/v1alpha/computation_control_service.proto +++ b/src/main/proto/wfa/measurement/system/v1alpha/computation_control_service.proto @@ -120,23 +120,15 @@ message HonestMajorityShareShuffle { enum Description { // The data type is unknown. This is never set intentionally. DESCRIPTION_UNSPECIFIED = 0; - // THe input for the shuffle phase. - SHUFFLE_PHASE_INPUT = 1; + // The input for the WAIT_ON_SHUFFLE_INPUT_PHASE_ONE. + SHUFFLE_PHASE_INPUT_ONE = 1; + // The input for the WAIT_ON_SHUFFLE_INPUT_PHASE_TWO. + SHUFFLE_PHASE_INPUT_TWO = 2; // The input for the aggregation phase. - AGGREGATION_PHASE_INPUT = 2; + AGGREGATION_PHASE_INPUT = 3; } // Payload data description Description description = 1; - - // The input for shuffle phase. - // Only set for description with SHUFFLE_PHASE_INPUT. - message ShufflePhaseInput { - // The random seed from the peer worker to generate noise share and - // permutation. - bytes peer_random_seed = 1; - } - // The input for shuffle phase. - ShufflePhaseInput shuffle_phase_input = 2; } // Response message for the `AdvanceComputation` method. diff --git a/src/test/kotlin/org/wfanet/measurement/duchy/db/computation/HonestMajorityShareShuffleProtocolEnumStagesTest.kt b/src/test/kotlin/org/wfanet/measurement/duchy/db/computation/HonestMajorityShareShuffleProtocolEnumStagesTest.kt index acd8f122888..df68e8aa608 100644 --- a/src/test/kotlin/org/wfanet/measurement/duchy/db/computation/HonestMajorityShareShuffleProtocolEnumStagesTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/duchy/db/computation/HonestMajorityShareShuffleProtocolEnumStagesTest.kt @@ -77,11 +77,31 @@ class HonestMajorityShareShuffleProtocolEnumStagesTest { fun `verify transistions`() { assertTrue { HonestMajorityShareShuffleProtocol.EnumStages.validTransition( + HonestMajorityShareShuffle.Stage.INITIALIZED, + HonestMajorityShareShuffle.Stage.WAIT_TO_START, + ) + } + + assertTrue { + HonestMajorityShareShuffleProtocol.EnumStages.validTransition( + HonestMajorityShareShuffle.Stage.INITIALIZED, + HonestMajorityShareShuffle.Stage.WAIT_ON_SHUFFLE_INPUT_PHASE_ONE, + ) + } + + assertTrue { + HonestMajorityShareShuffleProtocol.EnumStages.validTransition( + HonestMajorityShareShuffle.Stage.WAIT_TO_START, HonestMajorityShareShuffle.Stage.SETUP_PHASE, - HonestMajorityShareShuffle.Stage.WAIT_ON_SHUFFLE_INPUT, ) } + assertTrue { + HonestMajorityShareShuffleProtocol.EnumStages.validTransition( + HonestMajorityShareShuffle.Stage.WAIT_ON_SHUFFLE_INPUT_PHASE_ONE, + HonestMajorityShareShuffle.Stage.SETUP_PHASE, + ) + } assertTrue { HonestMajorityShareShuffleProtocol.EnumStages.validTransition( HonestMajorityShareShuffle.Stage.SETUP_PHASE, @@ -91,7 +111,7 @@ class HonestMajorityShareShuffleProtocolEnumStagesTest { assertTrue { HonestMajorityShareShuffleProtocol.EnumStages.validTransition( - HonestMajorityShareShuffle.Stage.WAIT_ON_SHUFFLE_INPUT, + HonestMajorityShareShuffle.Stage.WAIT_ON_SHUFFLE_INPUT_PHASE_TWO, HonestMajorityShareShuffle.Stage.SHUFFLE_PHASE, ) } diff --git a/src/test/kotlin/org/wfanet/measurement/duchy/service/internal/computationcontrol/AsyncComputationControlServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/duchy/service/internal/computationcontrol/AsyncComputationControlServiceTest.kt index 64efcb22222..9a18411c20c 100644 --- a/src/test/kotlin/org/wfanet/measurement/duchy/service/internal/computationcontrol/AsyncComputationControlServiceTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/duchy/service/internal/computationcontrol/AsyncComputationControlServiceTest.kt @@ -16,7 +16,6 @@ package org.wfanet.measurement.duchy.service.internal.computationcontrol import com.google.common.truth.Truth.assertThat import com.google.common.truth.extensions.proto.ProtoTruth.assertThat -import com.google.protobuf.kotlin.toByteStringUtf8 import io.grpc.Status import io.grpc.StatusRuntimeException import kotlin.test.assertFailsWith @@ -26,14 +25,10 @@ import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.JUnit4 import org.mockito.kotlin.any -import org.mockito.kotlin.never import org.mockito.kotlin.stub -import org.mockito.kotlin.verifyBlocking import org.mockito.kotlin.whenever -import org.wfanet.measurement.api.Version import org.wfanet.measurement.common.grpc.testing.GrpcTestServerRule import org.wfanet.measurement.common.grpc.testing.mockService -import org.wfanet.measurement.common.testing.verifyProtoArgument import org.wfanet.measurement.duchy.service.internal.computations.newEmptyOutputBlobMetadata import org.wfanet.measurement.duchy.service.internal.computations.newInputBlobMetadata import org.wfanet.measurement.duchy.service.internal.computations.newOutputBlobMetadata @@ -41,12 +36,10 @@ import org.wfanet.measurement.duchy.service.internal.computations.newPassThrough import org.wfanet.measurement.duchy.service.internal.computations.toAdvanceComputationStageResponse import org.wfanet.measurement.duchy.service.internal.computations.toGetComputationTokenResponse import org.wfanet.measurement.duchy.service.internal.computations.toRecordOutputBlobPathResponse -import org.wfanet.measurement.duchy.service.internal.computations.toUpdateComputationDetailsResponse import org.wfanet.measurement.duchy.toProtocolStage import org.wfanet.measurement.internal.duchy.AdvanceComputationRequest import org.wfanet.measurement.internal.duchy.AdvanceComputationStageRequest import org.wfanet.measurement.internal.duchy.AdvanceComputationStageRequest.AfterTransition -import org.wfanet.measurement.internal.duchy.AdvanceComputationStageResponse import org.wfanet.measurement.internal.duchy.ComputationBlobDependency import org.wfanet.measurement.internal.duchy.ComputationDetails import org.wfanet.measurement.internal.duchy.ComputationStageDetails @@ -59,26 +52,19 @@ import org.wfanet.measurement.internal.duchy.advanceComputationStageRequest import org.wfanet.measurement.internal.duchy.computationDetails import org.wfanet.measurement.internal.duchy.computationStageBlobMetadata import org.wfanet.measurement.internal.duchy.computationStageDetails -import org.wfanet.measurement.internal.duchy.computationStageInput import org.wfanet.measurement.internal.duchy.computationToken import org.wfanet.measurement.internal.duchy.config.RoleInComputation import org.wfanet.measurement.internal.duchy.copy -import org.wfanet.measurement.internal.duchy.getComputationTokenRequest import org.wfanet.measurement.internal.duchy.getOutputBlobMetadataRequest import org.wfanet.measurement.internal.duchy.protocol.HonestMajorityShareShuffle.Stage as HmssStage import org.wfanet.measurement.internal.duchy.protocol.HonestMajorityShareShuffleKt import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggregationV2.Stage as Llv2Stage import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggregationV2Kt +import org.wfanet.measurement.internal.duchy.protocol.honestMajorityShareShuffle import org.wfanet.measurement.internal.duchy.recordOutputBlobPathRequest -import org.wfanet.measurement.internal.duchy.requisitionMetadata -import org.wfanet.measurement.internal.duchy.updateComputationDetailsRequest - -private val COMMON_SEED = "seed_1".toByteStringUtf8() -private val PEER_COMMON_SEED = "seed_2".toByteStringUtf8() -private val REQUISITION_PATH_1 = "path 1" -private val REQUISITION_PATH_2 = "path 2" -private val REQUISITION_SEED_1 = "encrypted seed 1".toByteStringUtf8() -private val REQUISITION_SEED_2 = "encrypted seed 2".toByteStringUtf8() + +private const val SHUFFLE_BLOB_ID = 1L +private val SHUFFLE_BLOB_PATH = "path" private const val AGGREGATION_BLOB_ID_1 = 1L private const val AGGREGATION_BLOB_ID_2 = 2L private val AGGREGATION_BLOB_PATH_1 = "path_1" @@ -409,145 +395,108 @@ class AsyncComputationControlServiceTest { assertThat(advanceComputationRequests).isEmpty() } - private suspend fun verifyHmssAdvanceComputationToShuffleStage( - initStage: HmssStage, - requisitionFulfilled: Boolean, - supposedToAdvance: Boolean, - ) { + @Test + fun `advanceComputation records blob and advance for HMSS WAIT_ON_SHUFFLE_INPUT_PHASE_ONE`(): + Unit = runBlocking { val token = computationToken { - computationStage = initStage.toProtocolStage() + computationStage = HmssStage.WAIT_ON_SHUFFLE_INPUT_PHASE_ONE.toProtocolStage() + blobs += newEmptyOutputBlobMetadata(1L) computationDetails = computationDetails { honestMajorityShareShuffle = HonestMajorityShareShuffleKt.computationDetails { role = RoleInComputation.NON_AGGREGATOR - seeds = - HonestMajorityShareShuffleKt.ComputationDetailsKt.randomSeeds { - this.commonRandomSeed = COMMON_SEED - } } } - if (requisitionFulfilled) { - requisitions += requisitionMetadata { - secretSeedCiphertext = REQUISITION_SEED_1 - path = REQUISITION_PATH_1 - publicApiVersion = Version.V2_ALPHA.string - } - requisitions += requisitionMetadata { - secretSeedCiphertext = REQUISITION_SEED_2 - path = REQUISITION_PATH_2 - publicApiVersion = Version.V2_ALPHA.string - } - } else { - requisitions += requisitionMetadata {} - requisitions += requisitionMetadata {} - } } - + val request = advanceComputationRequest { + globalComputationId = COMPUTATION_ID + computationStage = HmssStage.WAIT_ON_SHUFFLE_INPUT_PHASE_ONE.toProtocolStage() + blobId = 1L + blobPath = BLOB_KEY + } val updatedToken = token.copy { - computationDetails = computationDetails { - honestMajorityShareShuffle = - HonestMajorityShareShuffleKt.computationDetails { - role = RoleInComputation.NON_AGGREGATOR - seeds = - HonestMajorityShareShuffleKt.ComputationDetailsKt.randomSeeds { - this.commonRandomSeed = COMMON_SEED - this.commonRandomSeedFromPeer = PEER_COMMON_SEED - } - } - } + blobs.clear() + blobs += newOutputBlobMetadata(1L, BLOB_KEY) } + val (recordBlobRequests, advanceComputationRequests) = + mockComputationsServiceCalls(token, updatedToken) - mockComputationsService.stub { - onBlocking { getComputationToken(any()) }.thenReturn(token.toGetComputationTokenResponse()) - onBlocking { updateComputationDetails(any()) } - .thenReturn(updatedToken.toUpdateComputationDetailsResponse()) - onBlocking { advanceComputationStage(any()) } - .thenReturn(AdvanceComputationStageResponse.getDefaultInstance()) - } + service.advanceComputation(request) - service.advanceComputation( - advanceComputationRequest { - globalComputationId = COMPUTATION_ID - computationStage = HmssStage.WAIT_ON_SHUFFLE_INPUT.toProtocolStage() - computationStageInput = computationStageInput { - honestMajorityShareShuffleShufflePhaseInput = - HonestMajorityShareShuffleKt.shufflePhaseInput { peerRandomSeed = PEER_COMMON_SEED } + assertThat(recordBlobRequests) + .containsExactly( + recordOutputBlobPathRequest { + this.token = token + outputBlobId = 1 + blobPath = BLOB_KEY } - } - ) - - verifyProtoArgument(mockComputationsService, ComputationsCoroutineImplBase::getComputationToken) - .isEqualTo(getComputationTokenRequest { globalComputationId = COMPUTATION_ID }) - verifyProtoArgument( - mockComputationsService, - ComputationsCoroutineImplBase::updateComputationDetails, ) - .isEqualTo( - updateComputationDetailsRequest { - this.token = token - details = updatedToken.computationDetails + assertThat(advanceComputationRequests) + .containsExactly( + advanceComputationStageRequest { + this.token = updatedToken + nextComputationStage = HmssStage.SETUP_PHASE.toProtocolStage() + inputBlobs += BLOB_KEY + outputBlobs = 0 + stageDetails = ComputationStageDetails.getDefaultInstance() + afterTransition = AfterTransition.ADD_UNCLAIMED_TO_QUEUE } ) - if (supposedToAdvance) { - verifyProtoArgument( - mockComputationsService, - ComputationsCoroutineImplBase::advanceComputationStage, - ) - .isEqualTo( - advanceComputationStageRequest { - this.token = updatedToken - nextComputationStage = HmssStage.SHUFFLE_PHASE.toProtocolStage() - afterTransition = AfterTransition.ADD_UNCLAIMED_TO_QUEUE - outputBlobs = 0 - stageDetails = ComputationStageDetails.getDefaultInstance() - } - ) - } else { - verifyBlocking(mockComputationsService, never()) { advanceComputationStage(any()) } - } } @Test - fun `advanceComputation records seed and advance for HMSS WAIT_ON_SHUFFLE_INPUT`(): Unit = - runBlocking { - verifyHmssAdvanceComputationToShuffleStage( - initStage = HmssStage.WAIT_ON_SHUFFLE_INPUT, - requisitionFulfilled = true, - supposedToAdvance = true, - ) + fun `advanceComputation records blob and advance for HMSS WAIT_ON_SHUFFLE_INPUT_PHASE_TWO`(): + Unit = runBlocking { + val token = computationToken { + computationStage = HmssStage.WAIT_ON_SHUFFLE_INPUT_PHASE_TWO.toProtocolStage() + blobs += newEmptyOutputBlobMetadata(1L) + computationDetails = computationDetails { + honestMajorityShareShuffle = + HonestMajorityShareShuffleKt.computationDetails { + role = RoleInComputation.NON_AGGREGATOR + } + } } + val request = advanceComputationRequest { + globalComputationId = COMPUTATION_ID + computationStage = HmssStage.WAIT_ON_SHUFFLE_INPUT_PHASE_TWO.toProtocolStage() + blobId = 1L + blobPath = BLOB_KEY + } + val updatedToken = + token.copy { + blobs.clear() + blobs += newOutputBlobMetadata(1L, BLOB_KEY) + } + val (recordBlobRequests, advanceComputationRequests) = + mockComputationsServiceCalls(token, updatedToken) - @Test - fun `advanceComputation records seed but doee not advance for HMSS INITIALIZED`() = runBlocking { - verifyHmssAdvanceComputationToShuffleStage( - initStage = HmssStage.INITIALIZED, - requisitionFulfilled = false, - supposedToAdvance = false, - ) - } - - @Test - fun `advanceComputation records seed but doee not advance for HMSS SETUP_PHASE`() = runBlocking { - verifyHmssAdvanceComputationToShuffleStage( - initStage = HmssStage.SETUP_PHASE, - requisitionFulfilled = true, - supposedToAdvance = false, - ) - } + service.advanceComputation(request) - @Test - fun `advanceComputation records seed but doee not advance for HMSS WAIT_ON_SHUFFLE_INPUT`() = - runBlocking { - verifyHmssAdvanceComputationToShuffleStage( - initStage = HmssStage.WAIT_ON_SHUFFLE_INPUT, - requisitionFulfilled = false, - supposedToAdvance = false, + assertThat(recordBlobRequests) + .containsExactly( + recordOutputBlobPathRequest { + this.token = token + outputBlobId = 1 + blobPath = BLOB_KEY + } ) - } + assertThat(advanceComputationRequests) + .containsExactly( + advanceComputationStageRequest { + this.token = updatedToken + nextComputationStage = HmssStage.SHUFFLE_PHASE.toProtocolStage() + inputBlobs += BLOB_KEY + outputBlobs = 1 + stageDetails = ComputationStageDetails.getDefaultInstance() + afterTransition = AfterTransition.ADD_UNCLAIMED_TO_QUEUE + } + ) + } @Test - fun `advanceComputation records the last blob and advance for HMSS WAIT_ON_AGGREGATION`() = + fun `advanceComputation records the last blob and advance for HMSS WAIT_ON_AGGREGATION`(): Unit = runBlocking { val token = computationToken { computationStage = HmssStage.WAIT_ON_AGGREGATION_INPUT.toProtocolStage() @@ -558,61 +507,41 @@ class AsyncComputationControlServiceTest { blobs += newOutputBlobMetadata(AGGREGATION_BLOB_ID_1, AGGREGATION_BLOB_PATH_1) blobs += newEmptyOutputBlobMetadata(AGGREGATION_BLOB_ID_2) } - + val request = advanceComputationRequest { + globalComputationId = COMPUTATION_ID + computationStage = HmssStage.WAIT_ON_AGGREGATION_INPUT.toProtocolStage() + blobId = AGGREGATION_BLOB_ID_2 + blobPath = AGGREGATION_BLOB_PATH_2 + } val updatedToken = token.copy { blobs.clear() blobs += newOutputBlobMetadata(AGGREGATION_BLOB_ID_1, AGGREGATION_BLOB_PATH_1) blobs += newOutputBlobMetadata(AGGREGATION_BLOB_ID_2, AGGREGATION_BLOB_PATH_2) } + val (recordBlobRequests, advanceComputationRequests) = + mockComputationsServiceCalls(token, updatedToken) - mockComputationsService.stub { - onBlocking { getComputationToken(any()) }.thenReturn(token.toGetComputationTokenResponse()) - onBlocking { recordOutputBlobPath(any()) } - .thenReturn(updatedToken.toRecordOutputBlobPathResponse()) - onBlocking { advanceComputationStage(any()) } - .thenReturn(AdvanceComputationStageResponse.getDefaultInstance()) - } - - service.advanceComputation( - advanceComputationRequest { - globalComputationId = COMPUTATION_ID - computationStage = HmssStage.WAIT_ON_AGGREGATION_INPUT.toProtocolStage() - blobId = AGGREGATION_BLOB_ID_2 - blobPath = AGGREGATION_BLOB_PATH_2 - } - ) + service.advanceComputation(request) - verifyProtoArgument( - mockComputationsService, - ComputationsCoroutineImplBase::getComputationToken, - ) - .isEqualTo(getComputationTokenRequest { globalComputationId = COMPUTATION_ID }) - verifyBlocking(mockComputationsService, never()) { updateComputationDetails(any()) } - verifyProtoArgument( - mockComputationsService, - ComputationsCoroutineImplBase::recordOutputBlobPath, - ) - .isEqualTo( + assertThat(recordBlobRequests) + .containsExactly( recordOutputBlobPathRequest { this.token = token outputBlobId = AGGREGATION_BLOB_ID_2 blobPath = AGGREGATION_BLOB_PATH_2 } ) - verifyProtoArgument( - mockComputationsService, - ComputationsCoroutineImplBase::advanceComputationStage, - ) - .isEqualTo( + assertThat(advanceComputationRequests) + .containsExactly( advanceComputationStageRequest { this.token = updatedToken nextComputationStage = HmssStage.AGGREGATION_PHASE.toProtocolStage() - afterTransition = AfterTransition.ADD_UNCLAIMED_TO_QUEUE - outputBlobs = 1 - stageDetails = ComputationStageDetails.getDefaultInstance() inputBlobs += AGGREGATION_BLOB_PATH_1 inputBlobs += AGGREGATION_BLOB_PATH_2 + outputBlobs = 1 + stageDetails = ComputationStageDetails.getDefaultInstance() + afterTransition = AfterTransition.ADD_UNCLAIMED_TO_QUEUE } ) } @@ -629,49 +558,32 @@ class AsyncComputationControlServiceTest { blobs += newEmptyOutputBlobMetadata(AGGREGATION_BLOB_ID_1) blobs += newEmptyOutputBlobMetadata(AGGREGATION_BLOB_ID_2) } - + val request = advanceComputationRequest { + globalComputationId = COMPUTATION_ID + computationStage = HmssStage.WAIT_ON_AGGREGATION_INPUT.toProtocolStage() + blobId = AGGREGATION_BLOB_ID_2 + blobPath = AGGREGATION_BLOB_PATH_2 + } val updatedToken = token.copy { blobs.clear() - blobs += newOutputBlobMetadata(AGGREGATION_BLOB_ID_1, AGGREGATION_BLOB_PATH_1) - blobs += newEmptyOutputBlobMetadata(AGGREGATION_BLOB_ID_2) + blobs += newEmptyOutputBlobMetadata(AGGREGATION_BLOB_ID_1) + blobs += newOutputBlobMetadata(AGGREGATION_BLOB_ID_2, AGGREGATION_BLOB_PATH_2) } + val (recordBlobRequests, advanceComputationRequests) = + mockComputationsServiceCalls(token, updatedToken) - mockComputationsService.stub { - onBlocking { getComputationToken(any()) }.thenReturn(token.toGetComputationTokenResponse()) - onBlocking { recordOutputBlobPath(any()) } - .thenReturn(updatedToken.toRecordOutputBlobPathResponse()) - onBlocking { advanceComputationStage(any()) } - .thenReturn(AdvanceComputationStageResponse.getDefaultInstance()) - } - - service.advanceComputation( - advanceComputationRequest { - globalComputationId = COMPUTATION_ID - computationStage = HmssStage.WAIT_ON_AGGREGATION_INPUT.toProtocolStage() - blobId = AGGREGATION_BLOB_ID_1 - blobPath = AGGREGATION_BLOB_PATH_1 - } - ) + service.advanceComputation(request) - verifyProtoArgument( - mockComputationsService, - ComputationsCoroutineImplBase::getComputationToken, - ) - .isEqualTo(getComputationTokenRequest { globalComputationId = COMPUTATION_ID }) - verifyBlocking(mockComputationsService, never()) { updateComputationDetails(any()) } - verifyProtoArgument( - mockComputationsService, - ComputationsCoroutineImplBase::recordOutputBlobPath, - ) - .isEqualTo( + assertThat(recordBlobRequests) + .containsExactly( recordOutputBlobPathRequest { this.token = token - outputBlobId = AGGREGATION_BLOB_ID_1 - blobPath = AGGREGATION_BLOB_PATH_1 + outputBlobId = AGGREGATION_BLOB_ID_2 + blobPath = AGGREGATION_BLOB_PATH_2 } ) - verifyBlocking(mockComputationsService, never()) { advanceComputationStage(any()) } + assertThat(advanceComputationRequests).isEmpty() } @Test diff --git a/src/test/kotlin/org/wfanet/measurement/duchy/service/internal/computationcontrol/HonestMajorityShareShuffleStagesTest.kt b/src/test/kotlin/org/wfanet/measurement/duchy/service/internal/computationcontrol/HonestMajorityShareShuffleStagesTest.kt index 84214a57774..897d2b17f68 100644 --- a/src/test/kotlin/org/wfanet/measurement/duchy/service/internal/computationcontrol/HonestMajorityShareShuffleStagesTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/duchy/service/internal/computationcontrol/HonestMajorityShareShuffleStagesTest.kt @@ -15,7 +15,6 @@ package org.wfanet.measurement.duchy.service.internal.computationcontrol import com.google.common.truth.Truth.assertThat -import com.google.protobuf.kotlin.toByteStringUtf8 import kotlin.test.assertFailsWith import org.junit.Test import org.junit.runner.RunWith @@ -23,107 +22,12 @@ import org.junit.runners.JUnit4 import org.wfanet.measurement.duchy.db.computation.HonestMajorityShareShuffleProtocol import org.wfanet.measurement.duchy.service.internal.computations.newEmptyOutputBlobMetadata import org.wfanet.measurement.duchy.toProtocolStage -import org.wfanet.measurement.internal.duchy.ComputationBlobDependency -import org.wfanet.measurement.internal.duchy.computationDetails import org.wfanet.measurement.internal.duchy.computationStage -import org.wfanet.measurement.internal.duchy.computationStageBlobMetadata import org.wfanet.measurement.internal.duchy.computationStageDetails -import org.wfanet.measurement.internal.duchy.computationStageInput import org.wfanet.measurement.internal.duchy.computationToken -import org.wfanet.measurement.internal.duchy.copy import org.wfanet.measurement.internal.duchy.protocol.HonestMajorityShareShuffle.Stage -import org.wfanet.measurement.internal.duchy.protocol.HonestMajorityShareShuffleKt -import org.wfanet.measurement.internal.duchy.protocol.HonestMajorityShareShuffleKt.ComputationDetailsKt.randomSeeds import org.wfanet.measurement.internal.duchy.protocol.HonestMajorityShareShuffleKt.stageDetails import org.wfanet.measurement.internal.duchy.protocol.HonestMajorityShareShuffleKt.waitOnAggregationInputDetails -import org.wfanet.measurement.internal.duchy.requisitionMetadata - -private val SEED = "my seed".toByteStringUtf8() -private val PEER_SEED = "peer seed".toByteStringUtf8() -private val SHUFFLE_PHASE_INPUT = computationStageInput { - honestMajorityShareShuffleShufflePhaseInput = - HonestMajorityShareShuffleKt.shufflePhaseInput { this.peerRandomSeed = PEER_SEED } -} - -private val WAIT_ON_SHUFFLE_INPUT_TOKEN = computationToken { - computationStage = computationStage { honestMajorityShareShuffle = Stage.WAIT_ON_SHUFFLE_INPUT } - computationDetails = computationDetails { - honestMajorityShareShuffle = - HonestMajorityShareShuffleKt.computationDetails { - seeds = randomSeeds { this.commonRandomSeed = SEED } - } - } - requisitions += requisitionMetadata {} - requisitions += requisitionMetadata {} -} - -private val UPDATED_WAIT_ON_SHUFFLE_INPUT_TOKEN = - WAIT_ON_SHUFFLE_INPUT_TOKEN.copy { - computationDetails = computationDetails { - honestMajorityShareShuffle = - HonestMajorityShareShuffleKt.computationDetails { - seeds = randomSeeds { - this.commonRandomSeed = SEED - this.commonRandomSeedFromPeer = PEER_SEED - } - } - } - } - -private val FULFILLED_WAIT_ON_SHUFFLE_INPUT_TOKEN = - WAIT_ON_SHUFFLE_INPUT_TOKEN.copy { - requisitions.clear() - requisitions += requisitionMetadata { - secretSeedCiphertext = "seed_1".toByteStringUtf8() - path = "path_1" - } - requisitions += requisitionMetadata { - secretSeedCiphertext = "seed_2".toByteStringUtf8() - path = "path_2" - } - } - -private val READY_WAIT_ON_SHUFFLE_INPUT_TOKEN = - UPDATED_WAIT_ON_SHUFFLE_INPUT_TOKEN.copy { - requisitions.clear() - requisitions += requisitionMetadata { - secretSeedCiphertext = "seed_1".toByteStringUtf8() - path = "path_1" - } - requisitions += requisitionMetadata { - secretSeedCiphertext = "seed_2".toByteStringUtf8() - path = "path_2" - } - } - -private val WAIT_ON_AGGREGATION_INPUT_TOKEN = computationToken { - computationStage = computationStage { - honestMajorityShareShuffle = Stage.WAIT_ON_AGGREGATION_INPUT - } - blobs += computationStageBlobMetadata { - dependencyType = ComputationBlobDependency.OUTPUT - blobId = 1 - } - blobs += computationStageBlobMetadata { - dependencyType = ComputationBlobDependency.OUTPUT - blobId = 2 - } -} - -private val READY_WAIT_ON_AGGREGATION_INPUT_TOKEN = - WAIT_ON_AGGREGATION_INPUT_TOKEN.copy { - blobs.clear() - blobs += computationStageBlobMetadata { - dependencyType = ComputationBlobDependency.OUTPUT - path = "path_1" - blobId = 1 - } - blobs += computationStageBlobMetadata { - dependencyType = ComputationBlobDependency.OUTPUT - path = "path_2" - blobId = 2 - } - } @RunWith(JUnit4::class) class HonestMajorityShareShuffleStagesTest { @@ -141,11 +45,12 @@ class HonestMajorityShareShuffleStagesTest { } @Test - fun `next stages are valid for WAIT_ON_AGGREGATION_INPUT and WAIT_ON_SHUFFLE_INPUT`() { + fun `next stages are valid for waiting stages`() { for (stage in Stage.values()) { when (stage) { - Stage.WAIT_ON_AGGREGATION_INPUT, - Stage.WAIT_ON_SHUFFLE_INPUT -> { + Stage.WAIT_ON_SHUFFLE_INPUT_PHASE_ONE, + Stage.WAIT_ON_SHUFFLE_INPUT_PHASE_TWO, + Stage.WAIT_ON_AGGREGATION_INPUT -> { val next = stages.nextStage(stage.toProtocolStage()).honestMajorityShareShuffle assertThat(HonestMajorityShareShuffleProtocol.EnumStages.validTransition(stage, next)) .isTrue() @@ -158,7 +63,7 @@ class HonestMajorityShareShuffleStagesTest { } @Test - fun `outputBlob returns BlobMeta for WAIT_ON_AGGREGATION_INPUT`() { + fun `outputBlob returns BlobMetadata for WAIT_ON_AGGREGATION_INPUT`() { val token = computationToken { computationStage = Stage.WAIT_ON_AGGREGATION_INPUT.toProtocolStage() blobs += newEmptyOutputBlobMetadata(1L) @@ -178,99 +83,4 @@ class HonestMajorityShareShuffleStagesTest { assertThat(stages.outputBlob(token, "bob")).isEqualTo(newEmptyOutputBlobMetadata(1L)) assertFailsWith { stages.outputBlob(token, "unknown-sender") } } - - @Test - fun `isValidStage returns correct boolean`() { - for (currentStage in Stage.values()) { - for (requestStage in Stage.values()) { - if (currentStage == Stage.UNRECOGNIZED || requestStage == Stage.UNRECOGNIZED) { - continue - } - when (requestStage) { - Stage.WAIT_ON_SHUFFLE_INPUT -> { - if ( - currentStage == Stage.INITIALIZED || - currentStage == Stage.SETUP_PHASE || - currentStage == Stage.WAIT_ON_SHUFFLE_INPUT - ) { - assertThat( - stages.isValidStage( - currentStage.toProtocolStage(), - requestStage.toProtocolStage(), - ) - ) - .isTrue() - } else { - assertThat( - stages.isValidStage( - currentStage.toProtocolStage(), - requestStage.toProtocolStage(), - ) - ) - .isFalse() - } - } - else -> { - if (currentStage == requestStage) { - assertThat( - stages.isValidStage( - currentStage.toProtocolStage(), - requestStage.toProtocolStage(), - ) - ) - .isTrue() - } else { - assertThat( - stages.isValidStage( - currentStage.toProtocolStage(), - requestStage.toProtocolStage(), - ) - ) - .isFalse() - } - } - } - } - } - } - - @Test - fun `expectStageInput returns correct booleans`() { - assertThat(stages.expectStageInput(WAIT_ON_SHUFFLE_INPUT_TOKEN)).isTrue() - assertThat(stages.expectStageInput(UPDATED_WAIT_ON_SHUFFLE_INPUT_TOKEN)).isFalse() - } - - @Test - fun `updateComputationDetails returns updated computation details`() { - val updatedDetails = - stages.updateComputationDetails( - WAIT_ON_SHUFFLE_INPUT_TOKEN.computationDetails, - SHUFFLE_PHASE_INPUT, - ) - - assertThat(updatedDetails).isEqualTo(UPDATED_WAIT_ON_SHUFFLE_INPUT_TOKEN.computationDetails) - } - - @Test - fun `readyForNextStage returns correct boolean`() { - assertThat(stages.readyForNextStage(WAIT_ON_SHUFFLE_INPUT_TOKEN)).isFalse() - assertThat(stages.readyForNextStage(UPDATED_WAIT_ON_SHUFFLE_INPUT_TOKEN)).isFalse() - assertThat(stages.readyForNextStage(FULFILLED_WAIT_ON_SHUFFLE_INPUT_TOKEN)).isFalse() - assertThat(stages.readyForNextStage(READY_WAIT_ON_SHUFFLE_INPUT_TOKEN)).isTrue() - - val shufflePhaseToken = - READY_WAIT_ON_SHUFFLE_INPUT_TOKEN.copy { - computationStage = computationStage { honestMajorityShareShuffle = Stage.SHUFFLE_PHASE } - } - assertThat(stages.readyForNextStage(shufflePhaseToken)).isFalse() - - assertThat(stages.readyForNextStage(WAIT_ON_AGGREGATION_INPUT_TOKEN)).isFalse() - assertThat(stages.readyForNextStage(READY_WAIT_ON_AGGREGATION_INPUT_TOKEN)).isTrue() - - val aggregationPhaseToken = - READY_WAIT_ON_SHUFFLE_INPUT_TOKEN.copy { - computationStage = computationStage { honestMajorityShareShuffle = Stage.AGGREGATION_PHASE } - } - assertThat(stages.readyForNextStage(aggregationPhaseToken)).isFalse() - } } diff --git a/src/test/kotlin/org/wfanet/measurement/duchy/service/system/v1alpha/ComputationControlServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/duchy/service/system/v1alpha/ComputationControlServiceTest.kt index 4f8ed649524..a53918b4cdb 100644 --- a/src/test/kotlin/org/wfanet/measurement/duchy/service/system/v1alpha/ComputationControlServiceTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/duchy/service/system/v1alpha/ComputationControlServiceTest.kt @@ -20,7 +20,6 @@ import com.google.protobuf.kotlin.toByteStringUtf8 import io.grpc.StatusRuntimeException import kotlin.test.assertFailsWith import kotlin.test.assertNotNull -import kotlin.test.assertNull import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.asFlow import kotlinx.coroutines.flow.flowOf @@ -33,9 +32,7 @@ import org.junit.rules.TemporaryFolder import org.junit.runner.RunWith import org.junit.runners.JUnit4 import org.mockito.kotlin.any -import org.mockito.kotlin.never import org.mockito.kotlin.stub -import org.mockito.kotlin.verifyBlocking import org.wfanet.measurement.common.grpc.testing.GrpcTestServerRule import org.wfanet.measurement.common.grpc.testing.mockService import org.wfanet.measurement.common.identity.DuchyIdentity @@ -52,10 +49,8 @@ import org.wfanet.measurement.internal.duchy.AsyncComputationControlGrpcKt.Async import org.wfanet.measurement.internal.duchy.ComputationBlobDependency import org.wfanet.measurement.internal.duchy.advanceComputationRequest as asyncAdvanceComputationRequest import org.wfanet.measurement.internal.duchy.computationStageBlobMetadata -import org.wfanet.measurement.internal.duchy.computationStageInput import org.wfanet.measurement.internal.duchy.getOutputBlobMetadataRequest import org.wfanet.measurement.internal.duchy.protocol.HonestMajorityShareShuffle as HonestMajorityShareShuffleProtocol -import org.wfanet.measurement.internal.duchy.protocol.HonestMajorityShareShuffleKt.shufflePhaseInput import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggregationV2 import org.wfanet.measurement.internal.duchy.protocol.ReachOnlyLiquidLegionsSketchAggregationV2 import org.wfanet.measurement.storage.filesystem.FileSystemStorageClient @@ -64,7 +59,6 @@ import org.wfanet.measurement.system.v1alpha.AdvanceComputationRequest import org.wfanet.measurement.system.v1alpha.HonestMajorityShareShuffle import org.wfanet.measurement.system.v1alpha.LiquidLegionsV2 import org.wfanet.measurement.system.v1alpha.ReachOnlyLiquidLegionsV2 -import org.wfanet.measurement.system.v1alpha.advanceComputationRequest private const val RUNNING_DUCHY_NAME = "Alsace" private const val BAVARIA = "Bavaria" @@ -397,7 +391,42 @@ class ComputationControlServiceTest { } @Test - fun `honest majority share shuffle sends blob as input`() = runBlocking { + fun `honest majority share shuffle sends input for WAIT_ON_SHUFFLE_INPUT_PHASE_ONE`() = + runBlocking { + val id = "444444" + val blobKey = "$id/WAIT_ON_SHUFFLE_INPUT_PHASE_ONE/$BLOB_ID" + val carinthiaHeader = + advanceComputationHeader(HonestMajorityShareShuffle.Description.SHUFFLE_PHASE_INPUT_ONE, id) + + withSender(carinthia) { advanceComputation(carinthiaHeader.withContent(BLOB_CONTENT)) } + + verifyProtoArgument( + mockAsyncControlService, + AsyncComputationControlCoroutineImplBase::getOutputBlobMetadata, + ) + .isEqualTo( + getOutputBlobMetadataRequest { + globalComputationId = id + dataOrigin = CARINTHIA + } + ) + assertThat(advanceAsyncComputationRequests) + .containsExactly( + asyncAdvanceComputationRequest { + globalComputationId = id + computationStage = + HonestMajorityShareShuffleProtocol.Stage.WAIT_ON_SHUFFLE_INPUT_PHASE_ONE + .toProtocolStage() + blobId = BLOB_ID + blobPath = blobKey + } + ) + val data = assertNotNull(computationStore.get(blobKey)) + assertThat(data).contentEqualTo(BLOB_CONTENT) + } + + @Test + fun `honest majority share shuffle sends aggregation phase input`() = runBlocking { val id = "444444" val blobKey = "$id/WAIT_ON_AGGREGATION_INPUT/$BLOB_ID" val carinthiaHeader = @@ -428,32 +457,6 @@ class ComputationControlServiceTest { val data = assertNotNull(computationStore.get(blobKey)) assertThat(data).contentEqualTo(BLOB_CONTENT) } - - @Test - fun `honest majority share shuffle sends seed as input`() = runBlocking { - val id = "444444" - val blobKey = "$id/WAIT_ON_SHUFFLE_INPUT/$BLOB_ID" - val carinthiaHeader = - advanceComputationHeader(HonestMajorityShareShuffle.Description.SHUFFLE_PHASE_INPUT, id, SEED) - - withSender(carinthia) { advanceComputation(carinthiaHeader.withSeed()) } - - verifyBlocking(mockAsyncControlService, never()) { getOutputBlobMetadata(any()) } - assertThat(advanceAsyncComputationRequests) - .containsExactly( - asyncAdvanceComputationRequest { - globalComputationId = id - computationStage = - HonestMajorityShareShuffleProtocol.Stage.WAIT_ON_SHUFFLE_INPUT.toProtocolStage() - computationStageInput = computationStageInput { - honestMajorityShareShuffleShufflePhaseInput = shufflePhaseInput { - peerRandomSeed = SEED - } - } - } - ) - val data = assertNull(computationStore.get(blobKey)) - } } private fun AdvanceComputationRequest.Header.withContent( @@ -469,7 +472,3 @@ private fun AdvanceComputationRequest.Header.withContent( .asFlow() .onStart { emit(AdvanceComputationRequest.newBuilder().setHeader(this@withContent).build()) } } - -private fun AdvanceComputationRequest.Header.withSeed(): Flow { - return flowOf(advanceComputationRequest { header = this@withSeed }) -}