Skip to content

Commit

Permalink
Update integration tests to verify direct measurement results (#1212)
Browse files Browse the repository at this point in the history
  • Loading branch information
riemanli authored Sep 14, 2023
1 parent 5d31013 commit 247d18b
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.wfanet.measurement.api.v2alpha.DataProvidersGrpcKt.DataProvidersCorou
import org.wfanet.measurement.api.v2alpha.EventGroupsGrpcKt.EventGroupsCoroutineStub
import org.wfanet.measurement.api.v2alpha.MeasurementConsumersGrpcKt.MeasurementConsumersCoroutineStub
import org.wfanet.measurement.api.v2alpha.MeasurementsGrpcKt.MeasurementsCoroutineStub
import org.wfanet.measurement.api.v2alpha.ProtocolConfig.NoiseMechanism
import org.wfanet.measurement.api.v2alpha.RequisitionsGrpcKt.RequisitionsCoroutineStub
import org.wfanet.measurement.api.v2alpha.differentialPrivacyParams
import org.wfanet.measurement.common.testing.ProviderRule
Expand Down Expand Up @@ -101,7 +102,8 @@ abstract class InProcessLifeOfAMeasurementIntegrationTest(
publicCertificatesClient,
RESULT_POLLING_DELAY,
InProcessCmmsComponents.TRUSTED_CERTIFICATES,
eventQuery
eventQuery,
NoiseMechanism.CONTINUOUS_GAUSSIAN
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ import org.wfanet.measurement.api.v2alpha.MeasurementSpecKt.impression
import org.wfanet.measurement.api.v2alpha.MeasurementSpecKt.reachAndFrequency
import org.wfanet.measurement.api.v2alpha.MeasurementSpecKt.vidSamplingInterval
import org.wfanet.measurement.api.v2alpha.MeasurementsGrpcKt.MeasurementsCoroutineStub
import org.wfanet.measurement.api.v2alpha.ProtocolConfig.NoiseMechanism
import org.wfanet.measurement.api.v2alpha.RequisitionSpec
import org.wfanet.measurement.api.v2alpha.RequisitionSpecKt
import org.wfanet.measurement.api.v2alpha.RequisitionSpecKt.eventFilter
Expand Down Expand Up @@ -118,6 +119,7 @@ class MeasurementConsumerSimulator(
private val resultPollingDelay: Duration,
private val trustedCertificates: Map<ByteString, X509Certificate>,
private val eventQuery: EventQuery<Message>,
private val expectedDirectNoiseMechanism: NoiseMechanism,
) {
/** Cache of resource name to [Certificate]. */
private val certificateCache = mutableMapOf<String, Certificate>()
Expand Down Expand Up @@ -231,14 +233,44 @@ class MeasurementConsumerSimulator(
.reachValue()
.isWithinPercent(0.5)
.of(expectedResult.reach.value)
assertThat(reachAndFrequencyResult.reach.hasDeterministicCountDistinct()).isTrue()
assertThat(reachAndFrequencyResult.reach.noiseMechanism).isEqualTo(expectedDirectNoiseMechanism)

assertThat(reachAndFrequencyResult)
.frequencyDistribution()
.isWithin(0.01)
.of(expectedResult.frequency.relativeFrequencyDistributionMap)
assertThat(reachAndFrequencyResult.frequency.hasDeterministicDistribution()).isTrue()
assertThat(reachAndFrequencyResult.frequency.noiseMechanism)
.isEqualTo(expectedDirectNoiseMechanism)

logger.info("Direct reach and frequency result is equal to the expected result")
}

/** A sequence of operations done in the simulator involving a direct reach measurement. */
suspend fun executeDirectReach(runId: String) {
// Create a new measurement on behalf of the measurement consumer.
val measurementConsumer = getMeasurementConsumer(measurementConsumerData.name)
val measurementInfo =
createMeasurement(measurementConsumer, runId, ::newReachMeasurementSpec, 1)
val measurementName = measurementInfo.measurement.name
logger.info("Created direct reach measurement $measurementName.")

// Get the CMMS computed result and compare it with the expected result.
val reachResult = pollForResult { getReachResult(measurementName) }
logger.info("Got direct reach result from Kingdom: $reachResult")

val expectedResult = getExpectedResult(measurementInfo)
logger.info("Expected result: $expectedResult")

// TODO(@riemanli): Use variance rather than fixed tolerance values.
assertThat(reachResult).reachValue().isWithinPercent(0.5).of(expectedResult.reach.value)
assertThat(reachResult.reach.hasDeterministicCountDistinct()).isTrue()
assertThat(reachResult.reach.noiseMechanism).isEqualTo(expectedDirectNoiseMechanism)

logger.info("Direct reach result is equal to the expected result")
}

/** A sequence of operations done in the simulator involving a reach-only measurement. */
suspend fun executeReachOnly(runId: String) {
// Create a new measurement on behalf of the measurement consumer.
Expand Down Expand Up @@ -286,6 +318,9 @@ class MeasurementConsumerSimulator(
// EdpSimulator sets it to this value.
apiIdToExternalId(DataProviderCertificateKey.fromName(it.certificate)!!.dataProviderId)
)
// EdpSimulator hasn't had an implementation for impression.
assertThat(!result.impression.hasDeterministicCount()).isTrue()
assertThat(result.impression.noiseMechanism).isEqualTo(expectedDirectNoiseMechanism)
}
logger.info("Impression result is equal to the expected result")
}
Expand All @@ -311,6 +346,9 @@ class MeasurementConsumerSimulator(
// EdpSimulator sets it to this value.
log2(externalDataProviderId.toDouble()).toLong()
)
// EdpSimulator hasn't had an implementation for watch duration.
assertThat(!result.watchDuration.hasDeterministicSum()).isTrue()
assertThat(result.watchDuration.noiseMechanism).isEqualTo(expectedDirectNoiseMechanism)
}
logger.info("Duration result is equal to the expected result")
}
Expand Down Expand Up @@ -540,6 +578,21 @@ class MeasurementConsumerSimulator(
}
}

