From 8324a9d657a26b579e979e0e2797d34e3f885ac0 Mon Sep 17 00:00:00 2001 From: Sanjay Vasandani Date: Thu, 25 May 2023 12:41:56 -0700 Subject: [PATCH] Specify request_id in CreateMeasurement and CreateEventGroup requests. --- .../api/v2alpha/tools/MeasurementSystem.kt | 19 +- .../loadtest/dataprovider/EdpSimulator.kt | 6 +- .../service/api/v1alpha/ReportsService.kt | 136 +++---- .../service/api/v2alpha/MetricsService.kt | 9 +- .../v2alpha/tools/MeasurementSystemTest.kt | 3 + .../service/api/v1alpha/ReportsServiceTest.kt | 348 +++++++----------- .../service/api/v2alpha/MetricsServiceTest.kt | 140 ++++--- 7 files changed, 309 insertions(+), 352 deletions(-) diff --git a/src/main/kotlin/org/wfanet/measurement/api/v2alpha/tools/MeasurementSystem.kt b/src/main/kotlin/org/wfanet/measurement/api/v2alpha/tools/MeasurementSystem.kt index c00e2e3242f..d77f7743ebf 100644 --- a/src/main/kotlin/org/wfanet/measurement/api/v2alpha/tools/MeasurementSystem.kt +++ b/src/main/kotlin/org/wfanet/measurement/api/v2alpha/tools/MeasurementSystem.kt @@ -533,6 +533,14 @@ class CreateMeasurement : Runnable { ) private lateinit var measurementConsumer: String + @Option( + names = ["--request-id"], + description = ["ID of API request for idempotency"], + required = false, + defaultValue = "", + ) + private lateinit var requestId: String + @Option( names = ["--private-key-der-file"], description = ["Private key for MeasurementConsumer"], @@ -546,7 +554,7 @@ class CreateMeasurement : Runnable { required = false, defaultValue = "" ) - private lateinit var measurementIdempotencyKey: String + private lateinit var measurementReferenceId: String @set:Option( names = ["--vid-sampling-start"], @@ -879,14 +887,19 @@ class CreateMeasurement : Runnable { this.measurementSpec = signMeasurementSpec(unsignedMeasurementSpec, measurementConsumerSigningKey) - measurementReferenceId = measurementIdempotencyKey + measurementReferenceId = this@CreateMeasurement.measurementReferenceId } val response = runBlocking(parentCommand.parentCommand.rpcDispatcher) { parentCommand.measurementStub .withAuthenticationKey(parentCommand.apiAuthenticationKey) - .createMeasurement(createMeasurementRequest { this.measurement = measurement }) + .createMeasurement( + createMeasurementRequest { + this.measurement = measurement + requestId = this@CreateMeasurement.requestId + } + ) } println("Measurement Name: ${response.name}") } diff --git a/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/EdpSimulator.kt b/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/EdpSimulator.kt index 33d58652596..f59fc02f69f 100644 --- a/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/EdpSimulator.kt +++ b/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/EdpSimulator.kt @@ -206,12 +206,13 @@ class EdpSimulator( throw Exception("Error creating EventGroupMetadataDescriptor", e) } + val eventGroupReferenceId = + "${TestIdentifiers.EVENT_GROUP_REFERENCE_ID_PREFIX}-${edpData.displayName}" val request = createEventGroupRequest { parent = edpData.name eventGroup = eventGroup { this.measurementConsumer = measurementConsumerName - eventGroupReferenceId = - "${TestIdentifiers.EVENT_GROUP_REFERENCE_ID_PREFIX}-${edpData.displayName}" + this.eventGroupReferenceId = eventGroupReferenceId eventTemplates += eventTemplateNames.map { eventTemplate { type = it } } measurementConsumerCertificate = measurementConsumer.certificate measurementConsumerPublicKey = measurementConsumer.publicKey @@ -227,6 +228,7 @@ class EdpSimulator( EncryptionPublicKey.parseFrom(measurementConsumer.publicKey.data) ) } + requestId = eventGroupReferenceId } val eventGroup = try { diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportsService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportsService.kt index 3a33025ee82..aa6bb2b3add 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportsService.kt @@ -466,7 +466,7 @@ class ReportsService( signingConfig: SigningConfig, internalReportingSetMap: Map ) = coroutineScope { - val deferred = mutableListOf>() + val deferredMeasurements = mutableListOf>() for (metric in request.report.metricsList) { val internalMetricDetails = buildInternalMetricDetails(metric) @@ -481,34 +481,35 @@ class ReportsService( val setOperationResult: SetOperationResult = namedSetOperationResults[setOperationId] ?: continue - setOperationResult.weightedMeasurementInfoList.forEach { weightedMeasurementInfo -> - deferred.add( + for (weightedMeasurementInfo in setOperationResult.weightedMeasurementInfoList) { + deferredMeasurements.add( async { createMeasurement( - weightedMeasurementInfo, - reportInfo, - setOperationResult.internalMetricDetails, - measurementConsumer, - apiAuthenticationKey, - signingConfig, - internalReportingSetMap - ) + weightedMeasurementInfo, + reportInfo, + setOperationResult.internalMetricDetails, + measurementConsumer, + apiAuthenticationKey, + signingConfig, + internalReportingSetMap + ) + .also { + weightedMeasurementInfo.kingdomMeasurementId = + checkNotNull(MeasurementKey.fromName(it.name)).measurementId + } } ) } } } - // map of kingdom measurementReferenceId to kingdom apiId - val measurementMap = mutableMapOf() - deferred.awaitAll().forEach { measurement -> + for (measurement in deferredMeasurements.awaitAll()) { try { - val apiId = MeasurementKey.fromName(measurement.name)!!.measurementId - measurementMap[measurement.measurementReferenceId] = apiId + val measurementKey = checkNotNull(MeasurementKey.fromName(measurement.name)) internalMeasurementsStub.createMeasurement( internalMeasurement { - this.measurementConsumerReferenceId = reportInfo.measurementConsumerReferenceId - this.measurementReferenceId = apiId + this.measurementConsumerReferenceId = measurementKey.measurementConsumerId + this.measurementReferenceId = measurementKey.measurementId state = InternalMeasurement.State.PENDING } ) @@ -522,27 +523,6 @@ class ReportsService( } } } - - for (metric in request.report.metricsList) { - val internalMetricDetails = buildInternalMetricDetails(metric) - - for (namedSetOperation in metric.setOperationsList) { - val setOperationId = - buildSetOperationId( - reportInfo.reportIdempotencyKey, - internalMetricDetails, - namedSetOperation.uniqueName, - ) - - val setOperationResult: SetOperationResult = - namedSetOperationResults[setOperationId] ?: continue - - setOperationResult.weightedMeasurementInfoList.forEach { weightedMeasurementInfo -> - weightedMeasurementInfo.kingdomMeasurementId = - measurementMap[weightedMeasurementInfo.reportingMeasurementId] - } - } - } } /** Creates a kingdom measurement for a [WeightedMeasurement]. */ @@ -579,7 +559,7 @@ class ReportsService( .createMeasurement(createMeasurementRequest) } catch (e: StatusException) { throw Exception( - "Unable to create the measurement [${createMeasurementRequest.measurement.name}].", + "Unable to create Measurement with request ID ${createMeasurementRequest.requestId}", e ) } @@ -1050,13 +1030,14 @@ class ReportsService( .forEach { (reportTimeInterval, weightedMeasurementInfos) -> measurementCalculations.add( InternalMetricKt.measurementCalculation { - this.timeInterval = reportTimeInterval.toInternal() + timeInterval = reportTimeInterval.toInternal() - weightedMeasurementInfos.forEach { + for (weightedMeasurementInfo in weightedMeasurementInfos) { weightedMeasurements += MeasurementCalculationKt.weightedMeasurement { - this.measurementReferenceId = it.kingdomMeasurementId!! - coefficient = it.weightedMeasurement.coefficient + measurementReferenceId = + checkNotNull(weightedMeasurementInfo.kingdomMeasurementId) + coefficient = weightedMeasurementInfo.weightedMeasurement.coefficient } } } @@ -1070,50 +1051,44 @@ class ReportsService( measurementConsumer: MeasurementConsumer, eventGroupEntriesByDataProvider: Map>, internalMetricDetails: InternalMetricDetails, - measurementReferenceId: String, + requestId: String, apiAuthenticationKey: String, signingConfig: SigningConfig, ): CreateMeasurementRequest { - val measurementConsumerReferenceId = - grpcRequireNotNull(MeasurementConsumerKey.fromName(measurementConsumer.name)) { - "Invalid measurement consumer name [${measurementConsumer.name}]" - } - .measurementConsumerId + grpcRequireNotNull(MeasurementConsumerKey.fromName(measurementConsumer.name)) { + "Invalid measurement consumer name [${measurementConsumer.name}]" + } - val measurementConsumerCertificate = readCertificate(signingConfig.signingCertificateDer) + val measurementConsumerCertificate: X509Certificate = + readCertificate(signingConfig.signingCertificateDer) val measurementConsumerSigningKey = SigningKeyHandle(measurementConsumerCertificate, signingConfig.signingPrivateKey) - val measurementEncryptionPublicKey = measurementConsumer.publicKey.data - - val measurementResourceName = - MeasurementKey(measurementConsumerReferenceId, measurementReferenceId).toName() - - val measurement = measurement { - name = measurementResourceName - this.measurementConsumerCertificate = signingConfig.signingCertificateName + val measurementEncryptionPublicKey: ByteString = measurementConsumer.publicKey.data - dataProviders += - buildDataProviderEntries( - eventGroupEntriesByDataProvider, - measurementEncryptionPublicKey, - measurementConsumerSigningKey, - apiAuthenticationKey, - ) + return createMeasurementRequest { + measurement = measurement { + this.measurementConsumerCertificate = signingConfig.signingCertificateName - val unsignedMeasurementSpec: MeasurementSpec = - buildUnsignedMeasurementSpec( - measurementEncryptionPublicKey, - dataProviders.map { it.value.nonceHash }, - internalMetricDetails, - ) + dataProviders += + buildDataProviderEntries( + eventGroupEntriesByDataProvider, + measurementEncryptionPublicKey, + measurementConsumerSigningKey, + apiAuthenticationKey, + ) - this.measurementSpec = - signMeasurementSpec(unsignedMeasurementSpec, measurementConsumerSigningKey) + val unsignedMeasurementSpec: MeasurementSpec = + buildUnsignedMeasurementSpec( + measurementEncryptionPublicKey, + dataProviders.map { it.value.nonceHash }, + internalMetricDetails, + ) - this.measurementReferenceId = measurementReferenceId + measurementSpec = + signMeasurementSpec(unsignedMeasurementSpec, measurementConsumerSigningKey) + } + this.requestId = requestId } - - return createMeasurementRequest { this.measurement = measurement } } /** @@ -1733,10 +1708,7 @@ private fun buildDurationMeasurementSpec( /** Converts a public [SetOperation.Type] to an [InternalSetOperation.Type]. */ private fun SetOperation.Type.toInternal(): InternalSetOperation.Type { - val source = this - - @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. - return when (source) { + return when (this) { SetOperation.Type.UNION -> InternalSetOperation.Type.UNION SetOperation.Type.INTERSECTION -> InternalSetOperation.Type.INTERSECTION SetOperation.Type.DIFFERENCE -> InternalSetOperation.Type.DIFFERENCE @@ -1984,7 +1956,6 @@ private fun InternalReport.toReport(): Report { /** Converts an [InternalReport.State] to a public [Report.State]. */ private fun InternalReport.State.toState(): Report.State { - @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. return when (this) { InternalReport.State.RUNNING -> Report.State.RUNNING InternalReport.State.SUCCEEDED -> Report.State.SUCCEEDED @@ -2096,7 +2067,6 @@ private fun InternalOperand.toOperand(): Operand { /** Converts an internal [InternalSetOperation.Type] to a public [SetOperation.Type]. */ private fun InternalSetOperation.Type.toType(): SetOperation.Type { - @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. return when (this) { InternalSetOperation.Type.UNION -> SetOperation.Type.UNION InternalSetOperation.Type.INTERSECTION -> SetOperation.Type.INTERSECTION diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt index a96d12d019a..08846a44870 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt @@ -345,13 +345,14 @@ class MetricsService( metricSpec ) - this.measurementSpec = + measurementSpec = signMeasurementSpec(unsignedMeasurementSpec, measurementConsumerSigningKey) - - this.measurementReferenceId = internalMeasurement.cmmsCreateMeasurementRequestId } - return createMeasurementRequest { this.measurement = measurement } + return createMeasurementRequest { + this.measurement = measurement + requestId = internalMeasurement.cmmsCreateMeasurementRequestId + } } /** Gets a [SigningKeyHandle] for a [MeasurementConsumerPrincipal]. */ diff --git a/src/test/kotlin/org/wfanet/measurement/api/v2alpha/tools/MeasurementSystemTest.kt b/src/test/kotlin/org/wfanet/measurement/api/v2alpha/tools/MeasurementSystemTest.kt index 1ffaa06fea4..5b106f9dd64 100644 --- a/src/test/kotlin/org/wfanet/measurement/api/v2alpha/tools/MeasurementSystemTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/api/v2alpha/tools/MeasurementSystemTest.kt @@ -559,6 +559,7 @@ class MeasurementSystemTest { @Test fun `measurements create calls CreateMeasurement with valid request`() { + val requestId = "foo" val args = commonArgs + arrayOf( @@ -574,6 +575,7 @@ class MeasurementSystemTest { "--measurement-consumer=measurementConsumers/777", "--private-key-der-file=$SECRETS_DIR/mc_cs_private.der", "--measurement-ref-id=9999", + "--request-id=$requestId", "--data-provider=dataProviders/1", "--event-group=dataProviders/1/eventGroups/1", "--event-filter=abcd", @@ -606,6 +608,7 @@ class MeasurementSystemTest { captureFirst { runBlocking { verify(measurementsServiceMock).createMeasurement(capture()) } } + assertThat(request.requestId).isEqualTo(requestId) val measurement = request.measurement // measurementSpec matches verifyMeasurementSpec( diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportsServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportsServiceTest.kt index 1fb73fcfdfb..c38de599c78 100644 --- a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportsServiceTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/ReportsServiceTest.kt @@ -24,6 +24,7 @@ import com.google.protobuf.kotlin.toByteStringUtf8 import com.google.protobuf.timestamp import com.google.protobuf.util.Timestamps import io.grpc.Status +import io.grpc.StatusException import io.grpc.StatusRuntimeException import java.nio.file.Paths import java.security.SecureRandom @@ -84,6 +85,7 @@ import org.wfanet.measurement.api.v2alpha.RequisitionSpecKt.eventFilter import org.wfanet.measurement.api.v2alpha.RequisitionSpecKt.eventGroupEntry import org.wfanet.measurement.api.v2alpha.certificate import org.wfanet.measurement.api.v2alpha.copy +import org.wfanet.measurement.api.v2alpha.createMeasurementRequest import org.wfanet.measurement.api.v2alpha.dataProvider import org.wfanet.measurement.api.v2alpha.differentialPrivacyParams import org.wfanet.measurement.api.v2alpha.encryptionPublicKey @@ -465,52 +467,23 @@ private const val IMPRESSION_SET_OPERATION_UNIQUE_NAME = "Impression Set Operati private const val WATCH_DURATION_SET_OPERATION_UNIQUE_NAME = "Watch Duration Set Operation" // Measurement IDs and names -private val REACH_MEASUREMENT_REFERENCE_ID = +private val REACH_MEASUREMENT_CREATE_REQUEST_ID = "$REACH_REPORT_IDEMPOTENCY_KEY-Reach-$REACH_SET_OPERATION_UNIQUE_NAME-$START_INSTANT-" + "$END_INSTANT-measurement-0" -private val REACH_MEASUREMENT_REFERENCE_ID_2 = - "$REACH_REPORT_IDEMPOTENCY_KEY-Reach-$REACH_SET_OPERATION_UNIQUE_NAME-$START_INSTANT-" + - "$END_INSTANT-measurement-1" -private val FREQUENCY_HISTOGRAM_MEASUREMENT_REFERENCE_ID = - "$FREQUENCY_HISTOGRAM_REPORT_IDEMPOTENCY_KEY-FrequencyHistogram-" + - "$FREQUENCY_HISTOGRAM_SET_OPERATION_UNIQUE_NAME-$START_INSTANT-$END_INSTANT-measurement-0" -private val IMPRESSION_MEASUREMENT_REFERENCE_ID = - "$IMPRESSION_REPORT_IDEMPOTENCY_KEY-ImpressionCount-$IMPRESSION_SET_OPERATION_UNIQUE_NAME" + - "-$START_INSTANT-$END_INSTANT-measurement-0" -private val WATCH_DURATION_MEASUREMENT_REFERENCE_ID = - "$WATCH_DURATION_REPORT_IDEMPOTENCY_KEY-WatchDuration-$WATCH_DURATION_SET_OPERATION_UNIQUE_NAME" + - "-$START_INSTANT-$END_INSTANT-measurement-0" - -private val REACH_MEASUREMENT_NAME = - MeasurementKey( - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId, - REACH_MEASUREMENT_REFERENCE_ID - ) - .toName() -private val REACH_MEASUREMENT_NAME_2 = - MeasurementKey( - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId, - REACH_MEASUREMENT_REFERENCE_ID_2 - ) - .toName() -private val FREQUENCY_HISTOGRAM_MEASUREMENT_NAME = - MeasurementKey( - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId, - FREQUENCY_HISTOGRAM_MEASUREMENT_REFERENCE_ID - ) - .toName() -private val IMPRESSION_MEASUREMENT_NAME = - MeasurementKey( - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId, - IMPRESSION_MEASUREMENT_REFERENCE_ID - ) - .toName() -private val WATCH_DURATION_MEASUREMENT_NAME = + +private val REACH_MEASUREMENT_KEY = MeasurementKey( - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId, - WATCH_DURATION_MEASUREMENT_REFERENCE_ID - ) - .toName() + MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId, + ExternalId(111).apiId.value + ) +private val REACH_MEASUREMENT_KEY_2 = + MeasurementKey(REACH_MEASUREMENT_KEY.measurementConsumerId, ExternalId(222).apiId.value) +private val FREQUENCY_HISTOGRAM_MEASUREMENT_KEY = + MeasurementKey(REACH_MEASUREMENT_KEY.measurementConsumerId, ExternalId(333).apiId.value) +private val IMPRESSION_MEASUREMENT_KEY = + MeasurementKey(REACH_MEASUREMENT_KEY.measurementConsumerId, ExternalId(444).apiId.value) +private val WATCH_DURATION_MEASUREMENT_KEY = + MeasurementKey(REACH_MEASUREMENT_KEY.measurementConsumerId, ExternalId(555).apiId.value) // Set operations private val INTERNAL_SET_OPERATION = @@ -619,16 +592,9 @@ private val WATCH_DURATION_LIST = WATCH_DURATION_SECOND_LIST.map { duration { se private val TOTAL_WATCH_DURATION = duration { seconds = WATCH_DURATION_SECOND_LIST.sum() } // Reach measurement -private val BASE_REACH_MEASUREMENT = - BASE_MEASUREMENT.copy { - name = REACH_MEASUREMENT_NAME - measurementReferenceId = REACH_MEASUREMENT_REFERENCE_ID - } +private val BASE_REACH_MEASUREMENT = BASE_MEASUREMENT.copy { name = REACH_MEASUREMENT_KEY.toName() } private val BASE_REACH_MEASUREMENT_2 = - BASE_MEASUREMENT.copy { - name = REACH_MEASUREMENT_NAME_2 - measurementReferenceId = REACH_MEASUREMENT_REFERENCE_ID_2 - } + BASE_MEASUREMENT.copy { name = REACH_MEASUREMENT_KEY_2.toName() } private val PENDING_REACH_MEASUREMENT = BASE_REACH_MEASUREMENT.copy { state = Measurement.State.COMPUTING } @@ -658,6 +624,17 @@ private val REACH_ONLY_MEASUREMENT_SPEC = measurementSpec { } } +private val REACH_MEASUREMENT_REQUEST = createMeasurementRequest { + measurement = + BASE_MEASUREMENT.copy { + dataProviders += + DATA_PROVIDER_KEYS_IN_SET_OPERATION.map { DATA_PROVIDER_ENTRIES.getValue(it) } + measurementSpec = + signMeasurementSpec(REACH_ONLY_MEASUREMENT_SPEC, MEASUREMENT_CONSUMER_SIGNING_KEY_HANDLE) + } + requestId = REACH_MEASUREMENT_CREATE_REQUEST_ID +} + private val SUCCEEDED_REACH_MEASUREMENT = BASE_REACH_MEASUREMENT.copy { dataProviders += DATA_PROVIDER_KEYS_IN_SET_OPERATION.map { DATA_PROVIDER_ENTRIES.getValue(it) } @@ -683,8 +660,8 @@ private val SUCCEEDED_REACH_MEASUREMENT = } private val INTERNAL_PENDING_REACH_MEASUREMENT = internalMeasurement { - measurementConsumerReferenceId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - measurementReferenceId = REACH_MEASUREMENT_REFERENCE_ID + measurementConsumerReferenceId = REACH_MEASUREMENT_KEY.measurementConsumerId + measurementReferenceId = REACH_MEASUREMENT_KEY.measurementId state = InternalMeasurement.State.PENDING } private val INTERNAL_SUCCEEDED_REACH_MEASUREMENT = @@ -702,10 +679,7 @@ private val INTERNAL_SUCCEEDED_REACH_MEASUREMENT = // Frequency histogram measurement private val BASE_REACH_FREQUENCY_HISTOGRAM_MEASUREMENT = - BASE_MEASUREMENT.copy { - name = FREQUENCY_HISTOGRAM_MEASUREMENT_NAME - measurementReferenceId = FREQUENCY_HISTOGRAM_MEASUREMENT_REFERENCE_ID - } + BASE_MEASUREMENT.copy { name = FREQUENCY_HISTOGRAM_MEASUREMENT_KEY.toName() } private val REACH_FREQUENCY_MEASUREMENT_SPEC = measurementSpec { measurementPublicKey = MEASUREMENT_CONSUMER_PUBLIC_KEY.toByteString() @@ -756,8 +730,8 @@ private val SUCCEEDED_FREQUENCY_HISTOGRAM_MEASUREMENT = } private val INTERNAL_PENDING_FREQUENCY_HISTOGRAM_MEASUREMENT = internalMeasurement { - measurementConsumerReferenceId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - measurementReferenceId = FREQUENCY_HISTOGRAM_MEASUREMENT_REFERENCE_ID + measurementConsumerReferenceId = FREQUENCY_HISTOGRAM_MEASUREMENT_KEY.measurementConsumerId + measurementReferenceId = FREQUENCY_HISTOGRAM_MEASUREMENT_KEY.measurementId state = InternalMeasurement.State.PENDING } @@ -776,10 +750,7 @@ private val INTERNAL_SUCCEEDED_FREQUENCY_HISTOGRAM_MEASUREMENT = // Impression measurement private val BASE_IMPRESSION_MEASUREMENT = - BASE_MEASUREMENT.copy { - name = IMPRESSION_MEASUREMENT_NAME - measurementReferenceId = IMPRESSION_MEASUREMENT_REFERENCE_ID - } + BASE_MEASUREMENT.copy { name = IMPRESSION_MEASUREMENT_KEY.toName() } private val IMPRESSION_MEASUREMENT_SPEC = measurementSpec { measurementPublicKey = MEASUREMENT_CONSUMER_PUBLIC_KEY.toByteString() @@ -831,8 +802,8 @@ private val SUCCEEDED_IMPRESSION_MEASUREMENT = } private val INTERNAL_PENDING_IMPRESSION_MEASUREMENT = internalMeasurement { - measurementConsumerReferenceId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - measurementReferenceId = IMPRESSION_MEASUREMENT_REFERENCE_ID + measurementConsumerReferenceId = IMPRESSION_MEASUREMENT_KEY.measurementConsumerId + measurementReferenceId = IMPRESSION_MEASUREMENT_KEY.measurementId state = InternalMeasurement.State.PENDING } @@ -847,10 +818,7 @@ private val INTERNAL_SUCCEEDED_IMPRESSION_MEASUREMENT = // Watch Duration measurement private val BASE_WATCH_DURATION_MEASUREMENT = - BASE_MEASUREMENT.copy { - name = WATCH_DURATION_MEASUREMENT_NAME - measurementReferenceId = WATCH_DURATION_MEASUREMENT_REFERENCE_ID - } + BASE_MEASUREMENT.copy { name = WATCH_DURATION_MEASUREMENT_KEY.toName() } private val PENDING_WATCH_DURATION_MEASUREMENT = BASE_WATCH_DURATION_MEASUREMENT.copy { state = Measurement.State.COMPUTING } @@ -906,8 +874,8 @@ private val SUCCEEDED_WATCH_DURATION_MEASUREMENT = } private val INTERNAL_PENDING_WATCH_DURATION_MEASUREMENT = internalMeasurement { - measurementConsumerReferenceId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - measurementReferenceId = WATCH_DURATION_MEASUREMENT_REFERENCE_ID + measurementConsumerReferenceId = WATCH_DURATION_MEASUREMENT_KEY.measurementConsumerId + measurementReferenceId = WATCH_DURATION_MEASUREMENT_KEY.measurementId state = InternalMeasurement.State.PENDING } private val INTERNAL_SUCCEEDED_WATCH_DURATION_MEASUREMENT = @@ -921,22 +889,22 @@ private val INTERNAL_SUCCEEDED_WATCH_DURATION_MEASUREMENT = // Weighted measurements private val WEIGHTED_REACH_MEASUREMENT = weightedMeasurement { - measurementReferenceId = REACH_MEASUREMENT_REFERENCE_ID + measurementReferenceId = REACH_MEASUREMENT_KEY.measurementId coefficient = 1 } private val WEIGHTED_FREQUENCY_HISTOGRAM_MEASUREMENT = weightedMeasurement { - measurementReferenceId = FREQUENCY_HISTOGRAM_MEASUREMENT_REFERENCE_ID + measurementReferenceId = FREQUENCY_HISTOGRAM_MEASUREMENT_KEY.measurementId coefficient = 1 } private val WEIGHTED_IMPRESSION_MEASUREMENT = weightedMeasurement { - measurementReferenceId = IMPRESSION_MEASUREMENT_REFERENCE_ID + measurementReferenceId = IMPRESSION_MEASUREMENT_KEY.measurementId coefficient = 1 } private val WEIGHTED_WATCH_DURATION_MEASUREMENT = weightedMeasurement { - measurementReferenceId = WATCH_DURATION_MEASUREMENT_REFERENCE_ID + measurementReferenceId = WATCH_DURATION_MEASUREMENT_KEY.measurementId coefficient = 1 } @@ -1097,7 +1065,7 @@ private val INTERNAL_PENDING_REACH_REPORT = internalReport { periodicTimeInterval = INTERNAL_PERIODIC_TIME_INTERVAL metrics.add(INTERNAL_REACH_METRIC) state = InternalReport.State.RUNNING - measurements.put(REACH_MEASUREMENT_REFERENCE_ID, INTERNAL_PENDING_REACH_MEASUREMENT) + measurements.put(REACH_MEASUREMENT_KEY.measurementId, INTERNAL_PENDING_REACH_MEASUREMENT) details = InternalReportKt.details { eventGroupFilters.putAll(EVENT_GROUP_FILTERS_MAP) } createTime = timestamp { seconds = 1000 } reportIdempotencyKey = REACH_REPORT_IDEMPOTENCY_KEY @@ -1105,7 +1073,7 @@ private val INTERNAL_PENDING_REACH_REPORT = internalReport { private val INTERNAL_SUCCEEDED_REACH_REPORT = INTERNAL_PENDING_REACH_REPORT.copy { state = InternalReport.State.SUCCEEDED - measurements.put(REACH_MEASUREMENT_REFERENCE_ID, INTERNAL_SUCCEEDED_REACH_MEASUREMENT) + measurements.put(REACH_MEASUREMENT_KEY.measurementId, INTERNAL_SUCCEEDED_REACH_MEASUREMENT) } // Internal reports of impression @@ -1115,7 +1083,10 @@ private val INTERNAL_PENDING_IMPRESSION_REPORT = internalReport { periodicTimeInterval = INTERNAL_PERIODIC_TIME_INTERVAL metrics.add(INTERNAL_IMPRESSION_METRIC) state = InternalReport.State.RUNNING - measurements.put(IMPRESSION_MEASUREMENT_REFERENCE_ID, INTERNAL_PENDING_IMPRESSION_MEASUREMENT) + measurements.put( + IMPRESSION_MEASUREMENT_KEY.measurementId, + INTERNAL_PENDING_IMPRESSION_MEASUREMENT + ) details = InternalReportKt.details { eventGroupFilters.putAll(EVENT_GROUP_FILTERS_MAP) } createTime = timestamp { seconds = 2000 } reportIdempotencyKey = IMPRESSION_REPORT_IDEMPOTENCY_KEY @@ -1123,7 +1094,10 @@ private val INTERNAL_PENDING_IMPRESSION_REPORT = internalReport { private val INTERNAL_SUCCEEDED_IMPRESSION_REPORT = INTERNAL_PENDING_IMPRESSION_REPORT.copy { state = InternalReport.State.SUCCEEDED - measurements.put(IMPRESSION_MEASUREMENT_REFERENCE_ID, INTERNAL_SUCCEEDED_IMPRESSION_MEASUREMENT) + measurements.put( + IMPRESSION_MEASUREMENT_KEY.measurementId, + INTERNAL_SUCCEEDED_IMPRESSION_MEASUREMENT + ) } // Internal reports of watch duration @@ -1134,7 +1108,7 @@ private val INTERNAL_PENDING_WATCH_DURATION_REPORT = internalReport { metrics.add(INTERNAL_WATCH_DURATION_METRIC) state = InternalReport.State.RUNNING measurements.put( - WATCH_DURATION_MEASUREMENT_REFERENCE_ID, + WATCH_DURATION_MEASUREMENT_KEY.measurementId, INTERNAL_PENDING_WATCH_DURATION_MEASUREMENT ) details = InternalReportKt.details { eventGroupFilters.putAll(EVENT_GROUP_FILTERS_MAP) } @@ -1145,7 +1119,7 @@ private val INTERNAL_SUCCEEDED_WATCH_DURATION_REPORT = INTERNAL_PENDING_WATCH_DURATION_REPORT.copy { state = InternalReport.State.SUCCEEDED measurements.put( - WATCH_DURATION_MEASUREMENT_REFERENCE_ID, + WATCH_DURATION_MEASUREMENT_KEY.measurementId, INTERNAL_SUCCEEDED_WATCH_DURATION_MEASUREMENT ) } @@ -1158,7 +1132,7 @@ private val INTERNAL_PENDING_FREQUENCY_HISTOGRAM_REPORT = internalReport { metrics.add(INTERNAL_FREQUENCY_HISTOGRAM_METRIC) state = InternalReport.State.RUNNING measurements.put( - FREQUENCY_HISTOGRAM_MEASUREMENT_REFERENCE_ID, + FREQUENCY_HISTOGRAM_MEASUREMENT_KEY.measurementId, INTERNAL_PENDING_FREQUENCY_HISTOGRAM_MEASUREMENT ) details = InternalReportKt.details { eventGroupFilters.putAll(EVENT_GROUP_FILTERS_MAP) } @@ -1169,7 +1143,7 @@ private val INTERNAL_SUCCEEDED_FREQUENCY_HISTOGRAM_REPORT = INTERNAL_PENDING_FREQUENCY_HISTOGRAM_REPORT.copy { state = InternalReport.State.SUCCEEDED measurements.put( - FREQUENCY_HISTOGRAM_MEASUREMENT_REFERENCE_ID, + FREQUENCY_HISTOGRAM_MEASUREMENT_KEY.measurementId, INTERNAL_SUCCEEDED_FREQUENCY_HISTOGRAM_MEASUREMENT ) } @@ -1429,33 +1403,25 @@ class ReportsServiceTest { captureFirst { runBlocking { verify(measurementsMock).createMeasurement(capture()) } } - val capturedMeasurement = capturedMeasurementRequest.measurement - val expectedMeasurement = - BASE_REACH_MEASUREMENT.copy { - dataProviders += - DATA_PROVIDER_KEYS_IN_SET_OPERATION.map { DATA_PROVIDER_ENTRIES.getValue(it) } - measurementSpec = - signMeasurementSpec(REACH_ONLY_MEASUREMENT_SPEC, MEASUREMENT_CONSUMER_SIGNING_KEY_HANDLE) - } - - assertThat(capturedMeasurement) + assertThat(capturedMeasurementRequest) .ignoringRepeatedFieldOrder() .ignoringFieldDescriptors( - Measurement.getDescriptor().findFieldByNumber(Measurement.MEASUREMENT_SPEC_FIELD_NUMBER), - Measurement.DataProviderEntry.Value.getDescriptor() - .findFieldByNumber(ENCRYPTED_REQUISITION_SPEC_FIELD_NUMBER), + MEASUREMENT_SPEC_FIELD_DESCRIPTOR, + ENCRYPTED_REQUISITION_SPEC_FIELD_DESCRIPTOR, ) - .isEqualTo(expectedMeasurement) + .isEqualTo(REACH_MEASUREMENT_REQUEST) verifyMeasurementSpec( - capturedMeasurement.measurementSpec, + capturedMeasurementRequest.measurement.measurementSpec, MEASUREMENT_CONSUMER_CERTIFICATE, TRUSTED_MEASUREMENT_CONSUMER_ISSUER ) - val measurementSpec = MeasurementSpec.parseFrom(capturedMeasurement.measurementSpec.data) + val measurementSpec = + MeasurementSpec.parseFrom(capturedMeasurementRequest.measurement.measurementSpec.data) assertThat(measurementSpec).isEqualTo(REACH_ONLY_MEASUREMENT_SPEC) - val dataProvidersList = capturedMeasurement.dataProvidersList.sortedBy { it.key } + val dataProvidersList = + capturedMeasurementRequest.measurement.dataProvidersList.sortedBy { it.key } dataProvidersList.map { dataProviderEntry -> val signedRequisitionSpec = @@ -1478,13 +1444,7 @@ class ReportsServiceTest { internalMeasurementsMock, InternalMeasurementsCoroutineImplBase::createMeasurement ) - .isEqualTo( - internalMeasurement { - measurementConsumerReferenceId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - measurementReferenceId = REACH_MEASUREMENT_REFERENCE_ID - state = InternalMeasurement.State.PENDING - } - ) + .isEqualTo(INTERNAL_PENDING_REACH_MEASUREMENT) // Verify proto argument of InternalReportsCoroutineImplBase::createReport verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::createReport) @@ -1500,9 +1460,8 @@ class ReportsServiceTest { } measurements += InternalCreateReportRequestKt.measurementKey { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - measurementReferenceId = REACH_MEASUREMENT_REFERENCE_ID + measurementConsumerReferenceId = REACH_MEASUREMENT_KEY.measurementConsumerId + measurementReferenceId = REACH_MEASUREMENT_KEY.measurementId } } ) @@ -1599,33 +1558,26 @@ class ReportsServiceTest { captureFirst { runBlocking { verify(measurementsMock).createMeasurement(capture()) } } - val capturedMeasurement = capturedMeasurementRequest.measurement - val expectedMeasurement = - BASE_REACH_MEASUREMENT.copy { - dataProviders += - DATA_PROVIDER_KEYS_IN_SET_OPERATION.map { DATA_PROVIDER_ENTRIES.getValue(it) } - measurementSpec = - signMeasurementSpec(REACH_ONLY_MEASUREMENT_SPEC, MEASUREMENT_CONSUMER_SIGNING_KEY_HANDLE) - } - assertThat(capturedMeasurement) + assertThat(capturedMeasurementRequest) .ignoringRepeatedFieldOrder() .ignoringFieldDescriptors( - Measurement.getDescriptor().findFieldByNumber(Measurement.MEASUREMENT_SPEC_FIELD_NUMBER), - Measurement.DataProviderEntry.Value.getDescriptor() - .findFieldByNumber(ENCRYPTED_REQUISITION_SPEC_FIELD_NUMBER), + MEASUREMENT_SPEC_FIELD_DESCRIPTOR, + ENCRYPTED_REQUISITION_SPEC_FIELD_DESCRIPTOR, ) - .isEqualTo(expectedMeasurement) + .isEqualTo(REACH_MEASUREMENT_REQUEST) verifyMeasurementSpec( - capturedMeasurement.measurementSpec, + capturedMeasurementRequest.measurement.measurementSpec, MEASUREMENT_CONSUMER_CERTIFICATE, TRUSTED_MEASUREMENT_CONSUMER_ISSUER ) - val measurementSpec = MeasurementSpec.parseFrom(capturedMeasurement.measurementSpec.data) + val measurementSpec = + MeasurementSpec.parseFrom(capturedMeasurementRequest.measurement.measurementSpec.data) assertThat(measurementSpec).isEqualTo(REACH_ONLY_MEASUREMENT_SPEC) - val dataProvidersList = capturedMeasurement.dataProvidersList.sortedBy { it.key } + val dataProvidersList = + capturedMeasurementRequest.measurement.dataProvidersList.sortedBy { it.key } dataProvidersList.map { dataProviderEntry -> val signedRequisitionSpec = @@ -1648,13 +1600,7 @@ class ReportsServiceTest { internalMeasurementsMock, InternalMeasurementsCoroutineImplBase::createMeasurement ) - .isEqualTo( - internalMeasurement { - measurementConsumerReferenceId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - measurementReferenceId = REACH_MEASUREMENT_REFERENCE_ID - state = InternalMeasurement.State.PENDING - } - ) + .isEqualTo(INTERNAL_PENDING_REACH_MEASUREMENT) // Verify proto argument of InternalReportsCoroutineImplBase::createReport verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::createReport) @@ -1670,9 +1616,8 @@ class ReportsServiceTest { } measurements += InternalCreateReportRequestKt.measurementKey { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - measurementReferenceId = REACH_MEASUREMENT_REFERENCE_ID + measurementConsumerReferenceId = REACH_MEASUREMENT_KEY.measurementConsumerId + measurementReferenceId = REACH_MEASUREMENT_KEY.measurementId } } ) @@ -1755,33 +1700,25 @@ class ReportsServiceTest { captureFirst { runBlocking { verify(measurementsMock).createMeasurement(capture()) } } - val capturedMeasurement = capturedMeasurementRequest.measurement - val expectedMeasurement = - BASE_REACH_MEASUREMENT.copy { - dataProviders += - DATA_PROVIDER_KEYS_IN_SET_OPERATION.map { DATA_PROVIDER_ENTRIES.getValue(it) } - measurementSpec = - signMeasurementSpec(REACH_ONLY_MEASUREMENT_SPEC, MEASUREMENT_CONSUMER_SIGNING_KEY_HANDLE) - } - - assertThat(capturedMeasurement) + assertThat(capturedMeasurementRequest) .ignoringRepeatedFieldOrder() .ignoringFieldDescriptors( - Measurement.getDescriptor().findFieldByNumber(Measurement.MEASUREMENT_SPEC_FIELD_NUMBER), - Measurement.DataProviderEntry.Value.getDescriptor() - .findFieldByNumber(ENCRYPTED_REQUISITION_SPEC_FIELD_NUMBER), + MEASUREMENT_SPEC_FIELD_DESCRIPTOR, + ENCRYPTED_REQUISITION_SPEC_FIELD_DESCRIPTOR, ) - .isEqualTo(expectedMeasurement) + .isEqualTo(REACH_MEASUREMENT_REQUEST) verifyMeasurementSpec( - capturedMeasurement.measurementSpec, + capturedMeasurementRequest.measurement.measurementSpec, MEASUREMENT_CONSUMER_CERTIFICATE, TRUSTED_MEASUREMENT_CONSUMER_ISSUER ) - val measurementSpec = MeasurementSpec.parseFrom(capturedMeasurement.measurementSpec.data) + val measurementSpec = + MeasurementSpec.parseFrom(capturedMeasurementRequest.measurement.measurementSpec.data) assertThat(measurementSpec).isEqualTo(REACH_ONLY_MEASUREMENT_SPEC) - val dataProvidersList = capturedMeasurement.dataProvidersList.sortedBy { it.key } + val dataProvidersList = + capturedMeasurementRequest.measurement.dataProvidersList.sortedBy { it.key } dataProvidersList.map { dataProviderEntry -> val signedRequisitionSpec = @@ -1804,13 +1741,7 @@ class ReportsServiceTest { internalMeasurementsMock, InternalMeasurementsCoroutineImplBase::createMeasurement ) - .isEqualTo( - internalMeasurement { - measurementConsumerReferenceId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - measurementReferenceId = REACH_MEASUREMENT_REFERENCE_ID - state = InternalMeasurement.State.PENDING - } - ) + .isEqualTo(INTERNAL_PENDING_REACH_MEASUREMENT) // Verify proto argument of InternalReportsCoroutineImplBase::createReport verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::createReport) @@ -1826,9 +1757,8 @@ class ReportsServiceTest { } measurements += InternalCreateReportRequestKt.measurementKey { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - measurementReferenceId = REACH_MEASUREMENT_REFERENCE_ID + measurementConsumerReferenceId = REACH_MEASUREMENT_KEY.measurementConsumerId + measurementReferenceId = REACH_MEASUREMENT_KEY.measurementId } } ) @@ -1854,11 +1784,11 @@ class ReportsServiceTest { source.metrics[0].namedSetOperationsList[0].measurementCalculationsList[0].copy { weightedMeasurements.clear() weightedMeasurements += weightedMeasurement { - measurementReferenceId = REACH_MEASUREMENT_REFERENCE_ID + measurementReferenceId = REACH_MEASUREMENT_KEY.measurementId coefficient = -1 } weightedMeasurements += weightedMeasurement { - measurementReferenceId = REACH_MEASUREMENT_REFERENCE_ID_2 + measurementReferenceId = REACH_MEASUREMENT_KEY_2.measurementId coefficient = 1 } } @@ -1968,15 +1898,13 @@ class ReportsServiceTest { } measurements += InternalCreateReportRequestKt.measurementKey { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - measurementReferenceId = REACH_MEASUREMENT_REFERENCE_ID + measurementConsumerReferenceId = REACH_MEASUREMENT_KEY.measurementConsumerId + measurementReferenceId = REACH_MEASUREMENT_KEY.measurementId } measurements += InternalCreateReportRequestKt.measurementKey { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - measurementReferenceId = REACH_MEASUREMENT_REFERENCE_ID_2 + measurementConsumerReferenceId = REACH_MEASUREMENT_KEY_2.measurementConsumerId + measurementReferenceId = REACH_MEASUREMENT_KEY_2.measurementId } } ) @@ -2741,8 +2669,8 @@ class ReportsServiceTest { @Test fun `createReport throws exception when internal createReport throws exception`() = runBlocking { - whenever(internalReportsMock.createReport(any())) - .thenThrow(StatusRuntimeException(Status.INVALID_ARGUMENT)) + val status = Status.INVALID_ARGUMENT.withDescription("Bad CreateReport request") + whenever(internalReportsMock.createReport(any())).thenThrow(StatusRuntimeException(status)) val request = createReportRequest { parent = MEASUREMENT_CONSUMERS.values.first().name @@ -2750,13 +2678,15 @@ class ReportsServiceTest { } val exception = - assertFailsWith(Exception::class) { + assertFailsWith { withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { runBlocking { service.createReport(request) } } } - val expectedExceptionDescription = "Unable to create a report in the reporting database." - assertThat(exception.message).isEqualTo(expectedExceptionDescription) + assertThat(exception.cause).isInstanceOf(StatusException::class.java) + val actualStatus = (exception.cause as StatusException).status + assertThat(actualStatus.code).isEqualTo(status.code) + assertThat(actualStatus.description).isEqualTo(status.description) } @Test @@ -2776,9 +2706,7 @@ class ReportsServiceTest { runBlocking { service.createReport(request) } } } - val expectedExceptionDescription = - "Unable to create the measurement [$REACH_MEASUREMENT_NAME]." - assertThat(exception.message).isEqualTo(expectedExceptionDescription) + assertThat(exception.message).contains(REACH_MEASUREMENT_CREATE_REQUEST_ID) } @Test @@ -2799,7 +2727,7 @@ class ReportsServiceTest { } } val expectedExceptionDescription = - "Unable to create the measurement [$REACH_MEASUREMENT_NAME] in the reporting database." + "Unable to create the measurement [${REACH_MEASUREMENT_KEY.toName()}] in the reporting database." assertThat(exception.message).isEqualTo(expectedExceptionDescription) } @@ -3293,7 +3221,7 @@ class ReportsServiceTest { } } val expectedExceptionDescription = - "Unable to retrieve the measurement [$REACH_MEASUREMENT_NAME]." + "Unable to retrieve the measurement [${REACH_MEASUREMENT_KEY.toName()}]." assertThat(exception.message).isEqualTo(expectedExceptionDescription) } @@ -3312,7 +3240,7 @@ class ReportsServiceTest { } } val expectedExceptionDescription = - "Unable to update the measurement [$REACH_MEASUREMENT_NAME] in the reporting database." + "Unable to update the measurement [${REACH_MEASUREMENT_KEY.toName()}] in the reporting database." assertThat(exception.message).isEqualTo(expectedExceptionDescription) } @@ -3348,7 +3276,7 @@ class ReportsServiceTest { } } val expectedExceptionDescription = - "Unable to update the measurement [$REACH_MEASUREMENT_NAME] in the reporting database." + "Unable to update the measurement [${REACH_MEASUREMENT_KEY.toName()}] in the reporting database." assertThat(exception.message).isEqualTo(expectedExceptionDescription) } @@ -3483,7 +3411,7 @@ class ReportsServiceTest { } ) verifyProtoArgument(measurementsMock, MeasurementsCoroutineImplBase::getMeasurement) - .isEqualTo(getMeasurementRequest { name = REACH_MEASUREMENT_NAME }) + .isEqualTo(getMeasurementRequest { name = REACH_MEASUREMENT_KEY.toName() }) verifyProtoArgument(internalReportsMock, ReportsCoroutineImplBase::getReport) .isEqualTo( getInternalReportRequest { @@ -3538,16 +3466,15 @@ class ReportsServiceTest { } ) verifyProtoArgument(measurementsMock, MeasurementsCoroutineImplBase::getMeasurement) - .isEqualTo(getMeasurementRequest { name = REACH_MEASUREMENT_NAME }) + .isEqualTo(getMeasurementRequest { name = REACH_MEASUREMENT_KEY.toName() }) verifyProtoArgument( internalMeasurementsMock, InternalMeasurementsCoroutineImplBase::setMeasurementFailure ) .isEqualTo( setMeasurementFailureRequest { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - measurementReferenceId = REACH_MEASUREMENT_REFERENCE_ID + measurementConsumerReferenceId = REACH_MEASUREMENT_KEY.measurementConsumerId + measurementReferenceId = REACH_MEASUREMENT_KEY.measurementId failure = InternalMeasurementKt.failure { reason = InternalMeasurement.Failure.Reason.REQUISITION_REFUSED @@ -3595,7 +3522,7 @@ class ReportsServiceTest { } ) verifyProtoArgument(measurementsMock, MeasurementsCoroutineImplBase::getMeasurement) - .isEqualTo(getMeasurementRequest { name = REACH_MEASUREMENT_NAME }) + .isEqualTo(getMeasurementRequest { name = REACH_MEASUREMENT_KEY.toName() }) verifyProtoArgument( internalMeasurementsMock, InternalMeasurementsCoroutineImplBase::setMeasurementResult @@ -3603,9 +3530,8 @@ class ReportsServiceTest { .usingDoubleTolerance(1e-12) .isEqualTo( setMeasurementResultRequest { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - measurementReferenceId = REACH_MEASUREMENT_REFERENCE_ID + measurementConsumerReferenceId = REACH_MEASUREMENT_KEY.measurementConsumerId + measurementReferenceId = REACH_MEASUREMENT_KEY.measurementId this.result = InternalMeasurementKt.result { reach = InternalMeasurementResultKt.reach { value = REACH_VALUE } @@ -3655,15 +3581,15 @@ class ReportsServiceTest { } ) verifyProtoArgument(measurementsMock, MeasurementsCoroutineImplBase::getMeasurement) - .isEqualTo(getMeasurementRequest { name = IMPRESSION_MEASUREMENT_NAME }) + .isEqualTo(getMeasurementRequest { name = IMPRESSION_MEASUREMENT_KEY.toName() }) verifyProtoArgument( internalMeasurementsMock, InternalMeasurementsCoroutineImplBase::setMeasurementResult ) .isEqualTo( setMeasurementResultRequest { - measurementConsumerReferenceId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - measurementReferenceId = IMPRESSION_MEASUREMENT_REFERENCE_ID + measurementConsumerReferenceId = IMPRESSION_MEASUREMENT_KEY.measurementConsumerId + measurementReferenceId = IMPRESSION_MEASUREMENT_KEY.measurementId this.result = InternalMeasurementKt.result { impression = InternalMeasurementResultKt.impression { value = TOTAL_IMPRESSION_VALUE } @@ -3710,15 +3636,15 @@ class ReportsServiceTest { } ) verifyProtoArgument(measurementsMock, MeasurementsCoroutineImplBase::getMeasurement) - .isEqualTo(getMeasurementRequest { name = WATCH_DURATION_MEASUREMENT_NAME }) + .isEqualTo(getMeasurementRequest { name = WATCH_DURATION_MEASUREMENT_KEY.toName() }) verifyProtoArgument( internalMeasurementsMock, InternalMeasurementsCoroutineImplBase::setMeasurementResult ) .isEqualTo( setMeasurementResultRequest { - measurementConsumerReferenceId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - measurementReferenceId = WATCH_DURATION_MEASUREMENT_REFERENCE_ID + measurementConsumerReferenceId = WATCH_DURATION_MEASUREMENT_KEY.measurementConsumerId + measurementReferenceId = WATCH_DURATION_MEASUREMENT_KEY.measurementId this.result = InternalMeasurementKt.result { watchDuration = @@ -3806,7 +3732,7 @@ class ReportsServiceTest { verifyProtoArgument(measurementsMock, MeasurementsCoroutineImplBase::getMeasurement) .comparingExpectedFieldsOnly() - .isEqualTo(getMeasurementRequest { name = WATCH_DURATION_MEASUREMENT_NAME }) + .isEqualTo(getMeasurementRequest { name = WATCH_DURATION_MEASUREMENT_KEY.toName() }) val internalReportCaptor: KArgumentCaptor = argumentCaptor() verifyBlocking(internalReportsMock, times(2)) { getReport(internalReportCaptor.capture()) } @@ -3842,16 +3768,15 @@ class ReportsServiceTest { assertThat(report).isEqualTo(SUCCEEDED_IMPRESSION_REPORT) verifyProtoArgument(measurementsMock, MeasurementsCoroutineImplBase::getMeasurement) - .isEqualTo(getMeasurementRequest { name = IMPRESSION_MEASUREMENT_NAME }) + .isEqualTo(getMeasurementRequest { name = IMPRESSION_MEASUREMENT_KEY.toName() }) verifyProtoArgument( internalMeasurementsMock, InternalMeasurementsCoroutineImplBase::setMeasurementResult ) .isEqualTo( setMeasurementResultRequest { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - measurementReferenceId = IMPRESSION_MEASUREMENT_REFERENCE_ID + measurementConsumerReferenceId = IMPRESSION_MEASUREMENT_KEY.measurementConsumerId + measurementReferenceId = IMPRESSION_MEASUREMENT_KEY.measurementId this.result = InternalMeasurementKt.result { impression = @@ -3906,16 +3831,15 @@ class ReportsServiceTest { assertThat(report).isEqualTo(PENDING_IMPRESSION_REPORT.copy { state = Report.State.FAILED }) verifyProtoArgument(measurementsMock, MeasurementsCoroutineImplBase::getMeasurement) - .isEqualTo(getMeasurementRequest { name = IMPRESSION_MEASUREMENT_NAME }) + .isEqualTo(getMeasurementRequest { name = IMPRESSION_MEASUREMENT_KEY.toName() }) verifyProtoArgument( internalMeasurementsMock, InternalMeasurementsCoroutineImplBase::setMeasurementFailure ) .isEqualTo( setMeasurementFailureRequest { - measurementConsumerReferenceId = - MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - measurementReferenceId = IMPRESSION_MEASUREMENT_REFERENCE_ID + measurementConsumerReferenceId = IMPRESSION_MEASUREMENT_KEY.measurementConsumerId + measurementReferenceId = IMPRESSION_MEASUREMENT_KEY.measurementId failure = InternalMeasurementKt.failure { reason = InternalMeasurement.Failure.Reason.REQUISITION_REFUSED @@ -4114,6 +4038,14 @@ class ReportsServiceTest { } ) } + + companion object { + private val MEASUREMENT_SPEC_FIELD_DESCRIPTOR = + Measurement.getDescriptor().findFieldByNumber(Measurement.MEASUREMENT_SPEC_FIELD_NUMBER) + private val ENCRYPTED_REQUISITION_SPEC_FIELD_DESCRIPTOR = + Measurement.DataProviderEntry.Value.getDescriptor() + .findFieldByNumber(ENCRYPTED_REQUISITION_SPEC_FIELD_NUMBER) + } } private fun EventGroupKey.toInternal(): InternalReportingSet.EventGroupKey { diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsServiceTest.kt index 464f0ef3f25..e0c7e588668 100644 --- a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsServiceTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsServiceTest.kt @@ -39,8 +39,10 @@ import org.junit.Rule import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.JUnit4 +import org.mockito.ArgumentMatcher import org.mockito.kotlin.KArgumentCaptor import org.mockito.kotlin.any +import org.mockito.kotlin.argThat import org.mockito.kotlin.argumentCaptor import org.mockito.kotlin.doReturn import org.mockito.kotlin.eq @@ -164,6 +166,7 @@ import org.wfanet.measurement.internal.reporting.v2.reportingSet as internalRepo import org.wfanet.measurement.internal.reporting.v2.streamMetricsRequest import org.wfanet.measurement.internal.reporting.v2.timeInterval as internalTimeInterval import org.wfanet.measurement.reporting.service.api.InMemoryEncryptionKeyPairStore +import org.wfanet.measurement.reporting.service.api.v2alpha.RequestIdMatcher.Companion.requestIdEq import org.wfanet.measurement.reporting.v2alpha.ListMetricsPageTokenKt.previousPageEnd import org.wfanet.measurement.reporting.v2alpha.ListMetricsRequest import org.wfanet.measurement.reporting.v2alpha.Metric @@ -742,9 +745,6 @@ private val REQUESTING_UNION_ALL_REACH_MEASUREMENT = }, MEASUREMENT_CONSUMER_SIGNING_KEY_HANDLE ) - - measurementReferenceId = - INTERNAL_PENDING_UNION_ALL_REACH_MEASUREMENT.cmmsCreateMeasurementRequestId } private val REQUESTING_UNION_ALL_BUT_LAST_PUBLISHER_REACH_MEASUREMENT = BASE_MEASUREMENT.copy { @@ -755,9 +755,6 @@ private val REQUESTING_UNION_ALL_BUT_LAST_PUBLISHER_REACH_MEASUREMENT = UNION_ALL_BUT_LAST_PUBLISHER_REACH_MEASUREMENT_SPEC, MEASUREMENT_CONSUMER_SIGNING_KEY_HANDLE ) - - measurementReferenceId = - INTERNAL_PENDING_UNION_ALL_BUT_LAST_PUBLISHER_REACH_MEASUREMENT.cmmsCreateMeasurementRequestId } private val PENDING_UNION_ALL_REACH_MEASUREMENT = @@ -840,9 +837,6 @@ private val REQUESTING_SINGLE_PUBLISHER_IMPRESSION_MEASUREMENT = SINGLE_PUBLISHER_IMPRESSION_MEASUREMENT_SPEC, MEASUREMENT_CONSUMER_SIGNING_KEY_HANDLE ) - - measurementReferenceId = - INTERNAL_PENDING_SINGLE_PUBLISHER_IMPRESSION_MEASUREMENT.cmmsCreateMeasurementRequestId } private val PENDING_SINGLE_PUBLISHER_IMPRESSION_MEASUREMENT = @@ -897,9 +891,6 @@ private val REQUESTING_UNION_ALL_WATCH_DURATION_MEASUREMENT = }, MEASUREMENT_CONSUMER_SIGNING_KEY_HANDLE ) - - measurementReferenceId = - INTERNAL_PENDING_UNION_ALL_WATCH_DURATION_MEASUREMENT.cmmsCreateMeasurementRequestId } private val PENDING_UNION_ALL_WATCH_DURATION_MEASUREMENT = @@ -1367,19 +1358,29 @@ class MetricsServiceTest { .thenReturn(pendingMeasurement) } - onBlocking { createMeasurement(any()) } - .thenAnswer { - val request = it.arguments[0] as CreateMeasurementRequest - mapOf( - PENDING_UNION_ALL_REACH_MEASUREMENT.measurementReferenceId to - PENDING_UNION_ALL_REACH_MEASUREMENT, - PENDING_UNION_ALL_BUT_LAST_PUBLISHER_REACH_MEASUREMENT.measurementReferenceId to - PENDING_UNION_ALL_BUT_LAST_PUBLISHER_REACH_MEASUREMENT, - PENDING_SINGLE_PUBLISHER_IMPRESSION_MEASUREMENT.measurementReferenceId to - PENDING_SINGLE_PUBLISHER_IMPRESSION_MEASUREMENT, + onBlocking { + createMeasurement( + requestIdEq(INTERNAL_PENDING_UNION_ALL_REACH_MEASUREMENT.cmmsCreateMeasurementRequestId) + ) + } + .thenReturn(PENDING_UNION_ALL_REACH_MEASUREMENT) + onBlocking { + createMeasurement( + requestIdEq( + INTERNAL_PENDING_UNION_ALL_BUT_LAST_PUBLISHER_REACH_MEASUREMENT + .cmmsCreateMeasurementRequestId ) - .getValue(request.measurement.measurementReferenceId) + ) } + .thenReturn(PENDING_UNION_ALL_BUT_LAST_PUBLISHER_REACH_MEASUREMENT) + onBlocking { + createMeasurement( + requestIdEq( + INTERNAL_PENDING_SINGLE_PUBLISHER_IMPRESSION_MEASUREMENT.cmmsCreateMeasurementRequestId + ) + ) + } + .thenReturn(PENDING_SINGLE_PUBLISHER_IMPRESSION_MEASUREMENT) } private val measurementConsumersMock: @@ -1494,16 +1495,19 @@ class MetricsServiceTest { assertThat(capturedMeasurementRequests) .ignoringRepeatedFieldOrder() .ignoringFieldDescriptors( - Measurement.getDescriptor().findFieldByNumber(Measurement.MEASUREMENT_SPEC_FIELD_NUMBER), - Measurement.DataProviderEntry.Value.getDescriptor() - .findFieldByNumber( - Measurement.DataProviderEntry.Value.ENCRYPTED_REQUISITION_SPEC_FIELD_NUMBER - ), + MEASUREMENT_SPEC_FIELD, + ENCRYPTED_REQUISITION_SPEC_FIELD, ) .containsExactly( - createMeasurementRequest { measurement = REQUESTING_UNION_ALL_REACH_MEASUREMENT }, + createMeasurementRequest { + measurement = REQUESTING_UNION_ALL_REACH_MEASUREMENT + requestId = INTERNAL_PENDING_UNION_ALL_REACH_MEASUREMENT.cmmsCreateMeasurementRequestId + }, createMeasurementRequest { measurement = REQUESTING_UNION_ALL_BUT_LAST_PUBLISHER_REACH_MEASUREMENT + requestId = + INTERNAL_PENDING_UNION_ALL_BUT_LAST_PUBLISHER_REACH_MEASUREMENT + .cmmsCreateMeasurementRequestId }, ) @@ -1606,15 +1610,14 @@ class MetricsServiceTest { assertThat(capturedMeasurementRequests) .ignoringRepeatedFieldOrder() .ignoringFieldDescriptors( - Measurement.getDescriptor().findFieldByNumber(Measurement.MEASUREMENT_SPEC_FIELD_NUMBER), - Measurement.DataProviderEntry.Value.getDescriptor() - .findFieldByNumber( - Measurement.DataProviderEntry.Value.ENCRYPTED_REQUISITION_SPEC_FIELD_NUMBER - ), + MEASUREMENT_SPEC_FIELD, + ENCRYPTED_REQUISITION_SPEC_FIELD, ) .containsExactly( createMeasurementRequest { measurement = REQUESTING_SINGLE_PUBLISHER_IMPRESSION_MEASUREMENT + requestId = + INTERNAL_PENDING_SINGLE_PUBLISHER_IMPRESSION_MEASUREMENT.cmmsCreateMeasurementRequestId }, ) @@ -1798,14 +1801,15 @@ class MetricsServiceTest { assertThat(capturedMeasurementRequests) .ignoringRepeatedFieldOrder() .ignoringFieldDescriptors( - Measurement.getDescriptor().findFieldByNumber(Measurement.MEASUREMENT_SPEC_FIELD_NUMBER), - Measurement.DataProviderEntry.Value.getDescriptor() - .findFieldByNumber( - Measurement.DataProviderEntry.Value.ENCRYPTED_REQUISITION_SPEC_FIELD_NUMBER - ), + MEASUREMENT_SPEC_FIELD, + ENCRYPTED_REQUISITION_SPEC_FIELD, ) .containsExactly( - createMeasurementRequest { measurement = requestingSinglePublisherImpressionMeasurement }, + createMeasurementRequest { + measurement = requestingSinglePublisherImpressionMeasurement + requestId = + INTERNAL_PENDING_SINGLE_PUBLISHER_IMPRESSION_MEASUREMENT.cmmsCreateMeasurementRequestId + }, ) capturedMeasurementRequests.forEach { capturedMeasurementRequest -> @@ -1893,16 +1897,19 @@ class MetricsServiceTest { assertThat(capturedMeasurementRequests) .ignoringRepeatedFieldOrder() .ignoringFieldDescriptors( - Measurement.getDescriptor().findFieldByNumber(Measurement.MEASUREMENT_SPEC_FIELD_NUMBER), - Measurement.DataProviderEntry.Value.getDescriptor() - .findFieldByNumber( - Measurement.DataProviderEntry.Value.ENCRYPTED_REQUISITION_SPEC_FIELD_NUMBER - ), + MEASUREMENT_SPEC_FIELD, + ENCRYPTED_REQUISITION_SPEC_FIELD, ) .containsExactly( - createMeasurementRequest { measurement = REQUESTING_UNION_ALL_REACH_MEASUREMENT }, + createMeasurementRequest { + measurement = REQUESTING_UNION_ALL_REACH_MEASUREMENT + requestId = INTERNAL_PENDING_UNION_ALL_REACH_MEASUREMENT.cmmsCreateMeasurementRequestId + }, createMeasurementRequest { measurement = REQUESTING_UNION_ALL_BUT_LAST_PUBLISHER_REACH_MEASUREMENT + requestId = + INTERNAL_PENDING_UNION_ALL_BUT_LAST_PUBLISHER_REACH_MEASUREMENT + .cmmsCreateMeasurementRequestId }, ) @@ -2702,19 +2709,24 @@ class MetricsServiceTest { assertThat(capturedMeasurementRequests) .ignoringRepeatedFieldOrder() .ignoringFieldDescriptors( - Measurement.getDescriptor().findFieldByNumber(Measurement.MEASUREMENT_SPEC_FIELD_NUMBER), - Measurement.DataProviderEntry.Value.getDescriptor() - .findFieldByNumber( - Measurement.DataProviderEntry.Value.ENCRYPTED_REQUISITION_SPEC_FIELD_NUMBER - ), + MEASUREMENT_SPEC_FIELD, + ENCRYPTED_REQUISITION_SPEC_FIELD, ) .containsExactly( - createMeasurementRequest { measurement = REQUESTING_UNION_ALL_REACH_MEASUREMENT }, + createMeasurementRequest { + measurement = REQUESTING_UNION_ALL_REACH_MEASUREMENT + requestId = INTERNAL_PENDING_UNION_ALL_REACH_MEASUREMENT.cmmsCreateMeasurementRequestId + }, createMeasurementRequest { measurement = REQUESTING_UNION_ALL_BUT_LAST_PUBLISHER_REACH_MEASUREMENT + requestId = + INTERNAL_PENDING_UNION_ALL_BUT_LAST_PUBLISHER_REACH_MEASUREMENT + .cmmsCreateMeasurementRequestId }, createMeasurementRequest { measurement = REQUESTING_SINGLE_PUBLISHER_IMPRESSION_MEASUREMENT + requestId = + INTERNAL_PENDING_SINGLE_PUBLISHER_IMPRESSION_MEASUREMENT.cmmsCreateMeasurementRequestId }, ) @@ -4324,6 +4336,30 @@ class MetricsServiceTest { assertThat(exception.status.description) .isEqualTo("At most $MAX_BATCH_SIZE metrics can be supported in a batch.") } + + companion object { + private val MEASUREMENT_SPEC_FIELD = + Measurement.getDescriptor().findFieldByNumber(Measurement.MEASUREMENT_SPEC_FIELD_NUMBER) + private val ENCRYPTED_REQUISITION_SPEC_FIELD = + Measurement.DataProviderEntry.Value.getDescriptor() + .findFieldByNumber( + Measurement.DataProviderEntry.Value.ENCRYPTED_REQUISITION_SPEC_FIELD_NUMBER + ) + } +} + +private class RequestIdMatcher(private val expected: String) : + ArgumentMatcher { + + override fun matches(actual: CreateMeasurementRequest?): Boolean { + return actual?.requestId == expected + } + + companion object { + fun requestIdEq(expected: String): CreateMeasurementRequest { + return argThat(RequestIdMatcher(expected)) + } + } } private fun EventGroupKey.toInternal(): InternalReportingSet.Primitive.EventGroupKey {