diff --git a/src/main/kotlin/org/wfanet/measurement/measurementconsumer/stats/Variances.kt b/src/main/kotlin/org/wfanet/measurement/measurementconsumer/stats/Variances.kt index 1e79b28ee9d..c5d4b8576db 100644 --- a/src/main/kotlin/org/wfanet/measurement/measurementconsumer/stats/Variances.kt +++ b/src/main/kotlin/org/wfanet/measurement/measurementconsumer/stats/Variances.kt @@ -161,7 +161,7 @@ object VariancesImpl : Variances { multiplier: Int) = relativeFrequencyMeasurementVarianceParams - require(totalReach > 0) { "Total reach must be positive, but got $totalReach." } + require(totalReach >= 0) { "Total reach must be non-negative, but got $totalReach." } require(reachRatio >= 0.0 && reachRatio <= 1.0) { "Reach ratio must be greater than or equal to 0 and less than or equal to 1, but got " + "$reachRatio." @@ -199,7 +199,7 @@ object VariancesImpl : Variances { reachRatio: Double, reachRatioVariance: Double, ): Double { - require(totalReach > 0) { "Total reach must be positive, but got $totalReach." } + require(totalReach >= 0) { "Total reach must be non-negative, but got $totalReach." } require(totalReachVariance >= 0) { "Total reach variance must not be negative, but got $totalReachVariance." } diff --git a/src/test/kotlin/org/wfanet/measurement/measurementconsumer/stats/VariancesTest.kt b/src/test/kotlin/org/wfanet/measurement/measurementconsumer/stats/VariancesTest.kt index e7ef79b7aa8..e202451d63f 100644 --- a/src/test/kotlin/org/wfanet/measurement/measurementconsumer/stats/VariancesTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/measurementconsumer/stats/VariancesTest.kt @@ -2548,6 +2548,211 @@ class VariancesTest { .of(expectedNKPlus) } + @Test + fun `computeMeasurementVariance returns for deterministic reach-frequency when total reach is zero and zero relative frequencies `() { + val vidSamplingIntervalWidth = 5e-2 + val totalReach = 0L + val reachDpParams = DpParams(0.5, 1e-15) + val reachMeasurementParams = + ReachMeasurementParams( + VidSamplingInterval(0.0, vidSamplingIntervalWidth), + reachDpParams, + NoiseMechanism.GAUSSIAN, + ) + val reachMeasurementVarianceParams = + ReachMeasurementVarianceParams(totalReach, reachMeasurementParams) + val reachMeasurementVariance = + VariancesImpl.computeMeasurementVariance( + DeterministicMethodology, + reachMeasurementVarianceParams, + ) + + val maximumFrequency = 5 + val relativeFrequencyDistribution = (1..maximumFrequency).associateWith { 0.0 } + val frequencyDpParams = DpParams(0.2, 1e-15) + val frequencyMeasurementParams = + FrequencyMeasurementParams( + VidSamplingInterval(0.0, vidSamplingIntervalWidth), + frequencyDpParams, + NoiseMechanism.GAUSSIAN, + maximumFrequency, + ) + val frequencyMeasurementVarianceParams = + FrequencyMeasurementVarianceParams( + totalReach, + reachMeasurementVariance, + relativeFrequencyDistribution, + frequencyMeasurementParams, + ) + + val (rKVars, rKPlusVars, nKVars, nKPlusVars) = + VariancesImpl.computeMeasurementVariance( + DeterministicMethodology, + frequencyMeasurementVarianceParams, + ) + + val expectedRK = + listOf( + 0.08333333333333336, + 0.08333333333333336, + 0.08333333333333336, + 0.08333333333333336, + 0.08333333333333336, + ) + val expectedRKPlus = + listOf( + 0.0, + 0.08333333333333336, + 0.08333333333333336, + 0.08333333333333336, + 0.08333333333333336, + ) + val expectedNK = + listOf( + 7244.988593451251, + 7244.988593451251, + 7244.988593451251, + 7244.988593451251, + 7244.988593451251, + ) + val expectedNKPlus = + listOf(0.0, 7244.988593451251, 7244.988593451251, 7244.988593451251, 7244.988593451251) + + for (frequency in 1..maximumFrequency) { + assertThat(rKVars.getValue(frequency)) + .isWithin(computeErrorTolerance(rKVars.getValue(frequency), expectedRK[frequency - 1])) + .of(expectedRK[frequency - 1]) + } + for (frequency in 1..maximumFrequency) { + assertThat(rKPlusVars.getValue(frequency)) + .isWithin( + computeErrorTolerance(rKPlusVars.getValue(frequency), expectedRKPlus[frequency - 1]) + ) + .of(expectedRKPlus[frequency - 1]) + } + for (frequency in 1..maximumFrequency) { + assertThat(nKVars.getValue(frequency)) + .isWithin(computeErrorTolerance(nKVars.getValue(frequency), expectedNK[frequency - 1])) + .of(expectedNK[frequency - 1]) + } + for (frequency in 1..maximumFrequency) { + assertThat(nKPlusVars.getValue(frequency)) + .isWithin( + computeErrorTolerance(nKPlusVars.getValue(frequency), expectedNKPlus[frequency - 1]) + ) + .of(expectedNKPlus[frequency - 1]) + } + } + + @Test + fun `computeMeasurementVariance returns for deterministic reach-frequency when total reach is zero and regular relative frequencies `() { + val vidSamplingIntervalWidth = 5e-2 + val totalReach = 0L + val reachDpParams = DpParams(0.5, 1e-15) + val reachMeasurementParams = + ReachMeasurementParams( + VidSamplingInterval(0.0, vidSamplingIntervalWidth), + reachDpParams, + NoiseMechanism.GAUSSIAN, + ) + val reachMeasurementVarianceParams = + ReachMeasurementVarianceParams(totalReach, reachMeasurementParams) + val reachMeasurementVariance = + VariancesImpl.computeMeasurementVariance( + DeterministicMethodology, + reachMeasurementVarianceParams, + ) + + val maximumFrequency = 5 + val relativeFrequencyDistribution = + mapOf( + 1 to 0.4351145038167939, + 2 to 0.0, + 3 to 0.3816793893129771, + 4 to 0.1832061068702290, + 5 to 0.0, + ) + val frequencyDpParams = DpParams(0.2, 1e-15) + val frequencyMeasurementParams = + FrequencyMeasurementParams( + VidSamplingInterval(0.0, vidSamplingIntervalWidth), + frequencyDpParams, + NoiseMechanism.GAUSSIAN, + maximumFrequency, + ) + val frequencyMeasurementVarianceParams = + FrequencyMeasurementVarianceParams( + totalReach, + reachMeasurementVariance, + relativeFrequencyDistribution, + frequencyMeasurementParams, + ) + + val (rKVars, rKPlusVars, nKVars, nKPlusVars) = + VariancesImpl.computeMeasurementVariance( + DeterministicMethodology, + frequencyMeasurementVarianceParams, + ) + + val expectedRK = + listOf( + 0.08333333333333336, + 0.08333333333333336, + 0.08333333333333336, + 0.08333333333333336, + 0.08333333333333336, + ) + val expectedRKPlus = + listOf( + 0.0, + 0.08333333333333336, + 0.08333333333333336, + 0.08333333333333336, + 0.08333333333333336, + ) + val expectedNK = + listOf( + 23699.39297634332, + 7243.321926784585, + 19905.73424562018, + 10160.74172504431, + 7243.321926784585, + ) + val expectedNKPlus = + listOf( + 86919.86312141502, + 34979.06986996207, + 34979.06986996207, + 10160.74172504431, + 7243.321926784585, + ) + + for (frequency in 1..maximumFrequency) { + assertThat(rKVars.getValue(frequency)) + .isWithin(computeErrorTolerance(rKVars.getValue(frequency), expectedRK[frequency - 1])) + .of(expectedRK[frequency - 1]) + } + for (frequency in 1..maximumFrequency) { + assertThat(rKPlusVars.getValue(frequency)) + .isWithin( + computeErrorTolerance(rKPlusVars.getValue(frequency), expectedRKPlus[frequency - 1]) + ) + .of(expectedRKPlus[frequency - 1]) + } + for (frequency in 1..maximumFrequency) { + assertThat(nKVars.getValue(frequency)) + .isWithin(computeErrorTolerance(nKVars.getValue(frequency), expectedNK[frequency - 1])) + .of(expectedNK[frequency - 1]) + } + for (frequency in 1..maximumFrequency) { + assertThat(nKPlusVars.getValue(frequency)) + .isWithin( + computeErrorTolerance(nKPlusVars.getValue(frequency), expectedNKPlus[frequency - 1]) + ) + .of(expectedNKPlus[frequency - 1]) + } + } + @Test fun `computeMeasurementVariance returns for LiquidLegionsSketch reach-frequency when reach is too small`() { val decayRate = 100.0