Skip to content

Commit

Permalink
Add InProcessAccuracyTest for Reach (#1296)
Browse files Browse the repository at this point in the history
  • Loading branch information
renjiezh authored Nov 7, 2023
1 parent 7640019 commit 0316c23
Show file tree
Hide file tree
Showing 11 changed files with 439 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -240,3 +240,22 @@ kt_jvm_library(
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/testing",
],
)

kt_jvm_library(
name = "in_process_reach_measurement_accuracy_test",
srcs = [
"InProcessReachMeasurementAccuracyTest.kt",
],
deps = [
":all_kingdom_services",
":in_process_cmms_components",
"//src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/server:duchy_data_server",
"//src/main/kotlin/org/wfanet/measurement/kingdom/deploy/common/service:data_services",
"//src/main/kotlin/org/wfanet/measurement/loadtest/measurementconsumer:simulator",
"//src/main/kotlin/org/wfanet/measurement/loadtest/measurementconsumer:synthetic_generator_event_query",
"//src/main/kotlin/org/wfanet/measurement/measurementconsumer/stats:variances",
"@wfa_common_jvm//imports/java/org/junit",
"@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/testing",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.wfanet.measurement.api.v2alpha.AccountsGrpcKt
import org.wfanet.measurement.api.v2alpha.ApiKeysGrpcKt
import org.wfanet.measurement.api.v2alpha.EventGroup
import org.wfanet.measurement.api.v2alpha.MeasurementConsumersGrpcKt
import org.wfanet.measurement.api.v2alpha.event_group_metadata.testing.SyntheticEventGroupSpec
import org.wfanet.measurement.common.crypto.subjectKeyIdentifier
import org.wfanet.measurement.common.crypto.tink.TinkPrivateKeyHandle
import org.wfanet.measurement.common.identity.DuchyInfo
Expand All @@ -46,6 +47,8 @@ class InProcessCmmsComponents(
private val kingdomDataServicesRule: ProviderRule<DataServices>,
private val duchyDependenciesRule:
ProviderRule<(String, ComputationLogEntriesCoroutineStub) -> InProcessDuchy.DuchyDependencies>,
private val syntheticEventGroupSpecs: List<SyntheticEventGroupSpec> =
SyntheticGenerationSpecs.SYNTHETIC_DATA_SPECS
) : TestRule {
private val kingdomDataServices: DataServices
get() = kingdomDataServicesRule.value
Expand All @@ -71,15 +74,15 @@ class InProcessCmmsComponents(

private val edpSimulators: List<InProcessEdpSimulator> by lazy {
edpDisplayNameToResourceNameMap.entries.mapIndexed { index, (displayName, resourceName) ->
val specIndex = index % SyntheticGenerationSpecs.SYNTHETIC_DATA_SPECS.size
val specIndex = index % syntheticEventGroupSpecs.size
InProcessEdpSimulator(
displayName = displayName,
resourceName = resourceName,
mcResourceName = mcResourceName,
kingdomPublicApiChannel = kingdom.publicApiChannel,
duchyPublicApiChannel = duchies[1].publicApiChannel,
trustedCertificates = TRUSTED_CERTIFICATES,
SyntheticGenerationSpecs.SYNTHETIC_DATA_SPECS[specIndex],
syntheticEventGroupSpecs[specIndex],
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,44 +121,44 @@ abstract class InProcessLifeOfAMeasurementIntegrationTest(
fun `create a RF measurement and check the result is equal to the expected result`() =
runBlocking {
// Use frontend simulator to create a reach and frequency measurement and verify its result.
mcSimulator.executeReachAndFrequency("1234")
mcSimulator.testReachAndFrequency("1234")
}

@Test
fun `create a direct RF measurement and check the result is equal to the expected result`() =
runBlocking {
// Use frontend simulator to create a direct reach and frequency measurement and verify its
// result.
mcSimulator.executeDirectReachAndFrequency("1234")
mcSimulator.testDirectReachAndFrequency("1234")
}

@Test
fun `create a reach-only measurement and check the result is equal to the expected result`() =
runBlocking {
// Use frontend simulator to create a reach and frequency measurement and verify its result.
mcSimulator.executeReachOnly("1234")
mcSimulator.testReachOnly("1234")
}

@Test
fun `create an impression measurement and check the result is equal to the expected result`() =
runBlocking {
// Use frontend simulator to create an impression measurement and verify its result.
mcSimulator.executeImpression("1234")
mcSimulator.testImpression("1234")
}

@Test
fun `create a duration measurement and check the result is equal to the expected result`() =
runBlocking {
// Use frontend simulator to create a duration measurement and verify its result.
mcSimulator.executeDuration("1234")
mcSimulator.testDuration("1234")
}

@Test
fun `create a RF measurement of invalid params and check the result contains error info`() =
runBlocking {
// Use frontend simulator to create an invalid reach and frequency measurement and verify
// its error info.
mcSimulator.executeInvalidReachAndFrequency("1234")
mcSimulator.testInvalidReachAndFrequency("1234")
}

// TODO(@renjiez): Add Multi-round test given the same input to verify correctness.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,283 @@
// Copyright 2023 The Cross-Media Measurement Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package org.wfanet.measurement.integration.common

import com.google.common.truth.Truth.assertThat
import java.time.Duration
import java.util.logging.Level
import java.util.logging.Logger
import kotlin.math.abs
import kotlin.math.pow
import kotlin.math.sqrt
import kotlinx.coroutines.runBlocking
import org.junit.After
import org.junit.Before
import org.junit.BeforeClass
import org.junit.Rule
import org.junit.Test
import org.wfanet.measurement.api.v2alpha.CertificatesGrpcKt
import org.wfanet.measurement.api.v2alpha.DataProvidersGrpcKt
import org.wfanet.measurement.api.v2alpha.EventGroupsGrpcKt
import org.wfanet.measurement.api.v2alpha.MeasurementConsumersGrpcKt
import org.wfanet.measurement.api.v2alpha.MeasurementsGrpcKt
import org.wfanet.measurement.api.v2alpha.ProtocolConfig.NoiseMechanism
import org.wfanet.measurement.api.v2alpha.RequisitionsGrpcKt
import org.wfanet.measurement.api.v2alpha.differentialPrivacyParams
import org.wfanet.measurement.common.testing.ProviderRule
import org.wfanet.measurement.eventdataprovider.noiser.DpParams
import org.wfanet.measurement.kingdom.deploy.common.RoLlv2ProtocolConfig
import org.wfanet.measurement.kingdom.deploy.common.service.DataServices
import org.wfanet.measurement.loadtest.measurementconsumer.MeasurementConsumerData
import org.wfanet.measurement.loadtest.measurementconsumer.MeasurementConsumerSimulator
import org.wfanet.measurement.loadtest.measurementconsumer.MeasurementConsumerSimulator.MeasurementInfo
import org.wfanet.measurement.loadtest.measurementconsumer.MetadataSyntheticGeneratorEventQuery
import org.wfanet.measurement.measurementconsumer.stats.LiquidLegionsV2Methodology
import org.wfanet.measurement.measurementconsumer.stats.NoiseMechanism as StatsNoiseMechanism
import org.wfanet.measurement.measurementconsumer.stats.ReachMeasurementParams
import org.wfanet.measurement.measurementconsumer.stats.ReachMeasurementVarianceParams
import org.wfanet.measurement.measurementconsumer.stats.VariancesImpl.computeMeasurementVariance
import org.wfanet.measurement.measurementconsumer.stats.VidSamplingInterval as StatsVidSamplingInterval
import org.wfanet.measurement.system.v1alpha.ComputationLogEntriesGrpcKt

/**
* Test the Measurement results are accurate w.r.t to the variance.
*
* This is abstract so that different implementations of dependencies can all run the same tests
* easily.
*/
abstract class InProcessReachMeasurementAccuracyTest(
kingdomDataServicesRule: ProviderRule<DataServices>,
duchyDependenciesRule:
ProviderRule<
(
String,
ComputationLogEntriesGrpcKt.ComputationLogEntriesCoroutineStub,
) -> InProcessDuchy.DuchyDependencies
>,
) {

@get:Rule
val inProcessCmmsComponents =
InProcessCmmsComponents(
kingdomDataServicesRule,
duchyDependenciesRule,
SYNTHETIC_EVENT_GROUP_SPECS
)

private lateinit var mcSimulator: MeasurementConsumerSimulator

private val publicMeasurementsClient by lazy {
MeasurementsGrpcKt.MeasurementsCoroutineStub(inProcessCmmsComponents.kingdom.publicApiChannel)
}
private val publicMeasurementConsumersClient by lazy {
MeasurementConsumersGrpcKt.MeasurementConsumersCoroutineStub(
inProcessCmmsComponents.kingdom.publicApiChannel
)
}
private val publicCertificatesClient by lazy {
CertificatesGrpcKt.CertificatesCoroutineStub(inProcessCmmsComponents.kingdom.publicApiChannel)
}
private val publicEventGroupsClient by lazy {
EventGroupsGrpcKt.EventGroupsCoroutineStub(inProcessCmmsComponents.kingdom.publicApiChannel)
}
private val publicDataProvidersClient by lazy {
DataProvidersGrpcKt.DataProvidersCoroutineStub(inProcessCmmsComponents.kingdom.publicApiChannel)
}
private val publicRequisitionsClient by lazy {
RequisitionsGrpcKt.RequisitionsCoroutineStub(inProcessCmmsComponents.kingdom.publicApiChannel)
}

@Before
fun startDaemons() {
inProcessCmmsComponents.startDaemons()
initMcSimulator()
}

private fun initMcSimulator() {
val measurementConsumerData = inProcessCmmsComponents.getMeasurementConsumerData()
val eventQuery =
MetadataSyntheticGeneratorEventQuery(
SyntheticGenerationSpecs.POPULATION_SPEC,
InProcessCmmsComponents.MC_ENCRYPTION_PRIVATE_KEY
)
mcSimulator =
MeasurementConsumerSimulator(
MeasurementConsumerData(
measurementConsumerData.name,
InProcessCmmsComponents.MC_ENTITY_CONTENT.signingKey,
InProcessCmmsComponents.MC_ENCRYPTION_PRIVATE_KEY,
measurementConsumerData.apiAuthenticationKey
),
OUTPUT_DP_PARAMS,
publicDataProvidersClient,
publicEventGroupsClient,
publicMeasurementsClient,
publicMeasurementConsumersClient,
publicCertificatesClient,
RESULT_POLLING_DELAY,
InProcessCmmsComponents.TRUSTED_CERTIFICATES,
eventQuery,
NoiseMechanism.CONTINUOUS_GAUSSIAN
)
}

@After
fun stopEdpSimulators() {
inProcessCmmsComponents.stopEdpSimulators()
}

@After
fun stopDuchyDaemons() {
inProcessCmmsComponents.stopDuchyDaemons()
}

private fun getReachVariance(measurementInfo: MeasurementInfo, reach: Long): Double {
val liquidLegionsMethodology =
LiquidLegionsV2Methodology(
RoLlv2ProtocolConfig.protocolConfig.sketchParams.decayRate,
RoLlv2ProtocolConfig.protocolConfig.sketchParams.maxSize,
RoLlv2ProtocolConfig.protocolConfig.sketchParams.samplingIndicatorSize
)
val reachMeasurementParams =
ReachMeasurementParams(
StatsVidSamplingInterval(
measurementInfo.measurementSpec.vidSamplingInterval.start.toDouble(),
measurementInfo.measurementSpec.vidSamplingInterval.width.toDouble()
),
DpParams(OUTPUT_DP_PARAMS.epsilon, OUTPUT_DP_PARAMS.delta),
StatsNoiseMechanism.GAUSSIAN
)
val reachMeasurementVarianceParams =
ReachMeasurementVarianceParams(reach, reachMeasurementParams)
return computeMeasurementVariance(liquidLegionsMethodology, reachMeasurementVarianceParams)
}

private fun getStandardDeviation(nums: List<Double>): Double {
val mean = nums.average()
val standardDeviation = nums.fold(0.0) { acc, num -> acc + (num - mean).pow(2.0) }

return sqrt(standardDeviation / nums.size)
}

data class ReachResult(
val actualReach: Long,
val expectedReach: Long,
val lowerBound: Double,
val upperBound: Double,
val withinInterval: Boolean,
)

@Test
fun `reach-only llv2 results should be accurate with respect to the variance`() = runBlocking {
val reachResults = mutableListOf<ReachResult>()
var expectedReach = -1L
var expectedStandardDeviation = 0.0

var summary = ""
for (round in 1..DEFAULT_TEST_ROUND_NUMBER) {
val executionResult = mcSimulator.executeReachOnly(round.toString())

if (expectedReach == -1L) {
expectedReach = executionResult.expectedResult.reach.value
val expectedVariance = getReachVariance(executionResult.measurementInfo, expectedReach)
expectedStandardDeviation = sqrt(expectedVariance)
} else if (expectedReach != executionResult.expectedResult.reach.value) {
logger.log(
Level.WARNING,
"expected result not consistent. round=$round, prev_expected_result=$expectedReach, " +
"current_expected_result=${executionResult.expectedResult.reach.value}"
)
}

// The general formula for confidence interval is result +/- multiplier * sqrt(variance).
// The multiplier for 95% confidence interval is 1.96.
val reach = executionResult.actualResult.reach.value
val reachVariance = getReachVariance(executionResult.measurementInfo, reach)
val intervalLowerBound = reach - sqrt(reachVariance) * MULTIPLIER
val intervalUpperBound = reach + sqrt(reachVariance) * MULTIPLIER
val withinInterval = reach >= intervalLowerBound && reach <= intervalUpperBound

val reachResult =
ReachResult(reach, expectedReach, intervalLowerBound, intervalUpperBound, withinInterval)
reachResults += reachResult

val message =
"round=$round, actual_result=${reachResult.actualReach}, " +
"expected_result=${reachResult.expectedReach}, " +
"interval=(${"%.2f".format(reachResult.lowerBound)}, " +
"${"%.2f".format(reachResult.upperBound)}), accurate=${reachResult.withinInterval}"
summary += message + "\n"
logger.log(Level.INFO, message)
}

logger.log(Level.INFO, "Accuracy Test Complete.\n$summary")

val averageReach = reachResults.map { it.actualReach }.average()
val withinIntervalNumber = reachResults.map { if (it.withinInterval) 1 else 0 }.sum()
val withinIntervalPercentage = withinIntervalNumber.toDouble() / reachResults.size * 100
val offsetPercentage = (averageReach - expectedReach) / expectedReach * 100
val averageDispersionRatio =
abs(averageReach - expectedReach) * sqrt(DEFAULT_TEST_ROUND_NUMBER.toDouble()) /
expectedStandardDeviation

logger.log(
Level.INFO,
"average_reach=$averageReach, offset_percentage=${"%.2f".format(offsetPercentage)}%, " +
"number_of_rounds_within_interval=$withinIntervalNumber out of $DEFAULT_TEST_ROUND_NUMBER " +
"(${"%.2f".format(withinIntervalPercentage)}%) "
)

val standardDeviation = getStandardDeviation(reachResults.map { it.actualReach.toDouble() })
logger.log(
Level.INFO,
"std=${"%.2f".format(standardDeviation)}, " +
"expected_std=${"%.2f".format(expectedStandardDeviation)}, " +
"ratio=${"%.2f".format(standardDeviation / expectedStandardDeviation)}"
)

assertThat(withinIntervalPercentage).isAtLeast(COVERAGE_TEST_THRESHOLD)
assertThat(averageDispersionRatio).isLessThan(AVERAGE_TEST_THRESHOLD)
assertThat(standardDeviation)
.isGreaterThan(expectedStandardDeviation * STANDARD_DEVIATION_TEST_LOWER_THRESHOLD)
assertThat(standardDeviation)
.isLessThan(expectedStandardDeviation * STANDARD_DEVIATION_TEST_UPPER_THRESHOLD)
}

companion object {
private val logger: Logger = Logger.getLogger(this::class.java.name)

private val SYNTHETIC_EVENT_GROUP_SPECS = SyntheticGenerationSpecs.SYNTHETIC_DATA_SPECS_2M

private const val DEFAULT_TEST_ROUND_NUMBER = 30
// Multiplier for 95% confidence interval
private const val MULTIPLIER = 1.96

private const val COVERAGE_TEST_THRESHOLD = 80
private const val AVERAGE_TEST_THRESHOLD = 2.58
private const val STANDARD_DEVIATION_TEST_LOWER_THRESHOLD = 0.67
private const val STANDARD_DEVIATION_TEST_UPPER_THRESHOLD = 1.35
private val OUTPUT_DP_PARAMS = differentialPrivacyParams {
epsilon = 0.0033
delta = 0.00001
}
private val RESULT_POLLING_DELAY = Duration.ofSeconds(10)

@BeforeClass
@JvmStatic
fun initConfig() {
InProcessCmmsComponents.initConfig()
}
}
}
Loading

0 comments on commit 0316c23

Please sign in to comment.