diff --git a/src/main/kotlin/org/wfanet/measurement/api/v2alpha/tools/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/api/v2alpha/tools/BUILD.bazel index c3139be5e09..7b99b0129e4 100644 --- a/src/main/kotlin/org/wfanet/measurement/api/v2alpha/tools/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/api/v2alpha/tools/BUILD.bazel @@ -24,10 +24,25 @@ java_binary( runtime_deps = [":encryption_public_keys"], ) +kt_jvm_library( + name = "create_measurement_flags", + srcs = [ + "CreateMeasurementFlags.kt", + ], + deps = [ + "//src/main/proto/wfa/measurement/api/v2alpha:date_interval_kt_jvm_proto", + "//src/main/proto/wfa/measurement/api/v2alpha:measurement_spec_kt_jvm_proto", + "@wfa_common_jvm//imports/java/picocli", + ], +) + kt_jvm_library( name = "measurement_system", - srcs = ["MeasurementSystem.kt"], + srcs = [ + "MeasurementSystem.kt", + ], deps = [ + ":create_measurement_flags", "//src/main/kotlin/org/wfanet/measurement/api:account_constants", "//src/main/kotlin/org/wfanet/measurement/api:api_key_constants", "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:resource_key", @@ -82,8 +97,11 @@ java_binary( kt_jvm_library( name = "benchmark_library", - srcs = ["Benchmark.kt"], + srcs = [ + "Benchmark.kt", + ], deps = [ + ":create_measurement_flags", "//src/main/kotlin/org/wfanet/measurement/api:api_key_constants", "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:resource_key", "//src/main/kotlin/org/wfanet/measurement/kingdom/deploy/common:flags", @@ -95,6 +113,7 @@ kt_jvm_library( "//src/main/proto/wfa/measurement/api/v2alpha:measurement_consumers_service_kt_jvm_grpc_proto", "//src/main/proto/wfa/measurement/api/v2alpha:measurement_kt_jvm_proto", "//src/main/proto/wfa/measurement/api/v2alpha:measurements_service_kt_jvm_grpc_proto", + "//src/main/proto/wfa/measurement/api/v2alpha:model_line_kt_jvm_proto", "@wfa_common_jvm//imports/java/picocli", "@wfa_common_jvm//imports/kotlin/com/google/protobuf/kotlin", "@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core", diff --git a/src/main/kotlin/org/wfanet/measurement/api/v2alpha/tools/Benchmark.kt b/src/main/kotlin/org/wfanet/measurement/api/v2alpha/tools/Benchmark.kt index 6878809a101..735eab1d32d 100644 --- a/src/main/kotlin/org/wfanet/measurement/api/v2alpha/tools/Benchmark.kt +++ b/src/main/kotlin/org/wfanet/measurement/api/v2alpha/tools/Benchmark.kt @@ -69,12 +69,6 @@ import org.wfanet.measurement.api.v2alpha.Measurement import org.wfanet.measurement.api.v2alpha.MeasurementConsumersGrpcKt.MeasurementConsumersCoroutineStub import org.wfanet.measurement.api.v2alpha.MeasurementKt.DataProviderEntryKt.value as dataProviderEntryValue import org.wfanet.measurement.api.v2alpha.MeasurementKt.dataProviderEntry -import org.wfanet.measurement.api.v2alpha.MeasurementSpec.Duration -import org.wfanet.measurement.api.v2alpha.MeasurementSpec.Impression -import org.wfanet.measurement.api.v2alpha.MeasurementSpec.ReachAndFrequency -import org.wfanet.measurement.api.v2alpha.MeasurementSpecKt.duration -import org.wfanet.measurement.api.v2alpha.MeasurementSpecKt.impression -import org.wfanet.measurement.api.v2alpha.MeasurementSpecKt.reachAndFrequency import org.wfanet.measurement.api.v2alpha.MeasurementSpecKt.vidSamplingInterval import org.wfanet.measurement.api.v2alpha.MeasurementsGrpcKt.MeasurementsCoroutineStub import org.wfanet.measurement.api.v2alpha.RequisitionSpecKt @@ -82,7 +76,6 @@ import org.wfanet.measurement.api.v2alpha.RequisitionSpecKt.EventGroupEntryKt.va import org.wfanet.measurement.api.v2alpha.RequisitionSpecKt.eventFilter import org.wfanet.measurement.api.v2alpha.RequisitionSpecKt.eventGroupEntry import org.wfanet.measurement.api.v2alpha.createMeasurementRequest -import org.wfanet.measurement.api.v2alpha.differentialPrivacyParams import org.wfanet.measurement.api.v2alpha.getDataProviderRequest import org.wfanet.measurement.api.v2alpha.getMeasurementConsumerRequest import org.wfanet.measurement.api.v2alpha.getMeasurementRequest @@ -132,12 +125,6 @@ private class ApiFlags { } class BaseFlags { - @CommandLine.Option( - names = ["--measurement-consumer"], - description = ["API resource name of the MeasurementConsumer"], - required = true - ) - lateinit var measurementConsumer: String @CommandLine.Option( names = ["--encryption-private-key-file"], @@ -148,29 +135,6 @@ class BaseFlags { val privateKeyHandle: PrivateKeyHandle by lazy { loadPrivateKey(encryptionPrivateKeyFile) } - @CommandLine.Option( - names = ["--private-key-der-file"], - description = ["Private key for MeasurementConsumer"], - required = true - ) - lateinit var privateKeyDerFile: File - - @set:CommandLine.Option( - names = ["--vid-sampling-start"], - description = ["Start point of vid sampling interval for first VID bucket"], - required = true, - ) - var vidSamplingStart by Delegates.notNull() - private set - - @set:CommandLine.Option( - names = ["--vid-sampling-width"], - description = ["Width of vid sampling interval"], - required = true, - ) - var vidSamplingWidth by Delegates.notNull() - private set - @set:CommandLine.Option( names = ["--vid-bucket-count"], description = ["Number of VID buckets to sample from"], @@ -202,206 +166,57 @@ class BaseFlags { ) var timeout by Delegates.notNull() private set - - @CommandLine.ArgGroup( - exclusive = true, - multiplicity = "1", - heading = "Specify one of the measurement types with its params\n" - ) - lateinit var measurementTypeParams: MeasurementTypeParams - - @CommandLine.ArgGroup(exclusive = false, multiplicity = "1..*", heading = "Add DataProviders\n") - lateinit var dataProviderInputs: List -} - -class ReachAndFrequencyParams { - @CommandLine.Option( - names = ["--reach-and-frequency"], - description = ["Measurement Type of ReachAndFrequency"], - required = true, - ) - var selected = false - private set - - @set:CommandLine.Option( - names = ["--reach-privacy-epsilon"], - description = ["Epsilon value of reach privacy params"], - required = true, - ) - var reachPrivacyEpsilon by Delegates.notNull() - private set - - @set:CommandLine.Option( - names = ["--reach-privacy-delta"], - description = ["Delta value of reach privacy params"], - required = true, - ) - var reachPrivacyDelta by Delegates.notNull() - private set - - @set:CommandLine.Option( - names = ["--frequency-privacy-epsilon"], - description = ["Epsilon value of frequency privacy params"], - required = true, - ) - var frequencyPrivacyEpsilon by Delegates.notNull() - private set - - @set:CommandLine.Option( - names = ["--frequency-privacy-delta"], - description = ["Epsilon value of frequency privacy params"], - required = true, - ) - var frequencyPrivacyDelta by Delegates.notNull() - private set - - @set:CommandLine.Option( - names = ["--max-frequency-for-reach"], - description = ["Maximum frequency per user when estimating reach"], - required = false, - defaultValue = "10", - ) - var maximumFrequencyPerUser by Delegates.notNull() - private set -} - -class ImpressionParams { - @CommandLine.Option( - names = ["--impression"], - description = ["Measurement Type of Impression"], - required = true, - ) - var selected = false - private set - - @set:CommandLine.Option( - names = ["--impression-privacy-epsilon"], - description = ["Epsilon value of impression privacy params"], - required = true, - ) - var privacyEpsilon by Delegates.notNull() - private set - - @set:CommandLine.Option( - names = ["--impression-privacy-delta"], - description = ["Epsilon value of impression privacy params"], - required = true, - ) - var privacyDelta by Delegates.notNull() - private set - - @set:CommandLine.Option( - names = ["--max-frequency"], - description = ["Maximum frequency per user"], - required = true, - ) - var maximumFrequencyPerUser by Delegates.notNull() - private set -} - -class DurationParams { - @CommandLine.Option( - names = ["--duration"], - description = ["Measurement Type of Duration"], - required = true, - ) - var selected = false - private set - - @set:CommandLine.Option( - names = ["--duration-privacy-epsilon"], - description = ["Epsilon value of duration privacy params"], - required = true, - ) - var privacyEpsilon by Delegates.notNull() - private set - - @set:CommandLine.Option( - names = ["--duration-privacy-delta"], - description = ["Epsilon value of duration privacy params"], - required = true, - ) - var privacyDelta by Delegates.notNull() - private set - - @set:CommandLine.Option( - names = ["--max-duration"], - description = ["Maximum watch duration per user"], - required = true, - ) - var maximumWatchDurationPerUser by Delegates.notNull() - private set -} - -class DataProviderInput { - @CommandLine.Option( - names = ["--data-provider"], - description = ["API resource name of the DataProvider"], - required = true, - ) - lateinit var name: String - private set - - @CommandLine.ArgGroup( - exclusive = false, - multiplicity = "1..*", - heading = "Add EventGroups for a DataProvider\n" - ) - lateinit var eventGroupInputs: List - private set -} - -class EventGroupInput { - @CommandLine.Option( - names = ["--event-group"], - description = ["API resource name of the EventGroup"], - required = true, - ) - lateinit var name: String - private set - - @CommandLine.Option( - names = ["--event-filter"], - description = ["Raw CEL expression of EventFilter"], - required = false, - defaultValue = "" - ) - lateinit var eventFilter: String - private set - - @CommandLine.Option( - names = ["--event-start-time"], - description = ["Start time of Event range in ISO 8601 format of UTC"], - required = true, - ) - lateinit var eventStartTime: Instant - private set - - @CommandLine.Option( - names = ["--event-end-time"], - description = ["End time of Event range in ISO 8601 format of UTC"], - required = true, - ) - lateinit var eventEndTime: Instant - private set } -class MeasurementTypeParams { +private fun getPopulationDataProviderEntry( + dataProviderStub: DataProvidersCoroutineStub, + dataProviderInput: + CreateMeasurementFlags.MeasurementParams.PopulationMeasurementParams.PopulationDataProviderInput, + measurementParams: CreateMeasurementFlags.MeasurementParams.PopulationMeasurementParams, + measurementConsumerSigningKey: SigningKeyHandle, + measurementEncryptionPublicKey: ByteString, + secureRandom: SecureRandom, + apiAuthenticationKey: String +): Measurement.DataProviderEntry { + return dataProviderEntry { + val requisitionSpec = requisitionSpec { + population = + RequisitionSpecKt.population { + interval = interval { + startTime = measurementParams.populationInputs.startTime.toProtoTime() + endTime = measurementParams.populationInputs.endTime.toProtoTime() + } + if (measurementParams.populationInputs.filter.isNotEmpty()) + filter = eventFilter { expression = measurementParams.populationInputs.filter } + } + this.measurementPublicKey = measurementEncryptionPublicKey + nonce = secureRandom.nextLong() + } - @CommandLine.ArgGroup( - exclusive = false, - heading = "Measurement type ReachAndFrequency and params\n" - ) - var reachAndFrequency = ReachAndFrequencyParams() - @CommandLine.ArgGroup(exclusive = false, heading = "Measurement type Impression and params\n") - var impression = ImpressionParams() - @CommandLine.ArgGroup(exclusive = false, heading = "Measurement type Duration and params\n") - var duration = DurationParams() + key = dataProviderInput.name + val dataProvider = + runBlocking(Dispatchers.IO) { + dataProviderStub + .withAuthenticationKey(apiAuthenticationKey) + .getDataProvider(getDataProviderRequest { name = dataProviderInput.name }) + } + value = dataProviderEntryValue { + dataProviderCertificate = dataProvider.certificate + dataProviderPublicKey = dataProvider.publicKey + encryptedRequisitionSpec = + encryptRequisitionSpec( + signRequisitionSpec(requisitionSpec, measurementConsumerSigningKey), + EncryptionPublicKey.parseFrom(dataProvider.publicKey.data) + ) + nonceHash = Hashing.hashSha256(requisitionSpec.nonce) + } + } } -private fun getDataProviderEntry( +private fun getEventDataProviderEntry( dataProviderStub: DataProvidersCoroutineStub, - dataProviderInput: DataProviderInput, + dataProviderInput: + CreateMeasurementFlags.MeasurementParams.EventMeasurementParams.EventDataProviderInput, measurementConsumerSigningKey: SigningKeyHandle, measurementEncryptionPublicKey: ByteString, secureRandom: SecureRandom, @@ -450,40 +265,6 @@ private fun getDataProviderEntry( } } -private fun getReachAndFrequency(measurementTypeParams: MeasurementTypeParams): ReachAndFrequency { - return reachAndFrequency { - reachPrivacyParams = differentialPrivacyParams { - epsilon = measurementTypeParams.reachAndFrequency.reachPrivacyEpsilon - delta = measurementTypeParams.reachAndFrequency.reachPrivacyDelta - } - frequencyPrivacyParams = differentialPrivacyParams { - epsilon = measurementTypeParams.reachAndFrequency.frequencyPrivacyEpsilon - delta = measurementTypeParams.reachAndFrequency.frequencyPrivacyDelta - } - maximumFrequencyPerUser = measurementTypeParams.reachAndFrequency.maximumFrequencyPerUser - } -} - -private fun getImpression(measurementTypeParams: MeasurementTypeParams): Impression { - return impression { - privacyParams = differentialPrivacyParams { - epsilon = measurementTypeParams.impression.privacyEpsilon - delta = measurementTypeParams.impression.privacyDelta - } - maximumFrequencyPerUser = measurementTypeParams.impression.maximumFrequencyPerUser - } -} - -private fun getDuration(measurementTypeParams: MeasurementTypeParams): Duration { - return duration { - privacyParams = differentialPrivacyParams { - epsilon = measurementTypeParams.duration.privacyEpsilon - delta = measurementTypeParams.duration.privacyDelta - } - maximumWatchDurationPerUser = measurementTypeParams.duration.maximumWatchDurationPerUser - } -} - private fun getMeasurementResult( resultPair: Measurement.ResultPair, privateKeyHandle: PrivateKeyHandle @@ -494,6 +275,7 @@ private fun getMeasurementResult( class Benchmark( val flags: BaseFlags, + private val createMeasurementFlags: CreateMeasurementFlags, val channel: ManagedChannel, val apiAuthenticationKey: String, val clock: Clock @@ -501,6 +283,9 @@ class Benchmark( private val secureRandom = SecureRandom.getInstance("SHA1PRNG") + private val eventMeasurementParams = + createMeasurementFlags.measurementParams.eventMeasurementParams + /** * The following data structure is used to track the status of each request and to store the * results that were received. @@ -551,13 +336,13 @@ class Benchmark( measurementConsumerStub .withAuthenticationKey(apiAuthenticationKey) .getMeasurementConsumer( - getMeasurementConsumerRequest { name = flags.measurementConsumer } + getMeasurementConsumerRequest { name = createMeasurementFlags.measurementConsumer } ) } val measurementConsumerCertificate = readCertificate(measurementConsumer.certificateDer) val measurementConsumerPrivateKey = readPrivateKey( - flags.privateKeyDerFile.readByteString(), + createMeasurementFlags.privateKeyDerFile.readByteString(), measurementConsumerCertificate.publicKey.algorithm ) val measurementConsumerSigningKey = @@ -572,42 +357,77 @@ class Benchmark( for (replica in 1..flags.repetitionCount) { val referenceId = "$referenceIdBase-$replica" - val vidSamplingStartForMeasurement = - flags.vidSamplingStart + - kotlin.random.Random.nextInt(0, flags.vidBucketCount).toFloat() * flags.vidSamplingWidth - - val measurement = measurement { - this.measurementConsumerCertificate = measurementConsumer.certificate - dataProviders += - flags.dataProviderInputs.map { - getDataProviderEntry( - dataProviderStub, - it, - measurementConsumerSigningKey, - measurementEncryptionPublicKey, - secureRandom, - apiAuthenticationKey - ) + + val measurement = + if (createMeasurementFlags.measurementParams.populationMeasurementParams.selected) { + measurement { + this.measurementConsumerCertificate = measurementConsumer.certificate + dataProviders += + getPopulationDataProviderEntry( + dataProviderStub, + createMeasurementFlags.measurementParams.populationMeasurementParams + .populationDataProviderInput, + createMeasurementFlags.measurementParams.populationMeasurementParams, + measurementConsumerSigningKey, + measurementEncryptionPublicKey, + secureRandom, + apiAuthenticationKey + ) + + val unsignedMeasurementSpec = measurementSpec { + measurementPublicKey = measurementEncryptionPublicKey + nonceHashes += this@measurement.dataProviders.map { it.value.nonceHash } + population = createMeasurementFlags.getPopulation() + if (createMeasurementFlags.modelLine.isNotEmpty()) + modelLine = createMeasurementFlags.modelLine + } + + this.measurementSpec = + signMeasurementSpec(unsignedMeasurementSpec, measurementConsumerSigningKey) + measurementReferenceId = referenceId } - val unsignedMeasurementSpec = measurementSpec { - measurementPublicKey = measurementEncryptionPublicKey - nonceHashes += this@measurement.dataProviders.map { it.value.nonceHash } - vidSamplingInterval = vidSamplingInterval { - start = vidSamplingStartForMeasurement - width = flags.vidSamplingWidth + } else { + val vidSamplingStartForMeasurement = + eventMeasurementParams.vidSamplingStart + + kotlin.random.Random.nextInt(0, flags.vidBucketCount).toFloat() * + eventMeasurementParams.vidSamplingWidth + measurement { + this.measurementConsumerCertificate = measurementConsumer.certificate + dataProviders += + eventMeasurementParams.eventDataProviderInputs.map { + getEventDataProviderEntry( + dataProviderStub, + it, + measurementConsumerSigningKey, + measurementEncryptionPublicKey, + secureRandom, + apiAuthenticationKey + ) + } + val unsignedMeasurementSpec = measurementSpec { + measurementPublicKey = measurementEncryptionPublicKey + nonceHashes += this@measurement.dataProviders.map { it.value.nonceHash } + vidSamplingInterval = vidSamplingInterval { + start = vidSamplingStartForMeasurement + width = eventMeasurementParams.vidSamplingWidth + } + if (eventMeasurementParams.eventMeasurementTypeParams.reachAndFrequency.selected) { + reachAndFrequency = createMeasurementFlags.getReachAndFrequency() + } else if (eventMeasurementParams.eventMeasurementTypeParams.impression.selected) { + impression = createMeasurementFlags.getImpression() + } else if (eventMeasurementParams.eventMeasurementTypeParams.duration.selected) { + duration = createMeasurementFlags.getDuration() + } + if (createMeasurementFlags.modelLine.isNotEmpty()) + modelLine = createMeasurementFlags.modelLine + } + + this.measurementSpec = + signMeasurementSpec(unsignedMeasurementSpec, measurementConsumerSigningKey) + measurementReferenceId = referenceId } - if (flags.measurementTypeParams.reachAndFrequency.selected) { - reachAndFrequency = getReachAndFrequency(flags.measurementTypeParams) - } else if (flags.measurementTypeParams.impression.selected) { - impression = getImpression(flags.measurementTypeParams) - } else duration = getDuration(flags.measurementTypeParams) } - this.measurementSpec = - signMeasurementSpec(unsignedMeasurementSpec, measurementConsumerSigningKey) - measurementReferenceId = referenceId - } - val task = MeasurementTask(replica, Instant.now(clock)) task.referenceId = referenceId @@ -664,7 +484,7 @@ class Benchmark( Measurement.State.SUCCEEDED -> { val result = getMeasurementResult(measurement.resultsList[0], flags.privateKeyHandle) task.result = result - // println ("Got result for task $iTask\n$measurement\n-----\n$result") + println("Got result for task $iTask\n$measurement\n-----\n$result") task.status = "success" } Measurement.State.FAILED -> { @@ -690,11 +510,13 @@ class Benchmark( private fun generateOutput(firstInstant: Instant) { File(flags.outputFile).printWriter().use { out -> out.print("replica,startTime,ackTime,computeTime,endTime,status,msg,") - if (flags.measurementTypeParams.reachAndFrequency.selected) { + if (createMeasurementFlags.measurementParams.populationMeasurementParams.selected) { + out.println("population") + } else if (eventMeasurementParams.eventMeasurementTypeParams.reachAndFrequency.selected) { out.println("reach,freq1,freq2,freq3,freq4,freq5") - } else if (flags.measurementTypeParams.impression.selected) { + } else if (eventMeasurementParams.eventMeasurementTypeParams.impression.selected) { out.println("impressions") - } else { + } else if (eventMeasurementParams.eventMeasurementTypeParams.duration.selected) { out.println("duration") } for (task in completedTasks) { @@ -710,7 +532,9 @@ class Benchmark( } out.print(",${task.elapsedTimeMillis / 1000.0},") out.print("${task.status},${task.errorMessage},") - if (flags.measurementTypeParams.reachAndFrequency.selected) { + if (createMeasurementFlags.measurementParams.populationMeasurementParams.selected) { + out.println("${task.result.population.value}") + } else if (eventMeasurementParams.eventMeasurementTypeParams.reachAndFrequency.selected) { var reach = 0L if (task.status == "success" && task.result.hasReach()) { reach = task.result.reach.value @@ -728,9 +552,9 @@ class Benchmark( out.print(",${frequencies[i - 1]}") } out.println() - } else if (flags.measurementTypeParams.impression.selected) { + } else if (eventMeasurementParams.eventMeasurementTypeParams.impression.selected) { out.println("${task.result.impression.value}") - } else { + } else if (eventMeasurementParams.eventMeasurementTypeParams.duration.selected) { out.println("${task.result.watchDuration.value.seconds}") } } @@ -771,6 +595,7 @@ class BenchmarkReport(val clock: Clock = Clock.systemUTC()) : Runnable { @CommandLine.Mixin private lateinit var tlsFlags: TlsFlags @CommandLine.Mixin private lateinit var apiFlags: ApiFlags @CommandLine.Mixin private lateinit var baseFlags: BaseFlags + @CommandLine.Mixin private lateinit var createMeasurementFlags: CreateMeasurementFlags @CommandLine.Option( names = ["--api-key"], @@ -791,7 +616,8 @@ class BenchmarkReport(val clock: Clock = Clock.systemUTC()) : Runnable { .withShutdownTimeout(JavaDuration.ofSeconds(1)) } override fun run() { - val benchmark = Benchmark(baseFlags, channel, apiAuthenticationKey, clock) + val benchmark = + Benchmark(baseFlags, createMeasurementFlags, channel, apiAuthenticationKey, clock) benchmark.generateBenchmarkReport() } } diff --git a/src/main/kotlin/org/wfanet/measurement/api/v2alpha/tools/CreateMeasurementFlags.kt b/src/main/kotlin/org/wfanet/measurement/api/v2alpha/tools/CreateMeasurementFlags.kt new file mode 100644 index 00000000000..f34de69c74b --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/api/v2alpha/tools/CreateMeasurementFlags.kt @@ -0,0 +1,401 @@ +// Copyright 2023 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package org.wfanet.measurement.api.v2alpha.tools + +import java.io.File +import java.time.Instant +import kotlin.properties.Delegates +import org.wfanet.measurement.api.v2alpha.MeasurementSpec +import org.wfanet.measurement.api.v2alpha.MeasurementSpecKt +import org.wfanet.measurement.api.v2alpha.differentialPrivacyParams +import picocli.CommandLine.ArgGroup +import picocli.CommandLine.Option + +class CreateMeasurementFlags { + + @Option( + names = ["--measurement-consumer"], + description = ["API resource name of the MeasurementConsumer"], + required = true + ) + lateinit var measurementConsumer: String + + @Option( + names = ["--request-id"], + description = ["ID of API request for idempotency"], + required = false, + defaultValue = "", + ) + lateinit var requestId: String + + @Option( + names = ["--private-key-der-file"], + description = ["Private key for MeasurementConsumer"], + required = true + ) + lateinit var privateKeyDerFile: File + + @Option( + names = ["--measurement-ref-id"], + description = ["Measurement reference id"], + required = false, + defaultValue = "" + ) + lateinit var measurementReferenceId: String + + @ArgGroup( + exclusive = true, + multiplicity = "1", + heading = "Specify either Event or Population Measurement with its params\n" + ) + lateinit var measurementParams: MeasurementParams + + @Option( + names = ["--model-line"], + description = ["API resource name of the ModelLine"], + required = false, + defaultValue = "", + ) + lateinit var modelLine: String + + class MeasurementParams { + @ArgGroup(exclusive = false, multiplicity = "1", heading = "Event Measurement and params\n") + var eventMeasurementParams = EventMeasurementParams() + + @ArgGroup( + exclusive = false, + multiplicity = "1", + heading = "Population Measurement and params\n" + ) + var populationMeasurementParams = PopulationMeasurementParams() + + class EventMeasurementParams { + class EventDataProviderInput { + @Option( + names = ["--event-data-provider"], + description = ["API resource name of the Event Data Provider"], + required = true, + ) + lateinit var name: String + private set + + @ArgGroup( + exclusive = false, + multiplicity = "1..*", + heading = "Add EventGroups for an Event Data Provider\n" + ) + lateinit var eventGroupInputs: List + private set + } + + class EventGroupInput { + @Option( + names = ["--event-group"], + description = ["API resource name of the EventGroup"], + required = true, + ) + lateinit var name: String + private set + + @Option( + names = ["--event-filter"], + description = ["Raw CEL expression of EventFilter"], + required = false, + defaultValue = "" + ) + lateinit var eventFilter: String + private set + + @Option( + names = ["--event-start-time"], + description = ["Start time of Event range in ISO 8601 format of UTC"], + required = true, + ) + lateinit var eventStartTime: Instant + private set + + @Option( + names = ["--event-end-time"], + description = ["End time of Event range in ISO 8601 format of UTC"], + required = true, + ) + lateinit var eventEndTime: Instant + private set + } + + @ArgGroup(exclusive = false, multiplicity = "1..*", heading = "Add Event Data Providers\n") + lateinit var eventDataProviderInputs: List + private set + + @set:Option( + names = ["--vid-sampling-start"], + description = ["Start point of vid sampling interval"], + required = true, + ) + var vidSamplingStart by Delegates.notNull() + private set + + @set:Option( + names = ["--vid-sampling-width"], + description = ["Width of vid sampling interval"], + required = true, + ) + var vidSamplingWidth by Delegates.notNull() + private set + + class EventMeasurementTypeParams { + class ReachAndFrequencyParams { + @Option( + names = ["--reach-and-frequency"], + description = ["Measurement Type of ReachAndFrequency"], + required = true, + ) + var selected = false + private set + + @set:Option( + names = ["--reach-privacy-epsilon"], + description = ["Epsilon value of reach privacy params"], + required = true, + ) + var reachPrivacyEpsilon by Delegates.notNull() + private set + + @set:Option( + names = ["--reach-privacy-delta"], + description = ["Delta value of reach privacy params"], + required = true, + ) + var reachPrivacyDelta by Delegates.notNull() + private set + + @set:Option( + names = ["--frequency-privacy-epsilon"], + description = ["Epsilon value of frequency privacy params"], + required = true, + ) + var frequencyPrivacyEpsilon by Delegates.notNull() + private set + + @set:Option( + names = ["--frequency-privacy-delta"], + description = ["Epsilon value of frequency privacy params"], + required = true, + ) + var frequencyPrivacyDelta by Delegates.notNull() + private set + + @set:Option( + names = ["--reach-max-frequency"], + description = ["Maximum frequency per user"], + required = false, + defaultValue = "10", + ) + var maximumFrequencyPerUser by Delegates.notNull() + private set + } + + class ImpressionParams { + @Option( + names = ["--impression"], + description = ["Measurement Type of Impression"], + required = true, + ) + var selected = false + private set + + @set:Option( + names = ["--impression-privacy-epsilon"], + description = ["Epsilon value of impression privacy params"], + required = true, + ) + var privacyEpsilon by Delegates.notNull() + private set + + @set:Option( + names = ["--impression-privacy-delta"], + description = ["Epsilon value of impression privacy params"], + required = true, + ) + var privacyDelta by Delegates.notNull() + private set + + @set:Option( + names = ["--impression-max-frequency"], + description = ["Maximum frequency per user"], + required = true, + ) + var maximumFrequencyPerUser by Delegates.notNull() + private set + } + + class DurationParams { + @Option( + names = ["--duration"], + description = ["Measurement Type of Duration"], + required = true, + ) + var selected = false + private set + + @set:Option( + names = ["--duration-privacy-epsilon"], + description = ["Epsilon value of duration privacy params"], + required = true, + ) + var privacyEpsilon by Delegates.notNull() + private set + + @set:Option( + names = ["--duration-privacy-delta"], + description = ["Epsilon value of duration privacy params"], + required = true, + ) + var privacyDelta by Delegates.notNull() + private set + + @set:Option( + names = ["--max-duration"], + description = ["Maximum watch duration per user"], + required = true, + ) + var maximumWatchDurationPerUser by Delegates.notNull() + private set + } + + @ArgGroup(exclusive = false, heading = "Measurement type ReachAndFrequency and params\n") + var reachAndFrequency = ReachAndFrequencyParams() + @ArgGroup(exclusive = false, heading = "Measurement type Impression and params\n") + var impression = ImpressionParams() + @ArgGroup(exclusive = false, heading = "Measurement type Duration and params\n") + var duration = DurationParams() + } + + @ArgGroup(exclusive = true, multiplicity = "1", heading = "Event Measurement and params\n") + var eventMeasurementTypeParams = EventMeasurementTypeParams() + } + class PopulationMeasurementParams { + class PopulationInput { + @Option( + names = ["--population-filter"], + description = ["Raw CEL expression of Population Filter"], + required = false, + defaultValue = "" + ) + lateinit var filter: String + private set + + @Option( + names = ["--population-start-time"], + description = ["Start time of Population range in ISO 8601 format of UTC"], + required = true, + ) + lateinit var startTime: Instant + private set + + @Option( + names = ["--population-end-time"], + description = ["End time of Population range in ISO 8601 format of UTC"], + required = true, + ) + lateinit var endTime: Instant + private set + } + + class PopulationDataProviderInput { + @Option( + names = ["--population-data-provider"], + description = ["API resource name of the Population Data Provider"], + required = true, + ) + lateinit var name: String + private set + } + + @ArgGroup(exclusive = false, heading = "Population Params\n") + lateinit var populationInputs: PopulationInput + private set + @ArgGroup(exclusive = false, heading = "Set Population Data Provider\n") + lateinit var populationDataProviderInput: PopulationDataProviderInput + + @Option( + names = ["--population"], + description = ["Population Measurement"], + required = true, + ) + var selected = false + private set + } + } + + fun getReachAndFrequency(): MeasurementSpec.ReachAndFrequency { + return MeasurementSpecKt.reachAndFrequency { + reachPrivacyParams = differentialPrivacyParams { + epsilon = + measurementParams.eventMeasurementParams.eventMeasurementTypeParams.reachAndFrequency + .reachPrivacyEpsilon + delta = + measurementParams.eventMeasurementParams.eventMeasurementTypeParams.reachAndFrequency + .reachPrivacyDelta + } + frequencyPrivacyParams = differentialPrivacyParams { + epsilon = + measurementParams.eventMeasurementParams.eventMeasurementTypeParams.reachAndFrequency + .frequencyPrivacyEpsilon + delta = + measurementParams.eventMeasurementParams.eventMeasurementTypeParams.reachAndFrequency + .frequencyPrivacyDelta + } + maximumFrequencyPerUser = + measurementParams.eventMeasurementParams.eventMeasurementTypeParams.reachAndFrequency + .maximumFrequencyPerUser + } + } + + fun getImpression(): MeasurementSpec.Impression { + return MeasurementSpecKt.impression { + privacyParams = differentialPrivacyParams { + epsilon = + measurementParams.eventMeasurementParams.eventMeasurementTypeParams.impression + .privacyEpsilon + delta = + measurementParams.eventMeasurementParams.eventMeasurementTypeParams.impression + .privacyDelta + } + maximumFrequencyPerUser = + measurementParams.eventMeasurementParams.eventMeasurementTypeParams.impression + .maximumFrequencyPerUser + } + } + + fun getDuration(): MeasurementSpec.Duration { + return MeasurementSpecKt.duration { + privacyParams = differentialPrivacyParams { + epsilon = + measurementParams.eventMeasurementParams.eventMeasurementTypeParams.duration + .privacyEpsilon + delta = + measurementParams.eventMeasurementParams.eventMeasurementTypeParams.duration.privacyDelta + } + maximumWatchDurationPerUser = + measurementParams.eventMeasurementParams.eventMeasurementTypeParams.duration + .maximumWatchDurationPerUser + } + } + + fun getPopulation(): MeasurementSpec.Population { + return MeasurementSpecKt.population {} + } +} diff --git a/src/main/kotlin/org/wfanet/measurement/api/v2alpha/tools/MeasurementSystem.kt b/src/main/kotlin/org/wfanet/measurement/api/v2alpha/tools/MeasurementSystem.kt index 8ae04a08e58..ce0bebc8dc3 100644 --- a/src/main/kotlin/org/wfanet/measurement/api/v2alpha/tools/MeasurementSystem.kt +++ b/src/main/kotlin/org/wfanet/measurement/api/v2alpha/tools/MeasurementSystem.kt @@ -31,7 +31,6 @@ import java.time.Clock import java.time.Duration as systemDuration import java.time.Instant import java.time.LocalDate -import kotlin.properties.Delegates import kotlinx.coroutines.CoroutineDispatcher import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.runBlocking @@ -53,12 +52,6 @@ import org.wfanet.measurement.api.v2alpha.MeasurementConsumerKey import org.wfanet.measurement.api.v2alpha.MeasurementConsumersGrpcKt.MeasurementConsumersCoroutineStub import org.wfanet.measurement.api.v2alpha.MeasurementKt.DataProviderEntryKt as DataProviderEntries import org.wfanet.measurement.api.v2alpha.MeasurementKt.dataProviderEntry -import org.wfanet.measurement.api.v2alpha.MeasurementSpec.Duration -import org.wfanet.measurement.api.v2alpha.MeasurementSpec.Impression -import org.wfanet.measurement.api.v2alpha.MeasurementSpec.ReachAndFrequency -import org.wfanet.measurement.api.v2alpha.MeasurementSpecKt.duration -import org.wfanet.measurement.api.v2alpha.MeasurementSpecKt.impression -import org.wfanet.measurement.api.v2alpha.MeasurementSpecKt.reachAndFrequency import org.wfanet.measurement.api.v2alpha.MeasurementSpecKt.vidSamplingInterval import org.wfanet.measurement.api.v2alpha.MeasurementsGrpcKt.MeasurementsCoroutineStub import org.wfanet.measurement.api.v2alpha.ModelLine @@ -99,7 +92,6 @@ import org.wfanet.measurement.api.v2alpha.dateInterval import org.wfanet.measurement.api.v2alpha.deleteModelOutageRequest import org.wfanet.measurement.api.v2alpha.deleteModelRolloutRequest import org.wfanet.measurement.api.v2alpha.deleteModelShardRequest -import org.wfanet.measurement.api.v2alpha.differentialPrivacyParams import org.wfanet.measurement.api.v2alpha.getCertificateRequest import org.wfanet.measurement.api.v2alpha.getDataProviderRequest import org.wfanet.measurement.api.v2alpha.getMeasurementConsumerRequest @@ -592,288 +584,66 @@ private class Measurements { class CreateMeasurement : Runnable { @ParentCommand private lateinit var parentCommand: Measurements - @Option( - names = ["--measurement-consumer"], - description = ["API resource name of the MeasurementConsumer"], - required = true - ) - private lateinit var measurementConsumer: String - - @Option( - names = ["--request-id"], - description = ["ID of API request for idempotency"], - required = false, - defaultValue = "", - ) - private lateinit var requestId: String - - @Option( - names = ["--private-key-der-file"], - description = ["Private key for MeasurementConsumer"], - required = true - ) - private lateinit var privateKeyDerFile: File - - @Option( - names = ["--measurement-ref-id"], - description = ["Measurement reference id"], - required = false, - defaultValue = "" - ) - private lateinit var measurementReferenceId: String - - @set:Option( - names = ["--vid-sampling-start"], - description = ["Start point of vid sampling interval"], - required = true, - ) - var vidSamplingStart by Delegates.notNull() - private set - - @set:Option( - names = ["--vid-sampling-width"], - description = ["Width of vid sampling interval"], - required = true, - ) - var vidSamplingWidth by Delegates.notNull() - private set - - @ArgGroup( - exclusive = true, - multiplicity = "1", - heading = "Specify one of the measurement types with its params\n" - ) - lateinit var measurementTypeParams: MeasurementTypeParams - - class MeasurementTypeParams { - class ReachAndFrequencyParams { - @Option( - names = ["--reach-and-frequency"], - description = ["Measurement Type of ReachAndFrequency"], - required = true, - ) - var selected = false - private set - - @set:Option( - names = ["--reach-privacy-epsilon"], - description = ["Epsilon value of reach privacy params"], - required = true, - ) - var reachPrivacyEpsilon by Delegates.notNull() - private set - - @set:Option( - names = ["--reach-privacy-delta"], - description = ["Delta value of reach privacy params"], - required = true, - ) - var reachPrivacyDelta by Delegates.notNull() - private set - - @set:Option( - names = ["--frequency-privacy-epsilon"], - description = ["Epsilon value of frequency privacy params"], - required = true, - ) - var frequencyPrivacyEpsilon by Delegates.notNull() - private set - - @set:Option( - names = ["--frequency-privacy-delta"], - description = ["Epsilon value of frequency privacy params"], - required = true, - ) - var frequencyPrivacyDelta by Delegates.notNull() - private set - - @set:Option( - names = ["--reach-max-frequency"], - description = ["Maximum frequency per user"], - required = false, - defaultValue = "10", - ) - var maximumFrequencyPerUser by Delegates.notNull() - private set - } - - class ImpressionParams { - @Option( - names = ["--impression"], - description = ["Measurement Type of Impression"], - required = true, - ) - var selected = false - private set - - @set:Option( - names = ["--impression-privacy-epsilon"], - description = ["Epsilon value of impression privacy params"], - required = true, - ) - var privacyEpsilon by Delegates.notNull() - private set + @ArgGroup(exclusive = false, multiplicity = "1", heading = "Create Measurement Flags\n") + lateinit var createMeasurementFlags: CreateMeasurementFlags - @set:Option( - names = ["--impression-privacy-delta"], - description = ["Epsilon value of impression privacy params"], - required = true, - ) - var privacyDelta by Delegates.notNull() - private set - - @set:Option( - names = ["--impression-max-frequency"], - description = ["Maximum frequency per user"], - required = true, - ) - var maximumFrequencyPerUser by Delegates.notNull() - private set - } - - class DurationParams { - @Option( - names = ["--duration"], - description = ["Measurement Type of Duration"], - required = true, - ) - var selected = false - private set - - @set:Option( - names = ["--duration-privacy-epsilon"], - description = ["Epsilon value of duration privacy params"], - required = true, - ) - var privacyEpsilon by Delegates.notNull() - private set - - @set:Option( - names = ["--duration-privacy-delta"], - description = ["Epsilon value of duration privacy params"], - required = true, - ) - var privacyDelta by Delegates.notNull() - private set - - @set:Option( - names = ["--max-duration"], - description = ["Maximum watch duration per user"], - required = true, - ) - var maximumWatchDurationPerUser by Delegates.notNull() - private set - } - - @ArgGroup(exclusive = false, heading = "Measurement type ReachAndFrequency and params\n") - var reachAndFrequency = ReachAndFrequencyParams() - @ArgGroup(exclusive = false, heading = "Measurement type Impression and params\n") - var impression = ImpressionParams() - @ArgGroup(exclusive = false, heading = "Measurement type Duration and params\n") - var duration = DurationParams() - } - - private fun getReachAndFrequency(): ReachAndFrequency { - return reachAndFrequency { - reachPrivacyParams = differentialPrivacyParams { - epsilon = measurementTypeParams.reachAndFrequency.reachPrivacyEpsilon - delta = measurementTypeParams.reachAndFrequency.reachPrivacyDelta - } - frequencyPrivacyParams = differentialPrivacyParams { - epsilon = measurementTypeParams.reachAndFrequency.frequencyPrivacyEpsilon - delta = measurementTypeParams.reachAndFrequency.frequencyPrivacyDelta - } - maximumFrequencyPerUser = measurementTypeParams.reachAndFrequency.maximumFrequencyPerUser - } - } + private val secureRandom = SecureRandom.getInstance("SHA1PRNG") - private fun getImpression(): Impression { - return impression { - privacyParams = differentialPrivacyParams { - epsilon = measurementTypeParams.impression.privacyEpsilon - delta = measurementTypeParams.impression.privacyDelta + private fun getPopulationDataProviderEntry( + populationDataProviderInput: + CreateMeasurementFlags.MeasurementParams.PopulationMeasurementParams.PopulationDataProviderInput, + populationMeasurementParams: + CreateMeasurementFlags.MeasurementParams.PopulationMeasurementParams, + measurementConsumerSigningKey: SigningKeyHandle, + measurementEncryptionPublicKey: ByteString + ): Measurement.DataProviderEntry { + return dataProviderEntry { + val requisitionSpec = requisitionSpec { + population = + RequisitionSpecKt.population { + interval = interval { + startTime = populationMeasurementParams.populationInputs.startTime.toProtoTime() + endTime = populationMeasurementParams.populationInputs.endTime.toProtoTime() + } + if (populationMeasurementParams.populationInputs.filter.isNotEmpty()) + filter = eventFilter { + expression = populationMeasurementParams.populationInputs.filter + } + } + this.measurementPublicKey = measurementEncryptionPublicKey + nonce = secureRandom.nextLong() } - maximumFrequencyPerUser = measurementTypeParams.impression.maximumFrequencyPerUser - } - } - private fun getDuration(): Duration { - return duration { - privacyParams = differentialPrivacyParams { - epsilon = measurementTypeParams.duration.privacyEpsilon - delta = measurementTypeParams.duration.privacyDelta - } - maximumWatchDurationPerUser = measurementTypeParams.duration.maximumWatchDurationPerUser + key = populationDataProviderInput.name + val dataProvider = + runBlocking(parentCommand.parentCommand.rpcDispatcher) { + parentCommand.dataProviderStub + .withAuthenticationKey(parentCommand.apiAuthenticationKey) + .getDataProvider(getDataProviderRequest { name = populationDataProviderInput.name }) + } + value = + DataProviderEntries.value { + dataProviderCertificate = dataProvider.certificate + dataProviderPublicKey = dataProvider.publicKey + encryptedRequisitionSpec = + encryptRequisitionSpec( + signRequisitionSpec(requisitionSpec, measurementConsumerSigningKey), + EncryptionPublicKey.parseFrom(dataProvider.publicKey.data) + ) + nonceHash = Hashing.hashSha256(requisitionSpec.nonce) + } } } - - @ArgGroup(exclusive = false, multiplicity = "1..*", heading = "Add DataProviders\n") - private lateinit var dataProviderInputs: List - - class DataProviderInput { - @Option( - names = ["--data-provider"], - description = ["API resource name of the DataProvider"], - required = true, - ) - lateinit var name: String - private set - - @ArgGroup( - exclusive = false, - multiplicity = "1..*", - heading = "Add EventGroups for a DataProvider\n" - ) - lateinit var eventGroupInputs: List - private set - } - - class EventGroupInput { - @Option( - names = ["--event-group"], - description = ["API resource name of the EventGroup"], - required = true, - ) - lateinit var name: String - private set - - @Option( - names = ["--event-filter"], - description = ["Raw CEL expression of EventFilter"], - required = false, - defaultValue = "" - ) - lateinit var eventFilter: String - private set - - @Option( - names = ["--event-start-time"], - description = ["Start time of Event range in ISO 8601 format of UTC"], - required = true, - ) - lateinit var eventStartTime: Instant - private set - - @Option( - names = ["--event-end-time"], - description = ["End time of Event range in ISO 8601 format of UTC"], - required = true, - ) - lateinit var eventEndTime: Instant - private set - } - - private val secureRandom = SecureRandom.getInstance("SHA1PRNG") - - private fun getDataProviderEntry( - dataProviderInput: DataProviderInput, + private fun getEventDataProviderEntry( + eventDataProviderInput: + CreateMeasurementFlags.MeasurementParams.EventMeasurementParams.EventDataProviderInput, measurementConsumerSigningKey: SigningKeyHandle, measurementEncryptionPublicKey: ByteString ): Measurement.DataProviderEntry { return dataProviderEntry { val requisitionSpec = requisitionSpec { val eventGroups = - dataProviderInput.eventGroupInputs.map { + eventDataProviderInput.eventGroupInputs.map { eventGroupEntry { key = it.name value = @@ -887,18 +657,18 @@ class CreateMeasurement : Runnable { } } } - this.eventGroups += eventGroups events = RequisitionSpecKt.events { this.eventGroups += eventGroups } + this.eventGroups += eventGroups this.measurementPublicKey = measurementEncryptionPublicKey nonce = secureRandom.nextLong() } - key = dataProviderInput.name + key = eventDataProviderInput.name val dataProvider = runBlocking(parentCommand.parentCommand.rpcDispatcher) { parentCommand.dataProviderStub .withAuthenticationKey(parentCommand.apiAuthenticationKey) - .getDataProvider(getDataProviderRequest { name = dataProviderInput.name }) + .getDataProvider(getDataProviderRequest { name = eventDataProviderInput.name }) } value = DataProviderEntries.value { @@ -915,16 +685,19 @@ class CreateMeasurement : Runnable { } override fun run() { + val measurementParams = createMeasurementFlags.measurementParams val measurementConsumer = runBlocking(parentCommand.parentCommand.rpcDispatcher) { parentCommand.measurementConsumerStub .withAuthenticationKey(parentCommand.apiAuthenticationKey) - .getMeasurementConsumer(getMeasurementConsumerRequest { name = measurementConsumer }) + .getMeasurementConsumer( + getMeasurementConsumerRequest { name = createMeasurementFlags.measurementConsumer } + ) } val measurementConsumerCertificate = readCertificate(measurementConsumer.certificateDer) val measurementConsumerPrivateKey = readPrivateKey( - privateKeyDerFile.readByteString(), + createMeasurementFlags.privateKeyDerFile.readByteString(), measurementConsumerCertificate.publicKey.algorithm ) val measurementConsumerSigningKey = @@ -934,28 +707,56 @@ class CreateMeasurement : Runnable { val measurement = measurement { this.measurementConsumerCertificate = measurementConsumer.certificate dataProviders += - dataProviderInputs.map { - getDataProviderEntry(it, measurementConsumerSigningKey, measurementEncryptionPublicKey) + if (measurementParams.populationMeasurementParams.selected) { + listOf( + getPopulationDataProviderEntry( + measurementParams.populationMeasurementParams.populationDataProviderInput, + measurementParams.populationMeasurementParams, + measurementConsumerSigningKey, + measurementEncryptionPublicKey + ) + ) + } else { + measurementParams.eventMeasurementParams.eventDataProviderInputs.map { + getEventDataProviderEntry( + it, + measurementConsumerSigningKey, + measurementEncryptionPublicKey + ) + } } val unsignedMeasurementSpec = measurementSpec { measurementPublicKey = measurementEncryptionPublicKey nonceHashes += this@measurement.dataProviders.map { it.value.nonceHash } - vidSamplingInterval = vidSamplingInterval { - start = vidSamplingStart - width = vidSamplingWidth - } - if (measurementTypeParams.reachAndFrequency.selected) { - reachAndFrequency = getReachAndFrequency() - } else if (measurementTypeParams.impression.selected) { - impression = getImpression() - } else if (measurementTypeParams.duration.selected) { - duration = getDuration() + if (!measurementParams.populationMeasurementParams.selected) { + vidSamplingInterval = vidSamplingInterval { + start = measurementParams.eventMeasurementParams.vidSamplingStart + width = measurementParams.eventMeasurementParams.vidSamplingWidth + } + if ( + measurementParams.eventMeasurementParams.eventMeasurementTypeParams.reachAndFrequency + .selected + ) { + reachAndFrequency = createMeasurementFlags.getReachAndFrequency() + } else if ( + measurementParams.eventMeasurementParams.eventMeasurementTypeParams.impression.selected + ) { + impression = createMeasurementFlags.getImpression() + } else if ( + measurementParams.eventMeasurementParams.eventMeasurementTypeParams.duration.selected + ) { + duration = createMeasurementFlags.getDuration() + } + } else if (measurementParams.populationMeasurementParams.selected) { + population = createMeasurementFlags.getPopulation() } + if (createMeasurementFlags.modelLine.isNotEmpty()) + this.modelLine = createMeasurementFlags.modelLine } this.measurementSpec = signMeasurementSpec(unsignedMeasurementSpec, measurementConsumerSigningKey) - measurementReferenceId = this@CreateMeasurement.measurementReferenceId + measurementReferenceId = createMeasurementFlags.measurementReferenceId } val response = @@ -966,7 +767,7 @@ class CreateMeasurement : Runnable { createMeasurementRequest { parent = measurementConsumer.name this.measurement = measurement - requestId = this@CreateMeasurement.requestId + requestId = createMeasurementFlags.requestId } ) } @@ -1067,6 +868,9 @@ class GetMeasurement : Runnable { "${result.watchDuration.value.seconds} seconds ${result.watchDuration.value.nanos} nanos" ) } + if (result.hasPopulation()) { + println("Population - ${result.population.value}") + } } override fun run() { @@ -1663,7 +1467,7 @@ private class ModelShards { fun create( @Option( names = ["--parent"], - description = ["API resource name of the parent DataProvider."], + description = ["API resource name of the parent Event Data Provider."], required = true, ) dataProviderName: String, @@ -1698,7 +1502,7 @@ private class ModelShards { fun list( @Option( names = ["--parent"], - description = ["API resource name of the parent DataProvider."], + description = ["API resource name of the parent Event Data Provider."], required = true, ) dataProviderName: String, diff --git a/src/test/kotlin/org/wfanet/measurement/api/v2alpha/tools/BenchmarkTest.kt b/src/test/kotlin/org/wfanet/measurement/api/v2alpha/tools/BenchmarkTest.kt index 93890b05d53..bbc9e539d14 100644 --- a/src/test/kotlin/org/wfanet/measurement/api/v2alpha/tools/BenchmarkTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/api/v2alpha/tools/BenchmarkTest.kt @@ -33,7 +33,6 @@ import java.time.Instant import java.time.ZoneId import kotlinx.coroutines.runBlocking import org.junit.After -import org.junit.Before import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.JUnit4 @@ -53,6 +52,7 @@ import org.wfanet.measurement.api.v2alpha.MeasurementSpec import org.wfanet.measurement.api.v2alpha.MeasurementSpecKt import org.wfanet.measurement.api.v2alpha.MeasurementSpecKt.duration import org.wfanet.measurement.api.v2alpha.MeasurementSpecKt.impression +import org.wfanet.measurement.api.v2alpha.MeasurementSpecKt.population import org.wfanet.measurement.api.v2alpha.MeasurementSpecKt.reachAndFrequency import org.wfanet.measurement.api.v2alpha.MeasurementsGrpcKt.MeasurementsCoroutineImplBase import org.wfanet.measurement.api.v2alpha.certificate @@ -140,7 +140,7 @@ private val AGGREGATOR_CERTIFICATE = certificate { x509Der = AGGREGATOR_CERTIFIC private const val MEASUREMENT_NAME = "$MEASUREMENT_CONSUMER_NAME/measurements/100" private val MEASUREMENT = measurement { name = MEASUREMENT_NAME } -private val SUCCEEDED_MEASUREMENT = measurement { +private val SUCCEEDED_REACH_AND_FREQUENCY_MEASUREMENT = measurement { name = MEASUREMENT_NAME state = Measurement.State.SUCCEEDED @@ -148,13 +148,9 @@ private val SUCCEEDED_MEASUREMENT = measurement { format = EncryptionPublicKey.Format.TINK_KEYSET data = MEASUREMENT_PUBLIC_KEY } - results += resultPair { - val result = result { reach = MeasurementKt.ResultKt.reach { value = 4096 } } - encryptedResult = getEncryptedResult(result, measurementPublicKey) - certificate = DATA_PROVIDER_CERTIFICATE_NAME - } results += resultPair { val result = result { + reach = ResultKt.reach { value = 4096 } frequency = MeasurementKt.ResultKt.frequency { relativeFrequencyDistribution.put(1, 1.0 / 6) @@ -165,11 +161,29 @@ private val SUCCEEDED_MEASUREMENT = measurement { encryptedResult = getEncryptedResult(result, measurementPublicKey) certificate = DATA_PROVIDER_CERTIFICATE_NAME } +} +private val SUCCEEDED_IMPRESSION_MEASUREMENT = measurement { + name = MEASUREMENT_NAME + state = Measurement.State.SUCCEEDED + + val measurementPublicKey = encryptionPublicKey { + format = EncryptionPublicKey.Format.TINK_KEYSET + data = MEASUREMENT_PUBLIC_KEY + } results += resultPair { val result = result { impression = ResultKt.impression { value = 4096 } } encryptedResult = getEncryptedResult(result, measurementPublicKey) certificate = DATA_PROVIDER_CERTIFICATE_NAME } +} +private val SUCCEEDED_DURATION_MEASUREMENT = measurement { + name = MEASUREMENT_NAME + state = Measurement.State.SUCCEEDED + + val measurementPublicKey = encryptionPublicKey { + format = EncryptionPublicKey.Format.TINK_KEYSET + data = MEASUREMENT_PUBLIC_KEY + } results += resultPair { val result = result { watchDuration = @@ -184,6 +198,20 @@ private val SUCCEEDED_MEASUREMENT = measurement { certificate = DATA_PROVIDER_CERTIFICATE_NAME } } +private val SUCCEEDED_POPULATION_MEASUREMENT = measurement { + name = MEASUREMENT_NAME + state = Measurement.State.SUCCEEDED + + val measurementPublicKey = encryptionPublicKey { + format = EncryptionPublicKey.Format.TINK_KEYSET + data = MEASUREMENT_PUBLIC_KEY + } + results += resultPair { + val result = result { population = ResultKt.population { value = 100 } } + encryptedResult = getEncryptedResult(result, measurementPublicKey) + certificate = DATA_PROVIDER_CERTIFICATE_NAME + } +} private fun getEncryptedResult( result: Measurement.Result, @@ -214,11 +242,7 @@ class BenchmarkTest { mockService() { onBlocking { getMeasurementConsumer(any()) }.thenReturn(MEASUREMENT_CONSUMER) } private val dataProvidersServiceMock: DataProvidersCoroutineImplBase = mockService() { onBlocking { getDataProvider(any()) }.thenReturn(DATA_PROVIDER) } - private val measurementsServiceMock: MeasurementsCoroutineImplBase = - mockService() { - onBlocking { createMeasurement(any()) }.thenReturn(MEASUREMENT) - onBlocking { getMeasurement(any()) }.thenReturn(SUCCEEDED_MEASUREMENT) - } + private lateinit var measurementsServiceMock: MeasurementsCoroutineImplBase private val certificatesServiceMock: CertificatesGrpcKt.CertificatesCoroutineImplBase = mockService() { onBlocking { getCertificate(any()) }.thenReturn(AGGREGATOR_CERTIFICATE) } @@ -228,7 +252,7 @@ class BenchmarkTest { get() = server.port private lateinit var server: CommonServer - @Before + fun initServer() { val services: List = listOf( @@ -264,6 +288,12 @@ class BenchmarkTest { @Test fun `Benchmark reach and frequency`() { + measurementsServiceMock = + mockService() { + onBlocking { createMeasurement(any()) }.thenReturn(MEASUREMENT) + onBlocking { getMeasurement(any()) }.thenReturn(SUCCEEDED_REACH_AND_FREQUENCY_MEASUREMENT) + } + initServer() val clock = Clock.fixed(Instant.parse(TIME_STRING_1), ZoneId.of("UTC")) val tempFile = Files.createTempFile("benchmarks-reach", ".csv") @@ -284,7 +314,7 @@ class BenchmarkTest { "--vid-sampling-width=0.2", "--private-key-der-file=$SECRETS_DIR/mc_cs_private.der", "--encryption-private-key-file=$SECRETS_DIR/mc_enc_private.tink", - "--data-provider=dataProviders/1", + "--event-data-provider=dataProviders/1", "--event-group=dataProviders/1/eventGroups/1", "--event-filter=abcd", "--event-start-time=$TIME_STRING_1", @@ -330,11 +360,20 @@ class BenchmarkTest { .isEqualTo( "replica,startTime,ackTime,computeTime,endTime,status,msg,reach,freq1,freq2,freq3,freq4,freq5" ) - assertThat(result[1]).isEqualTo("1,0.0,0.0,0.0,0.0,success,,4096,0.0,0.0,0.0,0.0,0.0") + assertThat(result[1]) + .isEqualTo( + "1,0.0,0.0,0.0,0.0,success,,4096,682.6666666666666,2048.0,1365.3333333333333,0.0,0.0" + ) } @Test fun `Benchmark impressions`() { + measurementsServiceMock = + mockService() { + onBlocking { createMeasurement(any()) }.thenReturn(MEASUREMENT) + onBlocking { getMeasurement(any()) }.thenReturn(SUCCEEDED_IMPRESSION_MEASUREMENT) + } + initServer() val clock = Clock.fixed(Instant.parse(TIME_STRING_1), ZoneId.of("UTC")) val tempFile = Files.createTempFile("benchmarks-impressions", ".csv") @@ -349,12 +388,12 @@ class BenchmarkTest { "--impression", "--impression-privacy-epsilon=0.015", "--impression-privacy-delta=0.0", - "--max-frequency=1000", + "--impression-max-frequency=1000", "--vid-sampling-start=0.1", "--vid-sampling-width=0.2", "--private-key-der-file=$SECRETS_DIR/mc_cs_private.der", "--encryption-private-key-file=$SECRETS_DIR/mc_enc_private.tink", - "--data-provider=dataProviders/1", + "--event-data-provider=dataProviders/1", "--event-group=dataProviders/1/eventGroups/1", "--event-filter=abcd", "--event-start-time=$TIME_STRING_1", @@ -394,11 +433,17 @@ class BenchmarkTest { assertThat(result.size).isEqualTo(2) assertThat(result[0]) .isEqualTo("replica,startTime,ackTime,computeTime,endTime,status,msg,impressions") - assertThat(result[1]).isEqualTo("1,0.0,0.0,0.0,0.0,success,,0") + assertThat(result[1]).isEqualTo("1,0.0,0.0,0.0,0.0,success,,4096") } @Test fun `Benchmark duration`() { + measurementsServiceMock = + mockService() { + onBlocking { createMeasurement(any()) }.thenReturn(MEASUREMENT) + onBlocking { getMeasurement(any()) }.thenReturn(SUCCEEDED_DURATION_MEASUREMENT) + } + initServer() val clock = Clock.fixed(Instant.parse(TIME_STRING_1), ZoneId.of("UTC")) val tempFile = Files.createTempFile("benchmarks-duration", ".csv") @@ -418,7 +463,7 @@ class BenchmarkTest { "--vid-sampling-width=0.2", "--private-key-der-file=$SECRETS_DIR/mc_cs_private.der", "--encryption-private-key-file=$SECRETS_DIR/mc_enc_private.tink", - "--data-provider=dataProviders/1", + "--event-data-provider=dataProviders/1", "--event-group=dataProviders/1/eventGroups/1", "--event-filter=abcd", "--event-start-time=$TIME_STRING_1", @@ -458,6 +503,56 @@ class BenchmarkTest { assertThat(result.size).isEqualTo(2) assertThat(result[0]) .isEqualTo("replica,startTime,ackTime,computeTime,endTime,status,msg,duration") - assertThat(result[1]).isEqualTo("1,0.0,0.0,0.0,0.0,success,,0") + assertThat(result[1]).isEqualTo("1,0.0,0.0,0.0,0.0,success,,100") + } + + @Test + fun `Benchmark population`() { + measurementsServiceMock = + mockService() { + onBlocking { createMeasurement(any()) }.thenReturn(MEASUREMENT) + onBlocking { getMeasurement(any()) }.thenReturn(SUCCEEDED_POPULATION_MEASUREMENT) + } + initServer() + val clock = Clock.fixed(Instant.parse(TIME_STRING_1), ZoneId.of("UTC")) + val tempFile = Files.createTempFile("benchmarks-population", ".csv") + + val args = + arrayOf( + "--tls-cert-file=$SECRETS_DIR/mc_tls.pem", + "--tls-key-file=$SECRETS_DIR/mc_tls.key", + "--cert-collection-file=$SECRETS_DIR/kingdom_root.pem", + "--kingdom-public-api-target=$HOST:$port", + "--api-key=$API_KEY", + "--measurement-consumer=measurementConsumers/777", + "--population", + "--private-key-der-file=$SECRETS_DIR/mc_cs_private.der", + "--encryption-private-key-file=$SECRETS_DIR/mc_enc_private.tink", + "--population-data-provider=dataProviders/1", + "--model-line=modelProviders/1/modelSuites/2/modelLines/3", + "--population-filter=abcd", + "--population-start-time=$TIME_STRING_1", + "--population-end-time=$TIME_STRING_2", + "--output-file=$tempFile", + ) + CommandLine(BenchmarkReport(clock)).execute(*args) + + val request = + captureFirst { + runBlocking { verify(measurementsServiceMock).createMeasurement(capture()) } + } + + val measurement = request.measurement + val measurementSpec = MeasurementSpec.parseFrom(measurement.measurementSpec.data) + assertThat(measurementSpec) + .comparingExpectedFieldsOnly() + .isEqualTo(measurementSpec { population = population {} }) + + val result = Files.readAllLines(tempFile) + + assertThat(result.size).isEqualTo(2) + assertThat(result[0]) + .isEqualTo("replica,startTime,ackTime,computeTime,endTime,status,msg,population") + assertThat(result[1]).isEqualTo("1,0.0,0.0,0.0,0.0,success,,100") } } diff --git a/src/test/kotlin/org/wfanet/measurement/api/v2alpha/tools/MeasurementSystemTest.kt b/src/test/kotlin/org/wfanet/measurement/api/v2alpha/tools/MeasurementSystemTest.kt index f4b443caa46..30deae4c12e 100644 --- a/src/test/kotlin/org/wfanet/measurement/api/v2alpha/tools/MeasurementSystemTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/api/v2alpha/tools/MeasurementSystemTest.kt @@ -86,6 +86,7 @@ import org.wfanet.measurement.api.v2alpha.MeasurementConsumersGrpcKt.Measurement import org.wfanet.measurement.api.v2alpha.MeasurementKt import org.wfanet.measurement.api.v2alpha.MeasurementKt.ResultKt.frequency import org.wfanet.measurement.api.v2alpha.MeasurementKt.ResultKt.impression +import org.wfanet.measurement.api.v2alpha.MeasurementKt.ResultKt.population import org.wfanet.measurement.api.v2alpha.MeasurementKt.ResultKt.reach import org.wfanet.measurement.api.v2alpha.MeasurementKt.ResultKt.watchDuration import org.wfanet.measurement.api.v2alpha.MeasurementKt.dataProviderEntry @@ -798,7 +799,7 @@ class MeasurementSystemTest { "--private-key-der-file=$SECRETS_DIR/mc_cs_private.der", "--measurement-ref-id=$measurementReferenceId", "--request-id=$requestId", - "--data-provider=dataProviders/1", + "--event-data-provider=dataProviders/1", "--event-group=dataProviders/1/eventGroups/1", "--event-filter=abcd", "--event-start-time=$TIME_STRING_1", @@ -806,7 +807,7 @@ class MeasurementSystemTest { "--event-group=dataProviders/1/eventGroups/2", "--event-start-time=$TIME_STRING_3", "--event-end-time=$TIME_STRING_4", - "--data-provider=dataProviders/2", + "--event-data-provider=dataProviders/2", "--event-group=dataProviders/2/eventGroups/1", "--event-filter=ijk", "--event-start-time=$TIME_STRING_5", @@ -999,7 +1000,7 @@ class MeasurementSystemTest { "--vid-sampling-start=0.1", "--vid-sampling-width=0.2", "--private-key-der-file=$SECRETS_DIR/mc_cs_private.der", - "--data-provider=dataProviders/1", + "--event-data-provider=dataProviders/1", "--event-group=dataProviders/1/eventGroups/1", "--event-filter=abcd", "--event-start-time=$TIME_STRING_1", @@ -1051,7 +1052,7 @@ class MeasurementSystemTest { "--vid-sampling-start=0.1", "--vid-sampling-width=0.2", "--private-key-der-file=$SECRETS_DIR/mc_cs_private.der", - "--data-provider=dataProviders/1", + "--event-data-provider=dataProviders/1", "--event-group=dataProviders/1/eventGroups/1", "--event-filter=abcd", "--event-start-time=$TIME_STRING_1", @@ -1087,6 +1088,72 @@ class MeasurementSystemTest { ) } + @Test + fun `measurements create calls CreateMeasurement with correct population params`() { + val args = + commonArgs + + arrayOf( + "measurements", + "--api-key=$AUTHENTICATION_KEY", + "create", + "--measurement-consumer=measurementConsumers/777", + "--model-line=modelProviders/1/modelSuites/2/modelLines/3", + "--population", + "--private-key-der-file=$SECRETS_DIR/mc_cs_private.der", + "--population-data-provider=dataProviders/1", + "--population-filter=abcd", + "--population-start-time=$TIME_STRING_1", + "--population-end-time=$TIME_STRING_2", + ) + callCli(args) + + val request = + captureFirst { + runBlocking { verify(measurementsServiceMock).createMeasurement(capture()) } + } + + val measurement = request.measurement + val measurementSpec = MeasurementSpec.parseFrom(measurement.measurementSpec.data) + assertThat(measurementSpec) + .comparingExpectedFieldsOnly() + .isEqualTo( + measurementSpec { + population = MeasurementSpecKt.population {} + modelLine = "modelProviders/1/modelSuites/2/modelLines/3" + } + ) + + // Verify first RequisitionSpec. + val signedRequisitionSpec = + decryptRequisitionSpec( + request.measurement.dataProvidersList.single().value.encryptedRequisitionSpec, + DATA_PROVIDER_PRIVATE_KEY_HANDLE + ) + val requisitionSpec = RequisitionSpec.parseFrom(signedRequisitionSpec.data) + verifyRequisitionSpec( + signedRequisitionSpec, + requisitionSpec, + measurementSpec, + MEASUREMENT_CONSUMER_CERTIFICATE, + TRUSTED_MEASUREMENT_CONSUMER_ISSUER + ) + assertThat(requisitionSpec) + .ignoringFields(RequisitionSpec.NONCE_FIELD_NUMBER) + .isEqualTo( + requisitionSpec { + measurementPublicKey = MEASUREMENT_CONSUMER.publicKey.data + population = + RequisitionSpecKt.population { + filter = RequisitionSpecKt.eventFilter { expression = "abcd" } + interval = interval { + startTime = Instant.parse(TIME_STRING_1).toProtoTime() + endTime = Instant.parse(TIME_STRING_2).toProtoTime() + } + } + } + ) + } + @Test fun `measurements list calls ListMeasurements with valid request`() { val args = @@ -1908,6 +1975,10 @@ class MeasurementSystemTest { EncryptionPublicKey.parseFrom(MEASUREMENT_CONSUMER.publicKey.data) } + /* + * Contains a measurement result for each measurement type + * TODO(@renjiezh) Have separate successful measurements for each measurement type + */ private val SUCCEEDED_MEASUREMENT: Measurement by lazy { val measurementPublicKey = MEASUREMENT_CONSUMER_ENCRYPTION_PUBLIC_KEY measurement { @@ -1947,6 +2018,11 @@ class MeasurementSystemTest { encryptedResult = getEncryptedResult(result, measurementPublicKey) certificate = DATA_PROVIDER_CERTIFICATE_NAME } + results += resultPair { + val result = result { population = population { value = 100 } } + encryptedResult = getEncryptedResult(result, measurementPublicKey) + certificate = DATA_PROVIDER_CERTIFICATE_NAME + } } } }