private fun newReachMeasurementSpec(
serializedMeasurementPublicKey: ByteString,
nonceHashes: List<ByteString>
): MeasurementSpec {
return measurementSpec {
measurementPublicKey = serializedMeasurementPublicKey
reach = MeasurementSpecKt.reach { privacyParams = outputDpParams }
vidSamplingInterval = vidSamplingInterval {
start = 0.0f
width = 1.0f
}
this.nonceHashes += nonceHashes
}
}

private fun newReachAndFrequencyMeasurementSpec(
serializedMeasurementPublicKey: ByteString,
nonceHashes: List<ByteString>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import org.wfanet.measurement.api.v2alpha.DataProvidersGrpcKt
import org.wfanet.measurement.api.v2alpha.EventGroupsGrpcKt
import org.wfanet.measurement.api.v2alpha.MeasurementConsumersGrpcKt
import org.wfanet.measurement.api.v2alpha.MeasurementsGrpcKt
import org.wfanet.measurement.api.v2alpha.ProtocolConfig
import org.wfanet.measurement.common.grpc.buildMutualTlsChannel
import org.wfanet.measurement.common.grpc.withDefaultDeadline
import org.wfanet.measurement.common.parseTextProto
Expand All @@ -45,7 +46,7 @@ import org.wfanet.measurement.loadtest.measurementconsumer.MetadataBigQueryEvent

/**
* Test for correctness of an existing CMMS on Kubernetes where the EDP simulators use
* [BigQueryEventQuery].
* [BigQueryEventQuery]. The computation composition is using ACDP by assumption.
*
* This currently assumes that the CMMS instance is using the certificates and keys from this Bazel
* workspace.
Expand Down Expand Up @@ -111,7 +112,8 @@ class BigQueryCorrectnessTest : AbstractCorrectnessTest(measurementSystem) {
CertificatesGrpcKt.CertificatesCoroutineStub(publicApiChannel),
RESULT_POLLING_DELAY,
MEASUREMENT_CONSUMER_SIGNING_CERTS.trustedCertificates,
eventQuery
eventQuery,
ProtocolConfig.NoiseMechanism.CONTINUOUS_GAUSSIAN
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ import org.wfanet.measurement.api.v2alpha.EventGroupsGrpcKt
import org.wfanet.measurement.api.v2alpha.ListEventGroupsRequestKt
import org.wfanet.measurement.api.v2alpha.MeasurementConsumersGrpcKt
import org.wfanet.measurement.api.v2alpha.MeasurementsGrpcKt
import org.wfanet.measurement.api.v2alpha.ProtocolConfig
import org.wfanet.measurement.api.v2alpha.listEventGroupsRequest
import org.wfanet.measurement.api.withAuthenticationKey
import org.wfanet.measurement.common.crypto.jceProvider
Expand Down Expand Up @@ -265,6 +266,7 @@ class EmptyClusterCorrectnessTest : AbstractCorrectnessTest(measurementSystem) {
SyntheticGenerationSpecs.POPULATION_SPEC,
MC_ENCRYPTION_PRIVATE_KEY
),
ProtocolConfig.NoiseMechanism.CONTINUOUS_GAUSSIAN
)
.also {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.wfanet.measurement.api.v2alpha.DataProvidersGrpcKt
import org.wfanet.measurement.api.v2alpha.EventGroupsGrpcKt
import org.wfanet.measurement.api.v2alpha.MeasurementConsumersGrpcKt
import org.wfanet.measurement.api.v2alpha.MeasurementsGrpcKt
import org.wfanet.measurement.api.v2alpha.ProtocolConfig
import org.wfanet.measurement.common.grpc.buildMutualTlsChannel
import org.wfanet.measurement.common.grpc.withDefaultDeadline
import org.wfanet.measurement.common.parseTextProto
Expand All @@ -43,7 +44,8 @@ import org.wfanet.measurement.loadtest.measurementconsumer.MetadataSyntheticGene

/**
* Test for correctness of an existing CMMS on Kubernetes where the EDP simulators use
* [SyntheticGeneratorEventQuery] with [SyntheticGenerationSpecs.POPULATION_SPEC].
* [SyntheticGeneratorEventQuery] with [SyntheticGenerationSpecs.POPULATION_SPEC]. The computation
* composition is using ACDP by assumption.
*
* This currently assumes that the CMMS instance is using the certificates and keys from this Bazel
* workspace.
Expand Down Expand Up @@ -105,6 +107,7 @@ class SyntheticGeneratorCorrectnessTest : AbstractCorrectnessTest(measurementSys
RESULT_POLLING_DELAY,
MEASUREMENT_CONSUMER_SIGNING_CERTS.trustedCertificates,
eventQuery,
ProtocolConfig.NoiseMechanism.CONTINUOUS_GAUSSIAN
)
}

Expand Down

0 comments on commit 247d18b

Please sign in to comment.