From 3f2699d33d27b53c81379e612cbc29f498017c34 Mon Sep 17 00:00:00 2001 From: Rieman Li Date: Mon, 13 Mar 2023 19:23:57 +0000 Subject: [PATCH 01/12] Add getMetric and batchGetMetric. --- .../service/api/v2alpha/MetricsService.kt | 116 ++++++++++++++++++ 1 file changed, 116 insertions(+) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt index 01045e9b85f..a3b4d377974 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt @@ -123,7 +123,10 @@ import org.wfanet.measurement.internal.reporting.v2.metricSpec as internalMetric import org.wfanet.measurement.reporting.service.api.EncryptionKeyPairStore import org.wfanet.measurement.reporting.v2alpha.BatchCreateMetricsRequest import org.wfanet.measurement.reporting.v2alpha.BatchCreateMetricsResponse +import org.wfanet.measurement.reporting.v2alpha.BatchGetMetricsRequest +import org.wfanet.measurement.reporting.v2alpha.BatchGetMetricsResponse import org.wfanet.measurement.reporting.v2alpha.CreateMetricRequest +import org.wfanet.measurement.reporting.v2alpha.GetMetricRequest import org.wfanet.measurement.reporting.v2alpha.ListMetricsRequest import org.wfanet.measurement.reporting.v2alpha.ListMetricsResponse import org.wfanet.measurement.reporting.v2alpha.Metric @@ -137,6 +140,7 @@ import org.wfanet.measurement.reporting.v2alpha.MetricResultKt.watchDurationResu import org.wfanet.measurement.reporting.v2alpha.MetricSpec import org.wfanet.measurement.reporting.v2alpha.MetricsGrpcKt.MetricsCoroutineImplBase import org.wfanet.measurement.reporting.v2alpha.batchCreateMetricsResponse +import org.wfanet.measurement.reporting.v2alpha.batchGetMetricsResponse import org.wfanet.measurement.reporting.v2alpha.listMetricsResponse import org.wfanet.measurement.reporting.v2alpha.metric import org.wfanet.measurement.reporting.v2alpha.metricResult @@ -853,6 +857,118 @@ class MetricsService( } } + override suspend fun getMetric(request: GetMetricRequest): Metric { + val metricKey = + grpcRequireNotNull(MetricKey.fromName(request.name)) { + "Metric name is either unspecified or invalid." + } + + val principal: ReportingPrincipal = principalFromCurrentContext + when (principal) { + is MeasurementConsumerPrincipal -> { + if (metricKey.cmmsMeasurementConsumerId != principal.resourceKey.measurementConsumerId) { + failGrpc(Status.PERMISSION_DENIED) { + "Cannot get a Metric for another MeasurementConsumer." + } + } + } + } + + val internalMetric: InternalMetric = + getInternalMetric(metricKey.cmmsMeasurementConsumerId, apiIdToExternalId(metricKey.metricId)) + + // Early exit when the metric is at a terminal state. + if (internalMetric.state != InternalMetric.State.RUNNING) { + return internalMetric.toMetric() + } + + val toBeSyncedInternalMeasurements: List = + internalMetric.weightedMeasurementsList.map { weightedMeasurement -> + weightedMeasurement.measurement + } + + measurementSupplier.syncInternalMeasurements( + toBeSyncedInternalMeasurements, + principal.config.apiKey, + principal, + ) + + return getInternalMetric( + metricKey.cmmsMeasurementConsumerId, + apiIdToExternalId(metricKey.metricId) + ) + .toMetric() + } + + override suspend fun batchGetMetrics(request: BatchGetMetricsRequest): BatchGetMetricsResponse { + grpcRequireNotNull(MeasurementConsumerKey.fromName(request.parent)) { + "Parent is either unspecified or invalid." + } + + val principal: ReportingPrincipal = principalFromCurrentContext + + when (principal) { + is MeasurementConsumerPrincipal -> { + if (request.parent != principal.resourceKey.toName()) { + failGrpc(Status.PERMISSION_DENIED) { + "Cannot get Metrics for another MeasurementConsumer." + } + } + } + } + + grpcRequire(request.namesList.isNotEmpty()) { "No metric name is provided." } + grpcRequire(request.namesList.size <= MAX_BATCH_SIZE) { + "At most $MAX_BATCH_SIZE metrics can be supported in a batch." + } + + val externalMetricIds: List = + request.namesList.map { metricName -> + val metricKey = + grpcRequireNotNull(MetricKey.fromName(metricName)) { + "Metric name is either unspecified or invalid." + } + apiIdToExternalId(metricKey.metricId) + } + + val internalMetrics: List = + batchGetInternalMetrics(principal.resourceKey.measurementConsumerId, externalMetricIds) + + val toBeSyncedInternalMeasurements: List = + internalMetrics + .filter { internalMetric -> internalMetric.state == InternalMetric.State.RUNNING } + .flatMap { internalMetric -> internalMetric.weightedMeasurementsList } + .map { weightedMeasurement -> weightedMeasurement.measurement } + + measurementSupplier.syncInternalMeasurements( + toBeSyncedInternalMeasurements, + principal.config.apiKey, + principal, + ) + + return batchGetMetricsResponse { + metrics += + batchGetInternalMetrics(principal.resourceKey.measurementConsumerId, externalMetricIds) + .map { it.toMetric() } + } + } + + private suspend fun getInternalMetric( + cmmsMeasurementConsumerId: String, + externalMetricId: Long, + ): org.wfanet.measurement.internal.reporting.v2alpha.Metric { + return try { + internalMetricsStub.getMetric( + internalGetMetricRequest { + this.cmmsMeasurementConsumerId = cmmsMeasurementConsumerId + this.externalMetricId = externalMetricId + } + ) + } catch (e: StatusException) { + throw Exception("Unable to get the metric from the reporting database.", e) + } + } + override suspend fun listMetrics(request: ListMetricsRequest): ListMetricsResponse { val listMetricsPageToken: ListMetricsPageToken = request.toListMetricsPageToken() From a8d43c27ddb12c6d32e297ec17f5eb62f8654062 Mon Sep 17 00:00:00 2001 From: Rieman Li Date: Tue, 14 Mar 2023 00:12:00 +0000 Subject: [PATCH 02/12] Add one unit test for getMetric. --- .../service/api/v2alpha/MetricsServiceTest.kt | 45 +++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsServiceTest.kt index 2822768ece8..da37d6be98a 100644 --- a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsServiceTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsServiceTest.kt @@ -177,6 +177,7 @@ import org.wfanet.measurement.reporting.v2alpha.batchCreateMetricsRequest import org.wfanet.measurement.reporting.v2alpha.batchCreateMetricsResponse import org.wfanet.measurement.reporting.v2alpha.copy import org.wfanet.measurement.reporting.v2alpha.createMetricRequest +import org.wfanet.measurement.reporting.v2alpha.getMetricRequest import org.wfanet.measurement.reporting.v2alpha.listMetricsRequest import org.wfanet.measurement.reporting.v2alpha.listMetricsResponse import org.wfanet.measurement.reporting.v2alpha.metric @@ -3855,6 +3856,50 @@ class MetricsServiceTest { assertThat(exception).hasMessageThat().contains(AGGREGATOR_CERTIFICATE.name) } + + @Test + fun `getMetric returns the metric with SUCCEEDED when the metric is already succeeded`() = + runBlocking { + whenever(internalMetricsMock.getMetric(any())) + .thenReturn(INTERNAL_SUCCEEDED_INCREMENTAL_REACH_METRIC) + + val request = getMetricRequest { name = SUCCEEDED_INCREMENTAL_REACH_METRIC.name } + + val result = + withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { + runBlocking { service.getMetric(request) } + } + + // Verify proto argument of internal MetricsCoroutineImplBase::getMetric + val getInternalMetricCaptor: KArgumentCaptor = argumentCaptor() + verifyBlocking(internalMetricsMock, times(1)) { getMetric(getInternalMetricCaptor.capture()) } + val capturedInternalGetMetricRequests = getInternalMetricCaptor.allValues + assertThat(capturedInternalGetMetricRequests) + .containsExactly( + internalGetMetricRequest { + cmmsMeasurementConsumerId = + INTERNAL_SUCCEEDED_INCREMENTAL_REACH_METRIC.cmmsMeasurementConsumerId + externalMetricId = INTERNAL_SUCCEEDED_INCREMENTAL_REACH_METRIC.externalMetricId + } + ) + + // Verify proto argument of internal MeasurementsCoroutineImplBase::batchSetMeasurementResults + val batchSetMeasurementResultsCaptor: KArgumentCaptor = + argumentCaptor() + verifyBlocking(internalMeasurementsMock, never()) { + batchSetMeasurementResults(batchSetMeasurementResultsCaptor.capture()) + } + + // Verify proto argument of internal + // MeasurementsCoroutineImplBase::batchSetMeasurementFailures + val batchSetMeasurementFailuresCaptor: KArgumentCaptor = + argumentCaptor() + verifyBlocking(internalMeasurementsMock, never()) { + batchSetMeasurementFailures(batchSetMeasurementFailuresCaptor.capture()) + } + + assertThat(result).isEqualTo(SUCCEEDED_INCREMENTAL_REACH_METRIC) + } } private fun EventGroupKey.toInternal(): InternalReportingSet.Primitive.EventGroupKey { From 5cc232f079d0504344680e6b711daa9ba8c7c86c Mon Sep 17 00:00:00 2001 From: Rieman Li Date: Wed, 15 Mar 2023 01:27:11 +0000 Subject: [PATCH 03/12] Add unit tests. --- .../service/api/v2alpha/MetricsServiceTest.kt | 298 +++++++++++++++++- 1 file changed, 297 insertions(+), 1 deletion(-) diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsServiceTest.kt index da37d6be98a..e3c54f8743b 100644 --- a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsServiceTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsServiceTest.kt @@ -70,6 +70,7 @@ import org.wfanet.measurement.api.v2alpha.MeasurementConsumersGrpcKt import org.wfanet.measurement.api.v2alpha.MeasurementKey import org.wfanet.measurement.api.v2alpha.MeasurementKt import org.wfanet.measurement.api.v2alpha.MeasurementKt.failure +import org.wfanet.measurement.api.v2alpha.MeasurementKt.resultPair import org.wfanet.measurement.api.v2alpha.MeasurementSpec import org.wfanet.measurement.api.v2alpha.MeasurementSpecKt import org.wfanet.measurement.api.v2alpha.MeasurementsGrpcKt @@ -212,6 +213,9 @@ private const val MAXIMUM_WATCH_DURATION_PER_USER = 4000 private const val DIFFERENTIAL_PRIVACY_DELTA = 1e-12 +private const val MAXIMUM_FREQUENCY_PER_USER = 10 +private const val MAXIMUM_WATCH_DURATION_PER_USER = 300 + private const val SECURE_RANDOM_OUTPUT_INT = 0 private const val SECURE_RANDOM_OUTPUT_LONG = 0L @@ -601,7 +605,7 @@ private val FREQUENCY_DISTRIBUTION = mapOf(1L to 1.0 / 6, 2L to 2.0 / 6, 3L to 3 private const val FIRST_PUBLISHER_IMPRESSION_VALUE = 100L private val IMPRESSION_VALUES = listOf(100L, 150L) private val TOTAL_IMPRESSION_VALUE = IMPRESSION_VALUES.sum() -private val WATCH_DURATION_SECOND_LIST = listOf(100L, 200L) +private val WATCH_DURATION_SECOND_LIST = listOf(100L, 200L, 300L) private val WATCH_DURATION_LIST = WATCH_DURATION_SECOND_LIST.map { duration { seconds = it } } private val TOTAL_WATCH_DURATION = duration { seconds = WATCH_DURATION_SECOND_LIST.sum() } @@ -672,6 +676,40 @@ private val INTERNAL_FAILED_SINGLE_PUBLISHER_IMPRESSION_MEASUREMENT = } } +// Internal cross-publisher watch duration measurements +private val INTERNAL_REQUESTING_UNION_ALL_WATCH_DURATION_MEASUREMENT = internalMeasurement { + cmmsMeasurementConsumerId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId + timeInterval = INTERNAL_TIME_INTERVAL + primitiveReportingSetBases += primitiveReportingSetBasis { + externalReportingSetId = INTERNAL_UNION_ALL_REPORTING_SET.externalReportingSetId + filters += listOf(METRIC_FILTER, PRIMITIVE_REPORTING_SET_FILTER) + } +} + +private val INTERNAL_PENDING_NOT_CREATED_UNION_ALL_WATCH_DURATION_MEASUREMENT = + INTERNAL_REQUESTING_UNION_ALL_WATCH_DURATION_MEASUREMENT.copy { + externalMeasurementId = 414L + cmmsCreateMeasurementRequestId = "UNION_ALL_WATCH_DURATION_MEASUREMENT" + state = InternalMeasurement.State.PENDING + } + +private val INTERNAL_PENDING_UNION_ALL_WATCH_DURATION_MEASUREMENT = + INTERNAL_PENDING_NOT_CREATED_UNION_ALL_WATCH_DURATION_MEASUREMENT.copy { + cmmsMeasurementId = externalIdToApiId(404L) + } + +private val INTERNAL_SUCCEEDED_UNION_ALL_WATCH_DURATION_MEASUREMENT = + INTERNAL_PENDING_UNION_ALL_WATCH_DURATION_MEASUREMENT.copy { + state = InternalMeasurement.State.SUCCEEDED + result = + InternalMeasurementKt.result { + watchDuration = + InternalMeasurementKt.ResultKt.watchDuration { value = TOTAL_WATCH_DURATION } + } + } + +// CMMs measurements + // CMMs incremental reach measurements private val UNION_ALL_BUT_LAST_PUBLISHER_REACH_MEASUREMENT_SPEC = measurementSpec { measurementPublicKey = MEASUREMENT_CONSUMER_PUBLIC_KEY.toByteString() @@ -1004,7 +1042,64 @@ private val INTERNAL_PENDING_SINGLE_PUBLISHER_IMPRESSION_METRIC = } } +private val INTERNAL_FAILED_SINGLE_PUBLISHER_IMPRESSION_METRIC = + INTERNAL_PENDING_SINGLE_PUBLISHER_IMPRESSION_METRIC.copy { state = InternalMetric.State.FAILED } + +// Internal Cross Publisher Watch Duration Metrics +private val INTERNAL_REQUESTING_CROSS_PUBLISHER_WATCH_DURATION_METRIC = internalMetric { + cmmsMeasurementConsumerId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId + metricIdempotencyKey = CROSS_PUBLISHER_WATCH_DURATION_METRIC_IDEMPOTENCY_KEY + externalReportingSetId = INTERNAL_UNION_ALL_REPORTING_SET.externalReportingSetId + timeInterval = INTERNAL_TIME_INTERVAL + metricSpec = internalMetricSpec { + watchDuration = + InternalMetricSpecKt.watchDurationParams { + maximumWatchDurationPerUser = MAXIMUM_WATCH_DURATION_PER_USER + } + } + weightedMeasurements += weightedMeasurement { + weight = 1 + measurement = INTERNAL_REQUESTING_UNION_ALL_WATCH_DURATION_MEASUREMENT + } + details = InternalMetricKt.details { filters += listOf(METRIC_FILTER) } +} + +private val INTERNAL_PENDING_INITIAL_CROSS_PUBLISHER_WATCH_DURATION_METRIC = + INTERNAL_REQUESTING_CROSS_PUBLISHER_WATCH_DURATION_METRIC.copy { + externalMetricId = 334L + createTime = Instant.now().toProtoTime() + weightedMeasurements.clear() + weightedMeasurements += weightedMeasurement { + weight = 1 + measurement = INTERNAL_PENDING_NOT_CREATED_UNION_ALL_WATCH_DURATION_MEASUREMENT + } + state = InternalMetric.State.RUNNING + } + +private val INTERNAL_PENDING_CROSS_PUBLISHER_WATCH_DURATION_METRIC = + INTERNAL_PENDING_INITIAL_CROSS_PUBLISHER_WATCH_DURATION_METRIC.copy { + weightedMeasurements.clear() + weightedMeasurements += weightedMeasurement { + weight = 1 + measurement = INTERNAL_PENDING_UNION_ALL_WATCH_DURATION_MEASUREMENT + } + } + +private val INTERNAL_SUCCEEDED_CROSS_PUBLISHER_WATCH_DURATION_METRIC = + INTERNAL_PENDING_CROSS_PUBLISHER_WATCH_DURATION_METRIC.copy { + state = InternalMetric.State.SUCCEEDED + details = + InternalMetricKt.details { + filters += this@copy.details.filtersList + result = internalMetricResult { + watchDuration = + InternalMetricResultKt.doubleResult { value = TOTAL_WATCH_DURATION.seconds.toDouble() } + } + } + } + // Public Metrics + // Incremental reach metrics private val REQUESTING_INCREMENTAL_REACH_METRIC = metric { reportingSet = INTERNAL_INCREMENTAL_REPORTING_SET.resourceName @@ -1080,6 +1175,41 @@ private val PENDING_SINGLE_PUBLISHER_IMPRESSION_METRIC = createTime = INTERNAL_PENDING_SINGLE_PUBLISHER_IMPRESSION_METRIC.createTime } +private val FAILED_SINGLE_PUBLISHER_IMPRESSION_METRIC = + PENDING_SINGLE_PUBLISHER_IMPRESSION_METRIC.copy { state = Metric.State.FAILED } + +// Cross publisher watch duration metrics +private val REQUESTING_CROSS_PUBLISHER_WATCH_DURATION_METRIC = metric { + reportingSet = INTERNAL_UNION_ALL_REPORTING_SET.resourceName + timeInterval = TIME_INTERVAL + metricSpec = WATCH_DURATION_METRIC_SPEC + filters += INTERNAL_PENDING_CROSS_PUBLISHER_WATCH_DURATION_METRIC.details.filtersList +} + +private val PENDING_CROSS_PUBLISHER_WATCH_DURATION_METRIC = + REQUESTING_CROSS_PUBLISHER_WATCH_DURATION_METRIC.copy { + name = + MetricKey( + MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId, + externalIdToApiId(INTERNAL_PENDING_CROSS_PUBLISHER_WATCH_DURATION_METRIC.externalMetricId) + ) + .toName() + state = Metric.State.RUNNING + createTime = INTERNAL_PENDING_CROSS_PUBLISHER_WATCH_DURATION_METRIC.createTime + } + +private val SUCCEEDED_CROSS_PUBLISHER_WATCH_DURATION_METRIC = + PENDING_CROSS_PUBLISHER_WATCH_DURATION_METRIC.copy { + state = Metric.State.SUCCEEDED + result = metricResult { + watchDuration = doubleResult { + value = + INTERNAL_SUCCEEDED_CROSS_PUBLISHER_WATCH_DURATION_METRIC.details.result.watchDuration + .value + } + } + } + @RunWith(JUnit4::class) class MetricsServiceTest { @@ -3900,6 +4030,172 @@ class MetricsServiceTest { assertThat(result).isEqualTo(SUCCEEDED_INCREMENTAL_REACH_METRIC) } + + @Test + fun `getMetric returns the metric with FAILED when the metric is already failed`() = runBlocking { + whenever(internalMetricsMock.getMetric(any())) + .thenReturn(INTERNAL_FAILED_SINGLE_PUBLISHER_IMPRESSION_METRIC) + + val request = getMetricRequest { name = FAILED_SINGLE_PUBLISHER_IMPRESSION_METRIC.name } + + val result = + withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { + runBlocking { service.getMetric(request) } + } + + // Verify proto argument of internal MetricsCoroutineImplBase::getMetric + val getInternalMetricCaptor: KArgumentCaptor = argumentCaptor() + verifyBlocking(internalMetricsMock, times(1)) { getMetric(getInternalMetricCaptor.capture()) } + val capturedInternalGetMetricRequests = getInternalMetricCaptor.allValues + assertThat(capturedInternalGetMetricRequests) + .containsExactly( + internalGetMetricRequest { + cmmsMeasurementConsumerId = + INTERNAL_FAILED_SINGLE_PUBLISHER_IMPRESSION_METRIC.cmmsMeasurementConsumerId + externalMetricId = INTERNAL_FAILED_SINGLE_PUBLISHER_IMPRESSION_METRIC.externalMetricId + } + ) + + // Verify proto argument of internal MeasurementsCoroutineImplBase::batchSetMeasurementResults + val batchSetMeasurementResultsCaptor: KArgumentCaptor = + argumentCaptor() + verifyBlocking(internalMeasurementsMock, never()) { + batchSetMeasurementResults(batchSetMeasurementResultsCaptor.capture()) + } + + // Verify proto argument of internal + // MeasurementsCoroutineImplBase::batchSetMeasurementFailures + val batchSetMeasurementFailuresCaptor: KArgumentCaptor = + argumentCaptor() + verifyBlocking(internalMeasurementsMock, never()) { + batchSetMeasurementFailures(batchSetMeasurementFailuresCaptor.capture()) + } + + assertThat(result).isEqualTo(FAILED_SINGLE_PUBLISHER_IMPRESSION_METRIC) + } + + @Test + fun `getMetric returns the metric with RUNNING when measurements are pending`() = runBlocking { + whenever(internalMetricsMock.getMetric(any())) + .thenReturn( + INTERNAL_PENDING_INCREMENTAL_REACH_METRIC, + INTERNAL_PENDING_INCREMENTAL_REACH_METRIC + ) + + val request = getMetricRequest { name = PENDING_INCREMENTAL_REACH_METRIC.name } + + val result = + withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { + runBlocking { service.getMetric(request) } + } + + // Verify proto argument of internal MetricsCoroutineImplBase::getMetric + val getInternalMetricCaptor: KArgumentCaptor = argumentCaptor() + verifyBlocking(internalMetricsMock, times(2)) { getMetric(getInternalMetricCaptor.capture()) } + val capturedInternalGetMetricRequests = getInternalMetricCaptor.allValues + assertThat(capturedInternalGetMetricRequests) + .containsExactly( + internalGetMetricRequest { + cmmsMeasurementConsumerId = + INTERNAL_PENDING_INCREMENTAL_REACH_METRIC.cmmsMeasurementConsumerId + externalMetricId = INTERNAL_PENDING_INCREMENTAL_REACH_METRIC.externalMetricId + }, + internalGetMetricRequest { + cmmsMeasurementConsumerId = + INTERNAL_PENDING_INCREMENTAL_REACH_METRIC.cmmsMeasurementConsumerId + externalMetricId = INTERNAL_PENDING_INCREMENTAL_REACH_METRIC.externalMetricId + } + ) + + // Verify proto argument of internal MeasurementsCoroutineImplBase::batchSetMeasurementResults + val batchSetMeasurementResultsCaptor: KArgumentCaptor = + argumentCaptor() + verifyBlocking(internalMeasurementsMock, never()) { + batchSetMeasurementResults(batchSetMeasurementResultsCaptor.capture()) + } + + // Verify proto argument of internal + // MeasurementsCoroutineImplBase::batchSetMeasurementFailures + val batchSetMeasurementFailuresCaptor: KArgumentCaptor = + argumentCaptor() + verifyBlocking(internalMeasurementsMock, never()) { + batchSetMeasurementFailures(batchSetMeasurementFailuresCaptor.capture()) + } + + assertThat(result).isEqualTo(PENDING_INCREMENTAL_REACH_METRIC) + } + + @Test + fun `getMetric returns the metric with SUCCEEDED when measurements are SUCCEEDED`() = + runBlocking { + whenever(internalMetricsMock.getMetric(any())) + .thenReturn( + INTERNAL_PENDING_CROSS_PUBLISHER_WATCH_DURATION_METRIC, + INTERNAL_SUCCEEDED_CROSS_PUBLISHER_WATCH_DURATION_METRIC + ) + whenever(measurementsMock.getMeasurement(any())) + .thenReturn( + SUCCEEDED_UNION_ALL_WATCH_DURATION_MEASUREMENT, + ) + whenever(internalMeasurementsMock.batchSetMeasurementResults(any())) + .thenReturn(flowOf(INTERNAL_SUCCEEDED_UNION_ALL_WATCH_DURATION_MEASUREMENT)) + + val request = getMetricRequest { name = PENDING_CROSS_PUBLISHER_WATCH_DURATION_METRIC.name } + + val result = + withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { + runBlocking { service.getMetric(request) } + } + + // Verify proto argument of internal MetricsCoroutineImplBase::getMetric + val getInternalMetricCaptor: KArgumentCaptor = argumentCaptor() + verifyBlocking(internalMetricsMock, times(2)) { getMetric(getInternalMetricCaptor.capture()) } + val capturedInternalGetMetricRequests = getInternalMetricCaptor.allValues + assertThat(capturedInternalGetMetricRequests) + .containsExactly( + internalGetMetricRequest { + cmmsMeasurementConsumerId = + INTERNAL_PENDING_CROSS_PUBLISHER_WATCH_DURATION_METRIC.cmmsMeasurementConsumerId + externalMetricId = + INTERNAL_PENDING_CROSS_PUBLISHER_WATCH_DURATION_METRIC.externalMetricId + }, + internalGetMetricRequest { + cmmsMeasurementConsumerId = + INTERNAL_PENDING_CROSS_PUBLISHER_WATCH_DURATION_METRIC.cmmsMeasurementConsumerId + externalMetricId = + INTERNAL_PENDING_CROSS_PUBLISHER_WATCH_DURATION_METRIC.externalMetricId + } + ) + + // Verify proto argument of internal MeasurementsCoroutineImplBase::batchSetMeasurementResults + val batchSetMeasurementResultsCaptor: KArgumentCaptor = + argumentCaptor() + verifyBlocking(internalMeasurementsMock, times(1)) { + batchSetMeasurementResults(batchSetMeasurementResultsCaptor.capture()) + } + assertThat(batchSetMeasurementResultsCaptor.allValues) + .containsExactly( + batchSetMeasurementResultsRequest { + cmmsMeasurementConsumerId = + INTERNAL_SUCCEEDED_UNION_ALL_WATCH_DURATION_MEASUREMENT.cmmsMeasurementConsumerId + measurementResults += measurementResult { + externalMeasurementId = + INTERNAL_SUCCEEDED_UNION_ALL_WATCH_DURATION_MEASUREMENT.externalMeasurementId + this.result = INTERNAL_SUCCEEDED_UNION_ALL_WATCH_DURATION_MEASUREMENT.result + } + } + ) + + // Verify proto argument of internal + // MeasurementsCoroutineImplBase::batchSetMeasurementFailures + val batchSetMeasurementFailuresCaptor: KArgumentCaptor = + argumentCaptor() + verifyBlocking(internalMeasurementsMock, never()) { + batchSetMeasurementFailures(batchSetMeasurementFailuresCaptor.capture()) + } + + assertThat(result).isEqualTo(SUCCEEDED_CROSS_PUBLISHER_WATCH_DURATION_METRIC) + } } private fun EventGroupKey.toInternal(): InternalReportingSet.Primitive.EventGroupKey { From fd189f55e2057e83169123313a53737439fcff6b Mon Sep 17 00:00:00 2001 From: Rieman Li Date: Mon, 17 Apr 2023 18:12:01 +0000 Subject: [PATCH 04/12] Fix MetricsService after rebasing. --- .../service/api/v2alpha/MetricsService.kt | 41 ++++++++----------- 1 file changed, 18 insertions(+), 23 deletions(-) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt index a3b4d377974..6ea6eb0e802 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt @@ -875,17 +875,24 @@ class MetricsService( } val internalMetric: InternalMetric = - getInternalMetric(metricKey.cmmsMeasurementConsumerId, apiIdToExternalId(metricKey.metricId)) + batchGetInternalMetrics( + metricKey.cmmsMeasurementConsumerId, + listOf(apiIdToExternalId(metricKey.metricId)) + ) + .first() // Early exit when the metric is at a terminal state. if (internalMetric.state != InternalMetric.State.RUNNING) { return internalMetric.toMetric() } + // Only syncs pending measurements which can only be in metrics that are still running. val toBeSyncedInternalMeasurements: List = - internalMetric.weightedMeasurementsList.map { weightedMeasurement -> - weightedMeasurement.measurement - } + internalMetric.weightedMeasurementsList + .map { weightedMeasurement -> weightedMeasurement.measurement } + .filter { internalMeasurement -> + internalMeasurement.state == InternalMeasurement.State.PENDING + } measurementSupplier.syncInternalMeasurements( toBeSyncedInternalMeasurements, @@ -893,10 +900,11 @@ class MetricsService( principal, ) - return getInternalMetric( + return batchGetInternalMetrics( metricKey.cmmsMeasurementConsumerId, - apiIdToExternalId(metricKey.metricId) + listOf(apiIdToExternalId(metricKey.metricId)) ) + .first() .toMetric() } @@ -934,11 +942,15 @@ class MetricsService( val internalMetrics: List = batchGetInternalMetrics(principal.resourceKey.measurementConsumerId, externalMetricIds) + // Only syncs pending measurements which can only be in metrics that are still running. val toBeSyncedInternalMeasurements: List = internalMetrics .filter { internalMetric -> internalMetric.state == InternalMetric.State.RUNNING } .flatMap { internalMetric -> internalMetric.weightedMeasurementsList } .map { weightedMeasurement -> weightedMeasurement.measurement } + .filter { internalMeasurement -> + internalMeasurement.state == InternalMeasurement.State.PENDING + } measurementSupplier.syncInternalMeasurements( toBeSyncedInternalMeasurements, @@ -952,23 +964,6 @@ class MetricsService( .map { it.toMetric() } } } - - private suspend fun getInternalMetric( - cmmsMeasurementConsumerId: String, - externalMetricId: Long, - ): org.wfanet.measurement.internal.reporting.v2alpha.Metric { - return try { - internalMetricsStub.getMetric( - internalGetMetricRequest { - this.cmmsMeasurementConsumerId = cmmsMeasurementConsumerId - this.externalMetricId = externalMetricId - } - ) - } catch (e: StatusException) { - throw Exception("Unable to get the metric from the reporting database.", e) - } - } - override suspend fun listMetrics(request: ListMetricsRequest): ListMetricsResponse { val listMetricsPageToken: ListMetricsPageToken = request.toListMetricsPageToken() From dd14a89784633d038f0ab783970d283151f5a7a1 Mon Sep 17 00:00:00 2001 From: Rieman Li Date: Mon, 17 Apr 2023 22:02:16 +0000 Subject: [PATCH 05/12] Fix MetricsServiceTest after rebasing. --- .../service/api/v2alpha/MetricsServiceTest.kt | 246 +++++++++++++----- 1 file changed, 187 insertions(+), 59 deletions(-) diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsServiceTest.kt index e3c54f8743b..f71d3c65586 100644 --- a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsServiceTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsServiceTest.kt @@ -213,9 +213,6 @@ private const val MAXIMUM_WATCH_DURATION_PER_USER = 4000 private const val DIFFERENTIAL_PRIVACY_DELTA = 1e-12 -private const val MAXIMUM_FREQUENCY_PER_USER = 10 -private const val MAXIMUM_WATCH_DURATION_PER_USER = 300 - private const val SECURE_RANDOM_OUTPUT_INT = 0 private const val SECURE_RANDOM_OUTPUT_LONG = 0L @@ -688,7 +685,7 @@ private val INTERNAL_REQUESTING_UNION_ALL_WATCH_DURATION_MEASUREMENT = internalM private val INTERNAL_PENDING_NOT_CREATED_UNION_ALL_WATCH_DURATION_MEASUREMENT = INTERNAL_REQUESTING_UNION_ALL_WATCH_DURATION_MEASUREMENT.copy { - externalMeasurementId = 414L + cmmsMeasurementId = externalIdToApiId(414L) cmmsCreateMeasurementRequestId = "UNION_ALL_WATCH_DURATION_MEASUREMENT" state = InternalMeasurement.State.PENDING } @@ -701,10 +698,13 @@ private val INTERNAL_PENDING_UNION_ALL_WATCH_DURATION_MEASUREMENT = private val INTERNAL_SUCCEEDED_UNION_ALL_WATCH_DURATION_MEASUREMENT = INTERNAL_PENDING_UNION_ALL_WATCH_DURATION_MEASUREMENT.copy { state = InternalMeasurement.State.SUCCEEDED - result = - InternalMeasurementKt.result { - watchDuration = - InternalMeasurementKt.ResultKt.watchDuration { value = TOTAL_WATCH_DURATION } + details = + InternalMeasurementKt.details { + result = + InternalMeasurementKt.result { + watchDuration = + InternalMeasurementKt.ResultKt.watchDuration { value = TOTAL_WATCH_DURATION } + } } } @@ -864,6 +864,64 @@ private val PENDING_SINGLE_PUBLISHER_IMPRESSION_MEASUREMENT = state = Measurement.State.COMPUTING } +// CMMs cross publisher watch duration measurements +private val UNION_ALL_WATCH_DURATION_MEASUREMENT_SPEC = measurementSpec { + measurementPublicKey = MEASUREMENT_CONSUMER_PUBLIC_KEY.toByteString() + + nonceHashes.addAll( + listOf( + hashSha256(SECURE_RANDOM_OUTPUT_LONG), + hashSha256(SECURE_RANDOM_OUTPUT_LONG), + hashSha256(SECURE_RANDOM_OUTPUT_LONG) + ) + ) + + duration = + MeasurementSpecKt.duration { + privacyParams = differentialPrivacyParams { + epsilon = WATCH_DURATION_EPSILON + delta = DIFFERENTIAL_PRIVACY_DELTA + } + privacyParams = differentialPrivacyParams { + epsilon = WATCH_DURATION_EPSILON + delta = DIFFERENTIAL_PRIVACY_DELTA + } + maximumWatchDurationPerUser = MAXIMUM_WATCH_DURATION_PER_USER + } + vidSamplingInterval = + MeasurementSpecKt.vidSamplingInterval { + start = WATCH_DURATION_VID_SAMPLING_START + width = WATCH_DURATION_VID_SAMPLING_WIDTH + } +} + +private val REQUESTING_UNION_ALL_WATCH_DURATION_MEASUREMENT = + BASE_MEASUREMENT.copy { + dataProviders += DATA_PROVIDERS.keys.map { DATA_PROVIDER_ENTRIES.getValue(it) } + + measurementSpec = + signMeasurementSpec( + UNION_ALL_WATCH_DURATION_MEASUREMENT_SPEC.copy { + nonceHashes += hashSha256(SECURE_RANDOM_OUTPUT_LONG) + }, + MEASUREMENT_CONSUMER_SIGNING_KEY_HANDLE + ) + + measurementReferenceId = + INTERNAL_PENDING_UNION_ALL_WATCH_DURATION_MEASUREMENT.cmmsCreateMeasurementRequestId + } + +private val PENDING_UNION_ALL_WATCH_DURATION_MEASUREMENT = + REQUESTING_UNION_ALL_WATCH_DURATION_MEASUREMENT.copy { + name = + MeasurementKey( + MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId, + INTERNAL_PENDING_UNION_ALL_WATCH_DURATION_MEASUREMENT.cmmsMeasurementId + ) + .toName() + state = Measurement.State.COMPUTING + } + // Metric Specs private val REACH_METRIC_SPEC: MetricSpec = metricSpec { @@ -1048,14 +1106,23 @@ private val INTERNAL_FAILED_SINGLE_PUBLISHER_IMPRESSION_METRIC = // Internal Cross Publisher Watch Duration Metrics private val INTERNAL_REQUESTING_CROSS_PUBLISHER_WATCH_DURATION_METRIC = internalMetric { cmmsMeasurementConsumerId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - metricIdempotencyKey = CROSS_PUBLISHER_WATCH_DURATION_METRIC_IDEMPOTENCY_KEY externalReportingSetId = INTERNAL_UNION_ALL_REPORTING_SET.externalReportingSetId timeInterval = INTERNAL_TIME_INTERVAL metricSpec = internalMetricSpec { watchDuration = InternalMetricSpecKt.watchDurationParams { + privacyParams = + InternalMetricSpecKt.differentialPrivacyParams { + epsilon = WATCH_DURATION_EPSILON + delta = DIFFERENTIAL_PRIVACY_DELTA + } maximumWatchDurationPerUser = MAXIMUM_WATCH_DURATION_PER_USER } + vidSamplingInterval = + InternalMetricSpecKt.vidSamplingInterval { + start = WATCH_DURATION_VID_SAMPLING_START + width = WATCH_DURATION_VID_SAMPLING_WIDTH + } } weightedMeasurements += weightedMeasurement { weight = 1 @@ -1093,7 +1160,9 @@ private val INTERNAL_SUCCEEDED_CROSS_PUBLISHER_WATCH_DURATION_METRIC = filters += this@copy.details.filtersList result = internalMetricResult { watchDuration = - InternalMetricResultKt.doubleResult { value = TOTAL_WATCH_DURATION.seconds.toDouble() } + InternalMetricResultKt.watchDurationResult { + value = TOTAL_WATCH_DURATION.seconds.toDouble() + } } } } @@ -1194,6 +1263,21 @@ private val PENDING_CROSS_PUBLISHER_WATCH_DURATION_METRIC = externalIdToApiId(INTERNAL_PENDING_CROSS_PUBLISHER_WATCH_DURATION_METRIC.externalMetricId) ) .toName() + metricSpec = metricSpec { + watchDuration = watchDurationParams { + privacyParams = + MetricSpecKt.differentialPrivacyParams { + epsilon = WATCH_DURATION_EPSILON + delta = DIFFERENTIAL_PRIVACY_DELTA + } + maximumWatchDurationPerUser = MAXIMUM_WATCH_DURATION_PER_USER + } + vidSamplingInterval = + MetricSpecKt.vidSamplingInterval { + start = WATCH_DURATION_VID_SAMPLING_START + width = WATCH_DURATION_VID_SAMPLING_WIDTH + } + } state = Metric.State.RUNNING createTime = INTERNAL_PENDING_CROSS_PUBLISHER_WATCH_DURATION_METRIC.createTime } @@ -1202,11 +1286,8 @@ private val SUCCEEDED_CROSS_PUBLISHER_WATCH_DURATION_METRIC = PENDING_CROSS_PUBLISHER_WATCH_DURATION_METRIC.copy { state = Metric.State.SUCCEEDED result = metricResult { - watchDuration = doubleResult { - value = - INTERNAL_SUCCEEDED_CROSS_PUBLISHER_WATCH_DURATION_METRIC.details.result.watchDuration - .value - } + watchDuration = + MetricResultKt.watchDurationResult { value = TOTAL_WATCH_DURATION.seconds.toDouble() } } } @@ -3990,8 +4071,10 @@ class MetricsServiceTest { @Test fun `getMetric returns the metric with SUCCEEDED when the metric is already succeeded`() = runBlocking { - whenever(internalMetricsMock.getMetric(any())) - .thenReturn(INTERNAL_SUCCEEDED_INCREMENTAL_REACH_METRIC) + whenever(internalMetricsMock.batchGetMetrics(any())) + .thenReturn( + internalBatchGetMetricsResponse { metrics += INTERNAL_SUCCEEDED_INCREMENTAL_REACH_METRIC } + ) val request = getMetricRequest { name = SUCCEEDED_INCREMENTAL_REACH_METRIC.name } @@ -4000,16 +4083,19 @@ class MetricsServiceTest { runBlocking { service.getMetric(request) } } - // Verify proto argument of internal MetricsCoroutineImplBase::getMetric - val getInternalMetricCaptor: KArgumentCaptor = argumentCaptor() - verifyBlocking(internalMetricsMock, times(1)) { getMetric(getInternalMetricCaptor.capture()) } - val capturedInternalGetMetricRequests = getInternalMetricCaptor.allValues + // Verify proto argument of internal MetricsCoroutineImplBase::batchGetMetrics + val batchGetInternalMetricsCaptor: KArgumentCaptor = + argumentCaptor() + verifyBlocking(internalMetricsMock, times(1)) { + batchGetMetrics(batchGetInternalMetricsCaptor.capture()) + } + val capturedInternalGetMetricRequests = batchGetInternalMetricsCaptor.allValues assertThat(capturedInternalGetMetricRequests) .containsExactly( - internalGetMetricRequest { + internalBatchGetMetricsRequest { cmmsMeasurementConsumerId = INTERNAL_SUCCEEDED_INCREMENTAL_REACH_METRIC.cmmsMeasurementConsumerId - externalMetricId = INTERNAL_SUCCEEDED_INCREMENTAL_REACH_METRIC.externalMetricId + externalMetricIds += INTERNAL_SUCCEEDED_INCREMENTAL_REACH_METRIC.externalMetricId } ) @@ -4033,8 +4119,12 @@ class MetricsServiceTest { @Test fun `getMetric returns the metric with FAILED when the metric is already failed`() = runBlocking { - whenever(internalMetricsMock.getMetric(any())) - .thenReturn(INTERNAL_FAILED_SINGLE_PUBLISHER_IMPRESSION_METRIC) + whenever(internalMetricsMock.batchGetMetrics(any())) + .thenReturn( + internalBatchGetMetricsResponse { + metrics += INTERNAL_FAILED_SINGLE_PUBLISHER_IMPRESSION_METRIC + } + ) val request = getMetricRequest { name = FAILED_SINGLE_PUBLISHER_IMPRESSION_METRIC.name } @@ -4043,16 +4133,19 @@ class MetricsServiceTest { runBlocking { service.getMetric(request) } } - // Verify proto argument of internal MetricsCoroutineImplBase::getMetric - val getInternalMetricCaptor: KArgumentCaptor = argumentCaptor() - verifyBlocking(internalMetricsMock, times(1)) { getMetric(getInternalMetricCaptor.capture()) } - val capturedInternalGetMetricRequests = getInternalMetricCaptor.allValues + // Verify proto argument of internal MetricsCoroutineImplBase::batchGetMetrics + val batchGetInternalMetricsCaptor: KArgumentCaptor = + argumentCaptor() + verifyBlocking(internalMetricsMock, times(1)) { + batchGetMetrics(batchGetInternalMetricsCaptor.capture()) + } + val capturedInternalGetMetricRequests = batchGetInternalMetricsCaptor.allValues assertThat(capturedInternalGetMetricRequests) .containsExactly( - internalGetMetricRequest { + internalBatchGetMetricsRequest { cmmsMeasurementConsumerId = INTERNAL_FAILED_SINGLE_PUBLISHER_IMPRESSION_METRIC.cmmsMeasurementConsumerId - externalMetricId = INTERNAL_FAILED_SINGLE_PUBLISHER_IMPRESSION_METRIC.externalMetricId + externalMetricIds += INTERNAL_FAILED_SINGLE_PUBLISHER_IMPRESSION_METRIC.externalMetricId } ) @@ -4076,10 +4169,10 @@ class MetricsServiceTest { @Test fun `getMetric returns the metric with RUNNING when measurements are pending`() = runBlocking { - whenever(internalMetricsMock.getMetric(any())) + whenever(internalMetricsMock.batchGetMetrics(any())) .thenReturn( - INTERNAL_PENDING_INCREMENTAL_REACH_METRIC, - INTERNAL_PENDING_INCREMENTAL_REACH_METRIC + internalBatchGetMetricsResponse { metrics += INTERNAL_PENDING_INCREMENTAL_REACH_METRIC }, + internalBatchGetMetricsResponse { metrics += INTERNAL_PENDING_INCREMENTAL_REACH_METRIC } ) val request = getMetricRequest { name = PENDING_INCREMENTAL_REACH_METRIC.name } @@ -4089,21 +4182,24 @@ class MetricsServiceTest { runBlocking { service.getMetric(request) } } - // Verify proto argument of internal MetricsCoroutineImplBase::getMetric - val getInternalMetricCaptor: KArgumentCaptor = argumentCaptor() - verifyBlocking(internalMetricsMock, times(2)) { getMetric(getInternalMetricCaptor.capture()) } - val capturedInternalGetMetricRequests = getInternalMetricCaptor.allValues + // Verify proto argument of internal MetricsCoroutineImplBase::batchGetMetrics + val batchGetInternalMetricsCaptor: KArgumentCaptor = + argumentCaptor() + verifyBlocking(internalMetricsMock, times(2)) { + batchGetMetrics(batchGetInternalMetricsCaptor.capture()) + } + val capturedInternalGetMetricRequests = batchGetInternalMetricsCaptor.allValues assertThat(capturedInternalGetMetricRequests) .containsExactly( - internalGetMetricRequest { + internalBatchGetMetricsRequest { cmmsMeasurementConsumerId = INTERNAL_PENDING_INCREMENTAL_REACH_METRIC.cmmsMeasurementConsumerId - externalMetricId = INTERNAL_PENDING_INCREMENTAL_REACH_METRIC.externalMetricId + externalMetricIds += INTERNAL_PENDING_INCREMENTAL_REACH_METRIC.externalMetricId }, - internalGetMetricRequest { + internalBatchGetMetricsRequest { cmmsMeasurementConsumerId = INTERNAL_PENDING_INCREMENTAL_REACH_METRIC.cmmsMeasurementConsumerId - externalMetricId = INTERNAL_PENDING_INCREMENTAL_REACH_METRIC.externalMetricId + externalMetricIds += INTERNAL_PENDING_INCREMENTAL_REACH_METRIC.externalMetricId } ) @@ -4128,17 +4224,46 @@ class MetricsServiceTest { @Test fun `getMetric returns the metric with SUCCEEDED when measurements are SUCCEEDED`() = runBlocking { - whenever(internalMetricsMock.getMetric(any())) + whenever(internalMetricsMock.batchGetMetrics(any())) .thenReturn( - INTERNAL_PENDING_CROSS_PUBLISHER_WATCH_DURATION_METRIC, - INTERNAL_SUCCEEDED_CROSS_PUBLISHER_WATCH_DURATION_METRIC + internalBatchGetMetricsResponse { + metrics += INTERNAL_PENDING_CROSS_PUBLISHER_WATCH_DURATION_METRIC + }, + internalBatchGetMetricsResponse { + metrics += INTERNAL_SUCCEEDED_CROSS_PUBLISHER_WATCH_DURATION_METRIC + }, ) + + val succeededUnionAllWatchDurationMeasurement = + PENDING_UNION_ALL_WATCH_DURATION_MEASUREMENT.copy { + state = Measurement.State.SUCCEEDED + + results += + DATA_PROVIDERS.keys.zip(WATCH_DURATION_LIST).map { (dataProviderKey, watchDuration) -> + val dataProvider = DATA_PROVIDERS.getValue(dataProviderKey) + resultPair { + val result = + MeasurementKt.result { + this.watchDuration = + MeasurementKt.ResultKt.watchDuration { value = watchDuration } + } + encryptedResult = + encryptResult( + signResult(result, DATA_PROVIDER_SIGNING_KEY), + MEASUREMENT_CONSUMER_PUBLIC_KEY + ) + certificate = dataProvider.certificate + } + } + } whenever(measurementsMock.getMeasurement(any())) + .thenReturn(succeededUnionAllWatchDurationMeasurement) + whenever(internalMeasurementsMock.batchSetMeasurementResults(any())) .thenReturn( - SUCCEEDED_UNION_ALL_WATCH_DURATION_MEASUREMENT, + batchSetCmmsMeasurementResultsResponse { + measurements += INTERNAL_SUCCEEDED_UNION_ALL_WATCH_DURATION_MEASUREMENT + } ) - whenever(internalMeasurementsMock.batchSetMeasurementResults(any())) - .thenReturn(flowOf(INTERNAL_SUCCEEDED_UNION_ALL_WATCH_DURATION_MEASUREMENT)) val request = getMetricRequest { name = PENDING_CROSS_PUBLISHER_WATCH_DURATION_METRIC.name } @@ -4147,22 +4272,25 @@ class MetricsServiceTest { runBlocking { service.getMetric(request) } } - // Verify proto argument of internal MetricsCoroutineImplBase::getMetric - val getInternalMetricCaptor: KArgumentCaptor = argumentCaptor() - verifyBlocking(internalMetricsMock, times(2)) { getMetric(getInternalMetricCaptor.capture()) } - val capturedInternalGetMetricRequests = getInternalMetricCaptor.allValues + // Verify proto argument of internal MetricsCoroutineImplBase::batchGetMetrics + val batchGetInternalMetricsCaptor: KArgumentCaptor = + argumentCaptor() + verifyBlocking(internalMetricsMock, times(2)) { + batchGetMetrics(batchGetInternalMetricsCaptor.capture()) + } + val capturedInternalGetMetricRequests = batchGetInternalMetricsCaptor.allValues assertThat(capturedInternalGetMetricRequests) .containsExactly( - internalGetMetricRequest { + internalBatchGetMetricsRequest { cmmsMeasurementConsumerId = INTERNAL_PENDING_CROSS_PUBLISHER_WATCH_DURATION_METRIC.cmmsMeasurementConsumerId - externalMetricId = + externalMetricIds += INTERNAL_PENDING_CROSS_PUBLISHER_WATCH_DURATION_METRIC.externalMetricId }, - internalGetMetricRequest { + internalBatchGetMetricsRequest { cmmsMeasurementConsumerId = INTERNAL_PENDING_CROSS_PUBLISHER_WATCH_DURATION_METRIC.cmmsMeasurementConsumerId - externalMetricId = + externalMetricIds += INTERNAL_PENDING_CROSS_PUBLISHER_WATCH_DURATION_METRIC.externalMetricId } ) @@ -4179,9 +4307,9 @@ class MetricsServiceTest { cmmsMeasurementConsumerId = INTERNAL_SUCCEEDED_UNION_ALL_WATCH_DURATION_MEASUREMENT.cmmsMeasurementConsumerId measurementResults += measurementResult { - externalMeasurementId = - INTERNAL_SUCCEEDED_UNION_ALL_WATCH_DURATION_MEASUREMENT.externalMeasurementId - this.result = INTERNAL_SUCCEEDED_UNION_ALL_WATCH_DURATION_MEASUREMENT.result + cmmsMeasurementId = + INTERNAL_SUCCEEDED_UNION_ALL_WATCH_DURATION_MEASUREMENT.cmmsMeasurementId + this.result = INTERNAL_SUCCEEDED_UNION_ALL_WATCH_DURATION_MEASUREMENT.details.result } } ) From a5779b8f4fe588ac44ebafc7fb7853e19035ff11 Mon Sep 17 00:00:00 2001 From: Rieman Li Date: Fri, 21 Apr 2023 21:18:55 +0000 Subject: [PATCH 06/12] Remove codes that contain internal metric states. Fix unit tests. --- .../service/api/v2alpha/MetricsService.kt | 66 +++++++++++-------- .../service/api/v2alpha/MetricsServiceTest.kt | 60 +++++++++++------ 2 files changed, 80 insertions(+), 46 deletions(-) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt index 6ea6eb0e802..496636fd629 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt @@ -875,14 +875,10 @@ class MetricsService( } val internalMetric: InternalMetric = - batchGetInternalMetrics( - metricKey.cmmsMeasurementConsumerId, - listOf(apiIdToExternalId(metricKey.metricId)) - ) - .first() + getInternalMetric(metricKey.cmmsMeasurementConsumerId, apiIdToExternalId(metricKey.metricId)) // Early exit when the metric is at a terminal state. - if (internalMetric.state != InternalMetric.State.RUNNING) { + if (determineMetricState(internalMetric) != Metric.State.RUNNING) { return internalMetric.toMetric() } @@ -894,18 +890,19 @@ class MetricsService( internalMeasurement.state == InternalMeasurement.State.PENDING } - measurementSupplier.syncInternalMeasurements( - toBeSyncedInternalMeasurements, - principal.config.apiKey, - principal, - ) - - return batchGetInternalMetrics( - metricKey.cmmsMeasurementConsumerId, - listOf(apiIdToExternalId(metricKey.metricId)) + val anyMeasurementUpdated: Boolean = + measurementSupplier.syncInternalMeasurements( + toBeSyncedInternalMeasurements, + principal.config.apiKey, + principal, ) - .first() - .toMetric() + + return if (anyMeasurementUpdated) { + getInternalMetric(metricKey.cmmsMeasurementConsumerId, apiIdToExternalId(metricKey.metricId)) + .toMetric() + } else { + internalMetric.toMetric() + } } override suspend fun batchGetMetrics(request: BatchGetMetricsRequest): BatchGetMetricsResponse { @@ -945,23 +942,28 @@ class MetricsService( // Only syncs pending measurements which can only be in metrics that are still running. val toBeSyncedInternalMeasurements: List = internalMetrics - .filter { internalMetric -> internalMetric.state == InternalMetric.State.RUNNING } + .filter { internalMetric -> determineMetricState(internalMetric) == Metric.State.RUNNING } .flatMap { internalMetric -> internalMetric.weightedMeasurementsList } .map { weightedMeasurement -> weightedMeasurement.measurement } .filter { internalMeasurement -> internalMeasurement.state == InternalMeasurement.State.PENDING } - measurementSupplier.syncInternalMeasurements( - toBeSyncedInternalMeasurements, - principal.config.apiKey, - principal, - ) + val anyMeasurementUpdated: Boolean = + measurementSupplier.syncInternalMeasurements( + toBeSyncedInternalMeasurements, + principal.config.apiKey, + principal, + ) return batchGetMetricsResponse { metrics += - batchGetInternalMetrics(principal.resourceKey.measurementConsumerId, externalMetricIds) - .map { it.toMetric() } + if (anyMeasurementUpdated) { + batchGetInternalMetrics(principal.resourceKey.measurementConsumerId, externalMetricIds) + .map { it.toMetric() } + } else { + internalMetrics.map { it.toMetric() } + } } } override suspend fun listMetrics(request: ListMetricsRequest): ListMetricsResponse { @@ -1046,7 +1048,7 @@ class MetricsService( } } - /** Gets a batch of [InternalMetric]. */ + /** Gets a batch of [InternalMetric]s. */ private suspend fun batchGetInternalMetrics( cmmsMeasurementConsumerId: String, externalMetricIds: List, @@ -1063,6 +1065,18 @@ class MetricsService( } } + /** Gets an [InternalMetric]. */ + private suspend fun getInternalMetric( + cmmsMeasurementConsumerId: String, + externalMetricId: Long, + ): InternalMetric { + return try { + batchGetInternalMetrics(cmmsMeasurementConsumerId, listOf(externalMetricId)).first() + } catch (e: StatusException) { + throw Exception("Unable to get metrics from the reporting database.", e) + } + } + override suspend fun createMetric(request: CreateMetricRequest): Metric { grpcRequireNotNull(MeasurementConsumerKey.fromName(request.parent)) { "Parent is either unspecified or invalid." diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsServiceTest.kt index f71d3c65586..bd46ea11d1c 100644 --- a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsServiceTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsServiceTest.kt @@ -1101,7 +1101,13 @@ private val INTERNAL_PENDING_SINGLE_PUBLISHER_IMPRESSION_METRIC = } private val INTERNAL_FAILED_SINGLE_PUBLISHER_IMPRESSION_METRIC = - INTERNAL_PENDING_SINGLE_PUBLISHER_IMPRESSION_METRIC.copy { state = InternalMetric.State.FAILED } + INTERNAL_PENDING_SINGLE_PUBLISHER_IMPRESSION_METRIC.copy { + weightedMeasurements.clear() + weightedMeasurements += weightedMeasurement { + weight = 1 + measurement = INTERNAL_FAILED_SINGLE_PUBLISHER_IMPRESSION_MEASUREMENT + } + } // Internal Cross Publisher Watch Duration Metrics private val INTERNAL_REQUESTING_CROSS_PUBLISHER_WATCH_DURATION_METRIC = internalMetric { @@ -1140,7 +1146,6 @@ private val INTERNAL_PENDING_INITIAL_CROSS_PUBLISHER_WATCH_DURATION_METRIC = weight = 1 measurement = INTERNAL_PENDING_NOT_CREATED_UNION_ALL_WATCH_DURATION_MEASUREMENT } - state = InternalMetric.State.RUNNING } private val INTERNAL_PENDING_CROSS_PUBLISHER_WATCH_DURATION_METRIC = @@ -1154,17 +1159,12 @@ private val INTERNAL_PENDING_CROSS_PUBLISHER_WATCH_DURATION_METRIC = private val INTERNAL_SUCCEEDED_CROSS_PUBLISHER_WATCH_DURATION_METRIC = INTERNAL_PENDING_CROSS_PUBLISHER_WATCH_DURATION_METRIC.copy { - state = InternalMetric.State.SUCCEEDED - details = - InternalMetricKt.details { - filters += this@copy.details.filtersList - result = internalMetricResult { - watchDuration = - InternalMetricResultKt.watchDurationResult { - value = TOTAL_WATCH_DURATION.seconds.toDouble() - } - } - } + weightedMeasurements.clear() + weightedMeasurements += weightedMeasurement { + weight = 1 + measurement = INTERNAL_SUCCEEDED_UNION_ALL_WATCH_DURATION_MEASUREMENT + } + details = InternalMetricKt.details { filters += this@copy.details.filtersList } } // Public Metrics @@ -4172,7 +4172,11 @@ class MetricsServiceTest { whenever(internalMetricsMock.batchGetMetrics(any())) .thenReturn( internalBatchGetMetricsResponse { metrics += INTERNAL_PENDING_INCREMENTAL_REACH_METRIC }, - internalBatchGetMetricsResponse { metrics += INTERNAL_PENDING_INCREMENTAL_REACH_METRIC } + ) + whenever(measurementsMock.getMeasurement(any())) + .thenReturn( + PENDING_UNION_ALL_REACH_MEASUREMENT, + PENDING_UNION_ALL_BUT_LAST_PUBLISHER_REACH_MEASUREMENT, ) val request = getMetricRequest { name = PENDING_INCREMENTAL_REACH_METRIC.name } @@ -4185,17 +4189,12 @@ class MetricsServiceTest { // Verify proto argument of internal MetricsCoroutineImplBase::batchGetMetrics val batchGetInternalMetricsCaptor: KArgumentCaptor = argumentCaptor() - verifyBlocking(internalMetricsMock, times(2)) { + verifyBlocking(internalMetricsMock, times(1)) { batchGetMetrics(batchGetInternalMetricsCaptor.capture()) } val capturedInternalGetMetricRequests = batchGetInternalMetricsCaptor.allValues assertThat(capturedInternalGetMetricRequests) .containsExactly( - internalBatchGetMetricsRequest { - cmmsMeasurementConsumerId = - INTERNAL_PENDING_INCREMENTAL_REACH_METRIC.cmmsMeasurementConsumerId - externalMetricIds += INTERNAL_PENDING_INCREMENTAL_REACH_METRIC.externalMetricId - }, internalBatchGetMetricsRequest { cmmsMeasurementConsumerId = INTERNAL_PENDING_INCREMENTAL_REACH_METRIC.cmmsMeasurementConsumerId @@ -4203,6 +4202,18 @@ class MetricsServiceTest { } ) + // Verify proto argument of MeasurementsCoroutineImplBase::getMeasurement + val getMeasurementCaptor: KArgumentCaptor = argumentCaptor() + verifyBlocking(measurementsMock, times(2)) { getMeasurement(getMeasurementCaptor.capture()) } + val capturedGetMeasurementRequests = getMeasurementCaptor.allValues + assertThat(capturedGetMeasurementRequests) + .containsExactly( + getMeasurementRequest { name = PENDING_UNION_ALL_REACH_MEASUREMENT.name }, + getMeasurementRequest { + name = PENDING_UNION_ALL_BUT_LAST_PUBLISHER_REACH_MEASUREMENT.name + }, + ) + // Verify proto argument of internal MeasurementsCoroutineImplBase::batchSetMeasurementResults val batchSetMeasurementResultsCaptor: KArgumentCaptor = argumentCaptor() @@ -4295,6 +4306,15 @@ class MetricsServiceTest { } ) + // Verify proto argument of MeasurementsCoroutineImplBase::getMeasurement + val getMeasurementCaptor: KArgumentCaptor = argumentCaptor() + verifyBlocking(measurementsMock, times(1)) { getMeasurement(getMeasurementCaptor.capture()) } + val capturedGetMeasurementRequests = getMeasurementCaptor.allValues + assertThat(capturedGetMeasurementRequests) + .containsExactly( + getMeasurementRequest { name = PENDING_UNION_ALL_WATCH_DURATION_MEASUREMENT.name }, + ) + // Verify proto argument of internal MeasurementsCoroutineImplBase::batchSetMeasurementResults val batchSetMeasurementResultsCaptor: KArgumentCaptor = argumentCaptor() From db01179b009b70d24a35d9e873bca5650a9495f5 Mon Sep 17 00:00:00 2001 From: Rieman Li Date: Fri, 21 Apr 2023 23:19:58 +0000 Subject: [PATCH 07/12] Add unit tests for getMetric and batchGetMetrics. --- .../service/api/v2alpha/MetricsService.kt | 11 +- .../service/api/v2alpha/MetricsServiceTest.kt | 438 ++++++++++++++++-- 2 files changed, 411 insertions(+), 38 deletions(-) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt index 496636fd629..4f9bf3bb3b8 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt @@ -878,7 +878,7 @@ class MetricsService( getInternalMetric(metricKey.cmmsMeasurementConsumerId, apiIdToExternalId(metricKey.metricId)) // Early exit when the metric is at a terminal state. - if (determineMetricState(internalMetric) != Metric.State.RUNNING) { + if (internalMetric.state != Metric.State.RUNNING) { return internalMetric.toMetric() } @@ -942,7 +942,7 @@ class MetricsService( // Only syncs pending measurements which can only be in metrics that are still running. val toBeSyncedInternalMeasurements: List = internalMetrics - .filter { internalMetric -> determineMetricState(internalMetric) == Metric.State.RUNNING } + .filter { internalMetric -> internalMetric.state == Metric.State.RUNNING } .flatMap { internalMetric -> internalMetric.weightedMeasurementsList } .map { weightedMeasurement -> weightedMeasurement.measurement } .filter { internalMeasurement -> @@ -1073,7 +1073,12 @@ class MetricsService( return try { batchGetInternalMetrics(cmmsMeasurementConsumerId, listOf(externalMetricId)).first() } catch (e: StatusException) { - throw Exception("Unable to get metrics from the reporting database.", e) + val metricName = + MetricKey(cmmsMeasurementConsumerId, externalIdToApiId(externalMetricId)).toName() + throw Exception( + "Unable to get the metric with name = [${metricName}] from the reporting database.", + e + ) } } diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsServiceTest.kt index bd46ea11d1c..89b06c42caf 100644 --- a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsServiceTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsServiceTest.kt @@ -176,6 +176,8 @@ import org.wfanet.measurement.reporting.v2alpha.MetricSpecKt.reachParams import org.wfanet.measurement.reporting.v2alpha.MetricSpecKt.watchDurationParams import org.wfanet.measurement.reporting.v2alpha.batchCreateMetricsRequest import org.wfanet.measurement.reporting.v2alpha.batchCreateMetricsResponse +import org.wfanet.measurement.reporting.v2alpha.batchGetMetricsRequest +import org.wfanet.measurement.reporting.v2alpha.batchGetMetricsResponse import org.wfanet.measurement.reporting.v2alpha.copy import org.wfanet.measurement.reporting.v2alpha.createMetricRequest import org.wfanet.measurement.reporting.v2alpha.getMetricRequest @@ -313,8 +315,6 @@ private val AGGREGATOR_CERTIFICATE = certificate { private val AGGREGATOR_ROOT_CERTIFICATE: X509Certificate = readCertificate(SECRETS_DIR.resolve("aggregator_root.pem")) -private val INVALID_MEASUREMENT_PUBLIC_KEY_DATA = "Invalid public key".toByteStringUtf8() - // Measurement consumer crypto private val TRUSTED_MEASUREMENT_CONSUMER_ISSUER: X509Certificate = @@ -598,10 +598,6 @@ private const val UNION_ALL_REACH_VALUE = 100_000L private const val UNION_ALL_BUT_LAST_PUBLISHER_REACH_VALUE = 70_000L private const val INCREMENTAL_REACH_VALUE = UNION_ALL_REACH_VALUE - UNION_ALL_BUT_LAST_PUBLISHER_REACH_VALUE -private val FREQUENCY_DISTRIBUTION = mapOf(1L to 1.0 / 6, 2L to 2.0 / 6, 3L to 3.0 / 6) -private const val FIRST_PUBLISHER_IMPRESSION_VALUE = 100L -private val IMPRESSION_VALUES = listOf(100L, 150L) -private val TOTAL_IMPRESSION_VALUE = IMPRESSION_VALUES.sum() private val WATCH_DURATION_SECOND_LIST = listOf(100L, 200L, 300L) private val WATCH_DURATION_LIST = WATCH_DURATION_SECOND_LIST.map { duration { seconds = it } } private val TOTAL_WATCH_DURATION = duration { seconds = WATCH_DURATION_SECOND_LIST.sum() } @@ -790,32 +786,29 @@ private val SUCCEEDED_UNION_ALL_REACH_MEASUREMENT = PENDING_UNION_ALL_REACH_MEASUREMENT.copy { state = Measurement.State.SUCCEEDED - results += - MeasurementKt.resultPair { - val result = - MeasurementKt.result { - reach = MeasurementKt.ResultKt.reach { value = UNION_ALL_REACH_VALUE } - } - encryptedResult = - encryptResult(signResult(result, AGGREGATOR_SIGNING_KEY), MEASUREMENT_CONSUMER_PUBLIC_KEY) - certificate = AGGREGATOR_CERTIFICATE.name - } + results += resultPair { + val result = + MeasurementKt.result { + reach = MeasurementKt.ResultKt.reach { value = UNION_ALL_REACH_VALUE } + } + encryptedResult = + encryptResult(signResult(result, AGGREGATOR_SIGNING_KEY), MEASUREMENT_CONSUMER_PUBLIC_KEY) + certificate = AGGREGATOR_CERTIFICATE.name + } } private val SUCCEEDED_UNION_ALL_BUT_LAST_PUBLISHER_REACH_MEASUREMENT = PENDING_UNION_ALL_BUT_LAST_PUBLISHER_REACH_MEASUREMENT.copy { state = Measurement.State.SUCCEEDED - results += - MeasurementKt.resultPair { - val result = - MeasurementKt.result { - reach = - MeasurementKt.ResultKt.reach { value = UNION_ALL_BUT_LAST_PUBLISHER_REACH_VALUE } - } - encryptedResult = - encryptResult(signResult(result, AGGREGATOR_SIGNING_KEY), MEASUREMENT_CONSUMER_PUBLIC_KEY) - certificate = AGGREGATOR_CERTIFICATE.name - } + results += resultPair { + val result = + MeasurementKt.result { + reach = MeasurementKt.ResultKt.reach { value = UNION_ALL_BUT_LAST_PUBLISHER_REACH_VALUE } + } + encryptedResult = + encryptResult(signResult(result, AGGREGATOR_SIGNING_KEY), MEASUREMENT_CONSUMER_PUBLIC_KEY) + certificate = AGGREGATOR_CERTIFICATE.name + } } // CMMs single publisher impression measurements @@ -942,10 +935,6 @@ private val WATCH_DURATION_METRIC_SPEC: MetricSpec = metricSpec { // Metric idempotency keys private const val INCREMENTAL_REACH_METRIC_IDEMPOTENCY_KEY = "TEST_INCREMENTAL_REACH_METRIC" -private const val SINGLE_PUBLISHER_IMPRESSION_METRIC_IDEMPOTENCY_KEY = - "TEST_SINGLE_PUBLISHER_IMPRESSION_METRIC" -private const val IMPRESSION_METRIC_IDEMPOTENCY_KEY = "TEST_IMPRESSION_METRIC" -private const val WATCH_DURATION_METRIC_IDEMPOTENCY_KEY = "TEST_WATCH_DURATION_METRIC" // Internal Incremental Metrics private val INTERNAL_REQUESTING_INCREMENTAL_REACH_METRIC = internalMetric { @@ -1021,7 +1010,6 @@ private val INTERNAL_PENDING_INCREMENTAL_REACH_METRIC = private val INTERNAL_SUCCEEDED_INCREMENTAL_REACH_METRIC = INTERNAL_PENDING_INCREMENTAL_REACH_METRIC.copy { - details = InternalMetricKt.details { filters += this@copy.details.filtersList } weightedMeasurements.clear() weightedMeasurements += weightedMeasurement { weight = 1 @@ -1164,7 +1152,6 @@ private val INTERNAL_SUCCEEDED_CROSS_PUBLISHER_WATCH_DURATION_METRIC = weight = 1 measurement = INTERNAL_SUCCEEDED_UNION_ALL_WATCH_DURATION_MEASUREMENT } - details = InternalMetricKt.details { filters += this@copy.details.filtersList } } // Public Metrics @@ -3034,7 +3021,7 @@ class MetricsServiceTest { } @Test - fun `batchCreateMetric throws exception when number of requests exceeds limit`() = runBlocking { + fun `batchCreateMetrics throws exception when number of requests exceeds limit`() = runBlocking { val request = batchCreateMetricsRequest { parent = MEASUREMENT_CONSUMERS.values.first().name @@ -4233,7 +4220,7 @@ class MetricsServiceTest { } @Test - fun `getMetric returns the metric with SUCCEEDED when measurements are SUCCEEDED`() = + fun `getMetric returns the metric with SUCCEEDED when measurements are updated to SUCCEEDED`() = runBlocking { whenever(internalMetricsMock.batchGetMetrics(any())) .thenReturn( @@ -4344,6 +4331,387 @@ class MetricsServiceTest { assertThat(result).isEqualTo(SUCCEEDED_CROSS_PUBLISHER_WATCH_DURATION_METRIC) } + + @Test + fun `getMetric returns the metric with FAILED when measurements are updated to FAILED`() = + runBlocking { + whenever(internalMetricsMock.batchGetMetrics(any())) + .thenReturn( + internalBatchGetMetricsResponse { + metrics += INTERNAL_PENDING_SINGLE_PUBLISHER_IMPRESSION_METRIC + }, + internalBatchGetMetricsResponse { + metrics += INTERNAL_FAILED_SINGLE_PUBLISHER_IMPRESSION_METRIC + }, + ) + + val failedSinglePublisherImpressionMeasurement = + PENDING_SINGLE_PUBLISHER_IMPRESSION_MEASUREMENT.copy { + state = Measurement.State.FAILED + failure = failure { + reason = Measurement.Failure.Reason.REQUISITION_REFUSED + message = + INTERNAL_FAILED_SINGLE_PUBLISHER_IMPRESSION_MEASUREMENT.details.failure.message + } + } + + whenever(measurementsMock.getMeasurement(any())) + .thenReturn(failedSinglePublisherImpressionMeasurement) + whenever(internalMeasurementsMock.batchSetMeasurementFailures(any())) + .thenReturn( + batchSetCmmsMeasurementFailuresResponse { + measurements += INTERNAL_FAILED_SINGLE_PUBLISHER_IMPRESSION_MEASUREMENT + } + ) + + val request = getMetricRequest { name = PENDING_SINGLE_PUBLISHER_IMPRESSION_METRIC.name } + + val result = + withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { + runBlocking { service.getMetric(request) } + } + + // Verify proto argument of internal MetricsCoroutineImplBase::batchGetMetrics + val batchGetInternalMetricsCaptor: KArgumentCaptor = + argumentCaptor() + verifyBlocking(internalMetricsMock, times(2)) { + batchGetMetrics(batchGetInternalMetricsCaptor.capture()) + } + val capturedInternalGetMetricRequests = batchGetInternalMetricsCaptor.allValues + assertThat(capturedInternalGetMetricRequests) + .containsExactly( + internalBatchGetMetricsRequest { + cmmsMeasurementConsumerId = + INTERNAL_PENDING_SINGLE_PUBLISHER_IMPRESSION_METRIC.cmmsMeasurementConsumerId + externalMetricIds += + INTERNAL_PENDING_SINGLE_PUBLISHER_IMPRESSION_METRIC.externalMetricId + }, + internalBatchGetMetricsRequest { + cmmsMeasurementConsumerId = + INTERNAL_PENDING_SINGLE_PUBLISHER_IMPRESSION_METRIC.cmmsMeasurementConsumerId + externalMetricIds += + INTERNAL_PENDING_SINGLE_PUBLISHER_IMPRESSION_METRIC.externalMetricId + } + ) + + // Verify proto argument of MeasurementsCoroutineImplBase::getMeasurement + val getMeasurementCaptor: KArgumentCaptor = argumentCaptor() + verifyBlocking(measurementsMock, times(1)) { getMeasurement(getMeasurementCaptor.capture()) } + val capturedGetMeasurementRequests = getMeasurementCaptor.allValues + assertThat(capturedGetMeasurementRequests) + .containsExactly( + getMeasurementRequest { name = PENDING_SINGLE_PUBLISHER_IMPRESSION_MEASUREMENT.name }, + ) + + // Verify proto argument of internal MeasurementsCoroutineImplBase::batchSetMeasurementResults + val batchSetMeasurementResultsCaptor: KArgumentCaptor = + argumentCaptor() + verifyBlocking(internalMeasurementsMock, never()) { + batchSetMeasurementResults(batchSetMeasurementResultsCaptor.capture()) + } + + // Verify proto argument of internal + // MeasurementsCoroutineImplBase::batchSetMeasurementFailures + val batchSetMeasurementFailuresCaptor: KArgumentCaptor = + argumentCaptor() + verifyBlocking(internalMeasurementsMock, times(1)) { + batchSetMeasurementFailures(batchSetMeasurementFailuresCaptor.capture()) + } + assertThat(batchSetMeasurementFailuresCaptor.allValues) + .containsExactly( + batchSetMeasurementFailuresRequest { + cmmsMeasurementConsumerId = + INTERNAL_FAILED_SINGLE_PUBLISHER_IMPRESSION_MEASUREMENT.cmmsMeasurementConsumerId + measurementFailures += measurementFailure { + cmmsMeasurementId = + INTERNAL_FAILED_SINGLE_PUBLISHER_IMPRESSION_MEASUREMENT.cmmsMeasurementId + this.failure = INTERNAL_FAILED_SINGLE_PUBLISHER_IMPRESSION_MEASUREMENT.details.failure + } + } + ) + + assertThat(result).isEqualTo(FAILED_SINGLE_PUBLISHER_IMPRESSION_METRIC) + } + + @Test + fun `getMetric returns the metric with SUCCEEDED when measurements are already SUCCEEDED`() = + runBlocking { + whenever(internalMetricsMock.batchGetMetrics(any())) + .thenReturn( + internalBatchGetMetricsResponse { + metrics += INTERNAL_SUCCEEDED_CROSS_PUBLISHER_WATCH_DURATION_METRIC + }, + ) + + val request = getMetricRequest { name = PENDING_CROSS_PUBLISHER_WATCH_DURATION_METRIC.name } + + val result = + withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { + runBlocking { service.getMetric(request) } + } + + // Verify proto argument of internal MetricsCoroutineImplBase::batchGetMetrics + val batchGetInternalMetricsCaptor: KArgumentCaptor = + argumentCaptor() + verifyBlocking(internalMetricsMock, times(1)) { + batchGetMetrics(batchGetInternalMetricsCaptor.capture()) + } + val capturedInternalGetMetricRequests = batchGetInternalMetricsCaptor.allValues + assertThat(capturedInternalGetMetricRequests) + .containsExactly( + internalBatchGetMetricsRequest { + cmmsMeasurementConsumerId = + INTERNAL_PENDING_CROSS_PUBLISHER_WATCH_DURATION_METRIC.cmmsMeasurementConsumerId + externalMetricIds += + INTERNAL_PENDING_CROSS_PUBLISHER_WATCH_DURATION_METRIC.externalMetricId + }, + ) + + // Verify proto argument of MeasurementsCoroutineImplBase::getMeasurement + val getMeasurementCaptor: KArgumentCaptor = argumentCaptor() + verifyBlocking(measurementsMock, never()) { getMeasurement(getMeasurementCaptor.capture()) } + + // Verify proto argument of internal MeasurementsCoroutineImplBase::batchSetMeasurementResults + val batchSetMeasurementResultsCaptor: KArgumentCaptor = + argumentCaptor() + verifyBlocking(internalMeasurementsMock, never()) { + batchSetMeasurementResults(batchSetMeasurementResultsCaptor.capture()) + } + + // Verify proto argument of internal + // MeasurementsCoroutineImplBase::batchSetMeasurementFailures + val batchSetMeasurementFailuresCaptor: KArgumentCaptor = + argumentCaptor() + verifyBlocking(internalMeasurementsMock, never()) { + batchSetMeasurementFailures(batchSetMeasurementFailuresCaptor.capture()) + } + + assertThat(result).isEqualTo(SUCCEEDED_CROSS_PUBLISHER_WATCH_DURATION_METRIC) + } + + @Test + fun `getMetric throws INVALID_ARGUMENT when Report name is invalid`() { + val request = getMetricRequest { name = "invalid_metric_name" } + + val exception = + assertFailsWith { + withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { + runBlocking { service.getMetric(request) } + } + } + assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) + } + + @Test + fun `getMetric throws PERMISSION_DENIED when MeasurementConsumer's identity does not match`() { + val request = getMetricRequest { name = PENDING_CROSS_PUBLISHER_WATCH_DURATION_METRIC.name } + + val exception = + assertFailsWith { + withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.last().name, CONFIG) { + runBlocking { service.getMetric(request) } + } + } + + assertThat(exception.status.code).isEqualTo(Status.Code.PERMISSION_DENIED) + } + + @Test + fun `getMetric throws UNAUTHENTICATED when the caller is not a MeasurementConsumer`() { + val request = getMetricRequest { name = PENDING_CROSS_PUBLISHER_WATCH_DURATION_METRIC.name } + + val exception = + assertFailsWith { + withDataProviderPrincipal(DATA_PROVIDERS.values.first().name) { + runBlocking { service.getMetric(request) } + } + } + + assertThat(exception.status.code).isEqualTo(Status.Code.UNAUTHENTICATED) + } + + @Test + fun `getMetric throws FAILED_PRECONDITION when the measurement public key is not valid`() = + runBlocking { + whenever(measurementsMock.getMeasurement(any())) + .thenReturn( + SUCCEEDED_UNION_ALL_REACH_MEASUREMENT, + SUCCEEDED_UNION_ALL_BUT_LAST_PUBLISHER_REACH_MEASUREMENT.copy { + measurementSpec = + signMeasurementSpec( + UNION_ALL_BUT_LAST_PUBLISHER_REACH_MEASUREMENT_SPEC.copy { + measurementPublicKey = + MEASUREMENT_CONSUMER_PUBLIC_KEY.copy { clearData() }.toByteString() + }, + MEASUREMENT_CONSUMER_SIGNING_KEY_HANDLE + ) + }, + ) + + val request = getMetricRequest { name = PENDING_INCREMENTAL_REACH_METRIC.name } + + val exception = + assertFailsWith { + withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { + runBlocking { service.getMetric(request) } + } + } + + assertThat(exception.status.code).isEqualTo(Status.Code.FAILED_PRECONDITION) + assertThat(exception) + .hasMessageThat() + .contains(SUCCEEDED_UNION_ALL_BUT_LAST_PUBLISHER_REACH_MEASUREMENT.name) + } + + @Test + fun `batchGetMetrics returns metrics with SUCCEEDED when the metric is already succeeded`() = + runBlocking { + whenever(internalMetricsMock.batchGetMetrics(any())) + .thenReturn( + internalBatchGetMetricsResponse { + metrics += INTERNAL_SUCCEEDED_INCREMENTAL_REACH_METRIC + metrics += INTERNAL_PENDING_SINGLE_PUBLISHER_IMPRESSION_METRIC + } + ) + + val request = batchGetMetricsRequest { + parent = MEASUREMENT_CONSUMERS.values.first().name + names += SUCCEEDED_INCREMENTAL_REACH_METRIC.name + names += PENDING_SINGLE_PUBLISHER_IMPRESSION_METRIC.name + } + + val result = + withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { + runBlocking { service.batchGetMetrics(request) } + } + + // Verify proto argument of internal MetricsCoroutineImplBase::batchGetMetrics + val batchGetInternalMetricsCaptor: KArgumentCaptor = + argumentCaptor() + verifyBlocking(internalMetricsMock, times(1)) { + batchGetMetrics(batchGetInternalMetricsCaptor.capture()) + } + val capturedInternalGetMetricRequests = batchGetInternalMetricsCaptor.allValues + assertThat(capturedInternalGetMetricRequests) + .containsExactly( + internalBatchGetMetricsRequest { + cmmsMeasurementConsumerId = + INTERNAL_SUCCEEDED_INCREMENTAL_REACH_METRIC.cmmsMeasurementConsumerId + externalMetricIds += INTERNAL_SUCCEEDED_INCREMENTAL_REACH_METRIC.externalMetricId + externalMetricIds += + INTERNAL_PENDING_SINGLE_PUBLISHER_IMPRESSION_METRIC.externalMetricId + } + ) + + // Verify proto argument of internal MeasurementsCoroutineImplBase::batchSetMeasurementResults + val batchSetMeasurementResultsCaptor: KArgumentCaptor = + argumentCaptor() + verifyBlocking(internalMeasurementsMock, never()) { + batchSetMeasurementResults(batchSetMeasurementResultsCaptor.capture()) + } + + // Verify proto argument of internal + // MeasurementsCoroutineImplBase::batchSetMeasurementFailures + val batchSetMeasurementFailuresCaptor: KArgumentCaptor = + argumentCaptor() + verifyBlocking(internalMeasurementsMock, never()) { + batchSetMeasurementFailures(batchSetMeasurementFailuresCaptor.capture()) + } + + assertThat(result) + .isEqualTo( + batchGetMetricsResponse { + metrics += SUCCEEDED_INCREMENTAL_REACH_METRIC + metrics += PENDING_SINGLE_PUBLISHER_IMPRESSION_METRIC + } + ) + } + + @Test + fun `batchGetMetrics returns metrics with RUNNING when measurements are pending`() = runBlocking { + val request = batchGetMetricsRequest { + parent = MEASUREMENT_CONSUMERS.values.first().name + names += PENDING_INCREMENTAL_REACH_METRIC.name + names += PENDING_SINGLE_PUBLISHER_IMPRESSION_METRIC.name + } + + val result = + withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { + runBlocking { service.batchGetMetrics(request) } + } + + // Verify proto argument of internal MetricsCoroutineImplBase::batchGetMetrics + val batchGetInternalMetricsCaptor: KArgumentCaptor = + argumentCaptor() + verifyBlocking(internalMetricsMock, times(1)) { + batchGetMetrics(batchGetInternalMetricsCaptor.capture()) + } + val capturedInternalGetMetricRequests = batchGetInternalMetricsCaptor.allValues + assertThat(capturedInternalGetMetricRequests) + .containsExactly( + internalBatchGetMetricsRequest { + cmmsMeasurementConsumerId = + INTERNAL_PENDING_INCREMENTAL_REACH_METRIC.cmmsMeasurementConsumerId + externalMetricIds += INTERNAL_PENDING_INCREMENTAL_REACH_METRIC.externalMetricId + externalMetricIds += INTERNAL_PENDING_SINGLE_PUBLISHER_IMPRESSION_METRIC.externalMetricId + } + ) + + // Verify proto argument of MeasurementsCoroutineImplBase::getMeasurement + val getMeasurementCaptor: KArgumentCaptor = argumentCaptor() + verifyBlocking(measurementsMock, times(3)) { getMeasurement(getMeasurementCaptor.capture()) } + val capturedGetMeasurementRequests = getMeasurementCaptor.allValues + assertThat(capturedGetMeasurementRequests) + .containsExactly( + getMeasurementRequest { name = PENDING_UNION_ALL_REACH_MEASUREMENT.name }, + getMeasurementRequest { + name = PENDING_UNION_ALL_BUT_LAST_PUBLISHER_REACH_MEASUREMENT.name + }, + getMeasurementRequest { name = PENDING_SINGLE_PUBLISHER_IMPRESSION_MEASUREMENT.name }, + ) + + // Verify proto argument of internal MeasurementsCoroutineImplBase::batchSetMeasurementResults + val batchSetMeasurementResultsCaptor: KArgumentCaptor = + argumentCaptor() + verifyBlocking(internalMeasurementsMock, never()) { + batchSetMeasurementResults(batchSetMeasurementResultsCaptor.capture()) + } + + // Verify proto argument of internal + // MeasurementsCoroutineImplBase::batchSetMeasurementFailures + val batchSetMeasurementFailuresCaptor: KArgumentCaptor = + argumentCaptor() + verifyBlocking(internalMeasurementsMock, never()) { + batchSetMeasurementFailures(batchSetMeasurementFailuresCaptor.capture()) + } + + assertThat(result) + .isEqualTo( + batchGetMetricsResponse { + metrics += PENDING_INCREMENTAL_REACH_METRIC + metrics += PENDING_SINGLE_PUBLISHER_IMPRESSION_METRIC + } + ) + } + + @Test + fun `batchGetMetrics throws exception when number of requests exceeds limit`() = runBlocking { + val request = batchGetMetricsRequest { + parent = MEASUREMENT_CONSUMERS.values.first().name + names += List(MAX_BATCH_SIZE + 1) { "metric_name" } + } + + val exception = + assertFailsWith { + withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { + runBlocking { service.batchGetMetrics(request) } + } + } + + assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) + assertThat(exception.status.description) + .isEqualTo("At most $MAX_BATCH_SIZE metrics can be supported in a batch.") + } } private fun EventGroupKey.toInternal(): InternalReportingSet.Primitive.EventGroupKey { From 8e73f0122b616215132ae2cc7cd1fbaf565c167f Mon Sep 17 00:00:00 2001 From: Rieman Li Date: Mon, 24 Apr 2023 19:20:04 +0000 Subject: [PATCH 08/12] Add more info to different error status. --- .../service/api/v2alpha/MetricsService.kt | 93 +++++++++++++++---- .../service/api/v2alpha/MetricsServiceTest.kt | 17 ++-- 2 files changed, 84 insertions(+), 26 deletions(-) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt index 4f9bf3bb3b8..3aa7ee4d953 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt @@ -298,7 +298,21 @@ class MetricsService( .withAuthenticationKey(principal.config.apiKey) .createMeasurement(createMeasurementRequest) } catch (e: StatusException) { - throw Exception("Unable to create a CMMS measurement.", e) + throw when (e.status.code) { + Status.Code.INVALID_ARGUMENT -> + Status.INVALID_ARGUMENT.withDescription("Required field unspecified or invalid.") + Status.Code.PERMISSION_DENIED -> + Status.PERMISSION_DENIED.withDescription( + "Cannot create a CMMS Measurement for another MeasurementConsumer." + ) + Status.Code.FAILED_PRECONDITION -> + Status.FAILED_PRECONDITION.withDescription("Failed precondition.") + Status.Code.NOT_FOUND -> + Status.NOT_FOUND.withDescription("${measurementConsumer.name} is not found.") + else -> Status.UNKNOWN.withDescription("Unable to create a CMMS measurement.") + } + .withCause(e) + .asRuntimeException() } } @@ -409,8 +423,8 @@ class MetricsService( } catch (e: StatusException) { throw when (e.status.code) { Status.Code.NOT_FOUND -> - Status.FAILED_PRECONDITION.withDescription("$dataProviderName not found") - else -> Status.UNKNOWN.withDescription("Unable to retrieve $dataProviderName") + Status.FAILED_PRECONDITION.withDescription("$dataProviderName not found.") + else -> Status.UNKNOWN.withDescription("Unable to retrieve $dataProviderName.") } .withCause(e) .asRuntimeException() @@ -422,7 +436,16 @@ class MetricsService( .withAuthenticationKey(apiAuthenticationKey) .getCertificate(getCertificateRequest { name = dataProvider.certificate }) } catch (e: StatusException) { - throw Exception("Unable to retrieve Certificate ${dataProvider.certificate}", e) + throw when (e.status.code) { + Status.Code.NOT_FOUND -> + Status.NOT_FOUND.withDescription("${dataProvider.certificate} not found.") + else -> + Status.UNKNOWN.withDescription( + "Unable to retrieve Certificate ${dataProvider.certificate}." + ) + } + .withCause(e) + .asRuntimeException() } if ( certificate.revocationState != Certificate.RevocationState.REVOCATION_STATE_UNSPECIFIED @@ -539,10 +562,16 @@ class MetricsService( getMeasurementConsumerRequest { name = principal.resourceKey.toName() } ) } catch (e: StatusException) { - throw Exception( - "Unable to retrieve the measurement consumer " + "[${principal.resourceKey.toName()}].", - e - ) + throw when (e.status.code) { + Status.Code.NOT_FOUND -> + Status.NOT_FOUND.withDescription("${principal.resourceKey.toName()} not found.") + else -> + Status.UNKNOWN.withDescription( + "Unable to retrieve the measurement consumer [${principal.resourceKey.toName()}]." + ) + } + .withCause(e) + .asRuntimeException() } } @@ -584,11 +613,19 @@ class MetricsService( .getCertificate(getCertificateRequest { name = principal.config.signingCertificateName }) .x509Der } catch (e: StatusException) { - throw Exception( - "Unable to retrieve the signing certificate for the measurement consumer " + - "[$principal.config.signingCertificateName].", - e - ) + throw when (e.status.code) { + Status.Code.NOT_FOUND -> + Status.NOT_FOUND.withDescription( + "${principal.config.signingCertificateName} not found." + ) + else -> + Status.UNKNOWN.withDescription( + "Unable to retrieve the signing certificate " + + "[${principal.config.signingCertificateName}] for the measurement consumer." + ) + } + .withCause(e) + .asRuntimeException() } } @@ -724,7 +761,20 @@ class MetricsService( .withAuthenticationKey(apiAuthenticationKey) .getMeasurement(getMeasurementRequest { name = measurementResourceName }) } catch (e: StatusException) { - throw Exception("Unable to retrieve the measurement [$measurementResourceName].", e) + throw when (e.status.code) { + Status.Code.NOT_FOUND -> + Status.NOT_FOUND.withDescription("$measurementResourceName not found.") + Status.Code.PERMISSION_DENIED -> + Status.PERMISSION_DENIED.withDescription( + "Doesn't have permission to get $measurementResourceName." + ) + else -> + Status.UNKNOWN.withDescription( + "Unable to retrieve the measurement [$measurementResourceName]." + ) + } + .withCause(e) + .asRuntimeException() } } } @@ -770,10 +820,17 @@ class MetricsService( .withAuthenticationKey(apiAuthenticationKey) .getCertificate(getCertificateRequest { name = measurementResultPair.certificate }) } catch (e: StatusException) { - throw Exception( - "Unable to retrieve the certificate [${measurementResultPair.certificate}].", - e - ) + throw when (e.status.code) { + Status.Code.NOT_FOUND -> + Status.NOT_FOUND.withDescription("${measurementResultPair.certificate} not found.") + else -> + Status.UNKNOWN.withDescription( + "Unable to retrieve the certificate " + + "[${measurementResultPair.certificate}] for the measurement consumer." + ) + } + .withCause(e) + .asRuntimeException() } val signedResult = diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsServiceTest.kt index 89b06c42caf..ff5c4b01747 100644 --- a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsServiceTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsServiceTest.kt @@ -102,6 +102,7 @@ import org.wfanet.measurement.common.crypto.subjectKeyIdentifier import org.wfanet.measurement.common.crypto.testing.loadSigningKey import org.wfanet.measurement.common.crypto.tink.loadPrivateKey import org.wfanet.measurement.common.getRuntimePath +import org.wfanet.measurement.common.grpc.grpcStatusCode import org.wfanet.measurement.common.grpc.testing.GrpcTestServerRule import org.wfanet.measurement.common.grpc.testing.mockService import org.wfanet.measurement.common.identity.ExternalId @@ -2737,7 +2738,7 @@ class MetricsServiceTest { } @Test - fun `createMetric throws exception when the CMMs createMeasurement throws exception`() = + fun `createMetric throws exception when the CMMs createMeasurement throws INVALID_ARGUMENT`() = runBlocking { whenever(measurementsMock.createMeasurement(any())) .thenThrow(StatusRuntimeException(Status.INVALID_ARGUMENT)) @@ -2753,8 +2754,9 @@ class MetricsServiceTest { runBlocking { service.createMetric(request) } } } - val expectedExceptionDescription = "Unable to create a CMMS measurement." - assertThat(exception.message).isEqualTo(expectedExceptionDescription) + assertThat(exception.grpcStatusCode()).isEqualTo(Status.Code.INVALID_ARGUMENT) + val expectedExceptionDescription = "Required field unspecified or invalid." + assertThat(exception.message).contains(expectedExceptionDescription) } @Test @@ -2780,9 +2782,9 @@ class MetricsServiceTest { } @Test - fun `createMetric throws exception when getMeasurementConsumer throws exception`() = runBlocking { + fun `createMetric throws exception when getMeasurementConsumer throws NOT_FOUND`() = runBlocking { whenever(measurementConsumersMock.getMeasurementConsumer(any())) - .thenThrow(StatusRuntimeException(Status.INVALID_ARGUMENT)) + .thenThrow(StatusRuntimeException(Status.NOT_FOUND)) val request = createMetricRequest { parent = MEASUREMENT_CONSUMERS.values.first().name @@ -2795,9 +2797,8 @@ class MetricsServiceTest { runBlocking { service.createMetric(request) } } } - val expectedExceptionDescription = - "Unable to retrieve the measurement consumer [${MEASUREMENT_CONSUMERS.values.first().name}]." - assertThat(exception.message).isEqualTo(expectedExceptionDescription) + assertThat(exception.grpcStatusCode()).isEqualTo(Status.Code.NOT_FOUND) + assertThat(exception.message).contains(MEASUREMENT_CONSUMERS.values.first().name) } @Test From c079de0fdfc7e39cdc7498648bc4aaa8817c5b71 Mon Sep 17 00:00:00 2001 From: Rieman Li Date: Tue, 25 Apr 2023 21:06:31 +0000 Subject: [PATCH 09/12] Fix unit tests. --- .../service/api/v2alpha/MetricsServiceTest.kt | 683 ++++-------------- 1 file changed, 151 insertions(+), 532 deletions(-) diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsServiceTest.kt index ff5c4b01747..87392c56e5a 100644 --- a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsServiceTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsServiceTest.kt @@ -150,7 +150,6 @@ import org.wfanet.measurement.internal.reporting.v2.batchCreateMetricsRequest as import org.wfanet.measurement.internal.reporting.v2.batchCreateMetricsResponse as internalBatchCreateMetricsResponse import org.wfanet.measurement.internal.reporting.v2.batchGetMetricsRequest as internalBatchGetMetricsRequest import org.wfanet.measurement.internal.reporting.v2.batchGetMetricsResponse as internalBatchGetMetricsResponse -import org.wfanet.measurement.internal.reporting.v2.batchGetReportingSetsRequest import org.wfanet.measurement.internal.reporting.v2.batchGetReportingSetsResponse import org.wfanet.measurement.internal.reporting.v2.batchSetCmmsMeasurementFailuresResponse import org.wfanet.measurement.internal.reporting.v2.batchSetCmmsMeasurementIdsRequest @@ -1314,13 +1313,26 @@ class MetricsServiceTest { InternalReportingSetsGrpcKt.ReportingSetsCoroutineImplBase = mockService { onBlocking { batchGetReportingSets(any()) } - .thenReturn( - batchGetReportingSetsResponse { reportingSets += INTERNAL_INCREMENTAL_REPORTING_SET }, + .thenAnswer { + val request = it.arguments[0] as BatchGetReportingSetsRequest + val internalReportingSetsMap = + mapOf( + INTERNAL_INCREMENTAL_REPORTING_SET.externalReportingSetId to + INTERNAL_INCREMENTAL_REPORTING_SET, + INTERNAL_UNION_ALL_REPORTING_SET.externalReportingSetId to + INTERNAL_UNION_ALL_REPORTING_SET, + INTERNAL_UNION_ALL_BUT_LAST_PUBLISHER_REPORTING_SET.externalReportingSetId to + INTERNAL_UNION_ALL_BUT_LAST_PUBLISHER_REPORTING_SET, + INTERNAL_SINGLE_PUBLISHER_REPORTING_SET.externalReportingSetId to + INTERNAL_SINGLE_PUBLISHER_REPORTING_SET + ) batchGetReportingSetsResponse { - reportingSets += INTERNAL_UNION_ALL_REPORTING_SET - reportingSets += INTERNAL_UNION_ALL_BUT_LAST_PUBLISHER_REPORTING_SET + reportingSets += + request.externalReportingSetIdsList.map { externalReportingSetId -> + internalReportingSetsMap.getValue(externalReportingSetId) + } } - ) + } } private val internalMeasurementsMock: InternalMeasurementsCoroutineImplBase = mockService { @@ -1346,12 +1358,16 @@ class MetricsServiceTest { } private val measurementsMock: MeasurementsCoroutineImplBase = mockService { - onBlocking { getMeasurement(any()) } - .thenReturn( + for (pendingMeasurement in + listOf( PENDING_UNION_ALL_REACH_MEASUREMENT, PENDING_UNION_ALL_BUT_LAST_PUBLISHER_REACH_MEASUREMENT, - PENDING_SINGLE_PUBLISHER_IMPRESSION_MEASUREMENT - ) + PENDING_SINGLE_PUBLISHER_IMPRESSION_MEASUREMENT, + PENDING_UNION_ALL_WATCH_DURATION_MEASUREMENT + )) { + onBlocking { getMeasurement(eq(getMeasurementRequest { name = pendingMeasurement.name })) } + .thenReturn(pendingMeasurement) + } onBlocking { createMeasurement(any()) } .thenAnswer { @@ -1371,7 +1387,12 @@ class MetricsServiceTest { private val measurementConsumersMock: MeasurementConsumersGrpcKt.MeasurementConsumersCoroutineImplBase = mockService { - onBlocking { getMeasurementConsumer(any()) }.thenReturn(MEASUREMENT_CONSUMERS.values.first()) + onBlocking { + getMeasurementConsumer( + eq(getMeasurementConsumerRequest { name = MEASUREMENT_CONSUMERS.values.first().name }) + ) + } + .thenReturn(MEASUREMENT_CONSUMERS.values.first()) } private val dataProvidersMock: DataProvidersGrpcKt.DataProvidersCoroutineImplBase = mockService { @@ -1461,28 +1482,6 @@ class MetricsServiceTest { val expected = PENDING_INCREMENTAL_REACH_METRIC - // Verify proto argument of the internal ReportingSetsCoroutineImplBase::batchGetReportingSets - val batchGetReportingSetsCaptor: KArgumentCaptor = - argumentCaptor() - verifyBlocking(internalReportingSetsMock, times(2)) { - batchGetReportingSets(batchGetReportingSetsCaptor.capture()) - } - - val capturedBatchGetReportingSetsRequests = batchGetReportingSetsCaptor.allValues - assertThat(capturedBatchGetReportingSetsRequests) - .containsExactly( - batchGetReportingSetsRequest { - cmmsMeasurementConsumerId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportingSetIds += INTERNAL_INCREMENTAL_REPORTING_SET.externalReportingSetId - }, - batchGetReportingSetsRequest { - cmmsMeasurementConsumerId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportingSetIds += INTERNAL_UNION_ALL_REPORTING_SET.externalReportingSetId - externalReportingSetIds += - INTERNAL_UNION_ALL_BUT_LAST_PUBLISHER_REPORTING_SET.externalReportingSetId - } - ) - // Verify proto argument of the internal MetricsCoroutineImplBase::createMetric verifyProtoArgument(internalMetricsMock, MetricsCoroutineImplBase::createMetric) .ignoringRepeatedFieldOrder() @@ -1490,13 +1489,6 @@ class MetricsServiceTest { internalCreateMetricRequest { metric = INTERNAL_REQUESTING_INCREMENTAL_REACH_METRIC } ) - // Verify proto argument of MeasurementConsumersCoroutineImplBase::getMeasurementConsumer - verifyProtoArgument( - measurementConsumersMock, - MeasurementConsumersGrpcKt.MeasurementConsumersCoroutineImplBase::getMeasurementConsumer - ) - .isEqualTo(getMeasurementConsumerRequest { name = MEASUREMENT_CONSUMERS.values.first().name }) - // Verify proto argument of MeasurementsCoroutineImplBase::createMeasurement val measurementsCaptor: KArgumentCaptor = argumentCaptor() verifyBlocking(measurementsMock, times(2)) { createMeasurement(measurementsCaptor.capture()) } @@ -1585,10 +1577,6 @@ class MetricsServiceTest { @Test fun `createMetric creates CMMS measurements for single pub impression metric`() = runBlocking { - whenever(internalReportingSetsMock.batchGetReportingSets(any())) - .thenReturn( - batchGetReportingSetsResponse { reportingSets += INTERNAL_SINGLE_PUBLISHER_REPORTING_SET } - ) whenever(internalMetricsMock.createMetric(any())) .thenReturn(INTERNAL_PENDING_INITIAL_SINGLE_PUBLISHER_IMPRESSION_METRIC) whenever(measurementsMock.createMeasurement(any())) @@ -1606,27 +1594,6 @@ class MetricsServiceTest { val expected = PENDING_SINGLE_PUBLISHER_IMPRESSION_METRIC - // Verify proto argument of the internal ReportingSetsCoroutineImplBase::batchGetReportingSets - val batchGetReportingSetsCaptor: KArgumentCaptor = - argumentCaptor() - verifyBlocking(internalReportingSetsMock, times(2)) { - batchGetReportingSets(batchGetReportingSetsCaptor.capture()) - } - - val capturedBatchGetReportingSetsRequests = batchGetReportingSetsCaptor.allValues - assertThat(capturedBatchGetReportingSetsRequests) - .containsExactly( - batchGetReportingSetsRequest { - cmmsMeasurementConsumerId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportingSetIds += INTERNAL_SINGLE_PUBLISHER_REPORTING_SET.externalReportingSetId - }, - batchGetReportingSetsRequest { - cmmsMeasurementConsumerId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportingSetIds += - INTERNAL_REQUESTING_SINGLE_PUBLISHER_IMPRESSION_METRIC.externalReportingSetId - } - ) - // Verify proto argument of the internal MetricsCoroutineImplBase::createMetric verifyProtoArgument(internalMetricsMock, MetricsCoroutineImplBase::createMetric) .ignoringRepeatedFieldOrder() @@ -1636,13 +1603,6 @@ class MetricsServiceTest { } ) - // Verify proto argument of MeasurementConsumersCoroutineImplBase::getMeasurementConsumer - verifyProtoArgument( - measurementConsumersMock, - MeasurementConsumersGrpcKt.MeasurementConsumersCoroutineImplBase::getMeasurementConsumer - ) - .isEqualTo(getMeasurementConsumerRequest { name = MEASUREMENT_CONSUMERS.values.first().name }) - // Verify proto argument of MeasurementsCoroutineImplBase::createMeasurement val measurementsCaptor: KArgumentCaptor = argumentCaptor() verifyBlocking(measurementsMock, times(1)) { createMeasurement(measurementsCaptor.capture()) } @@ -1777,10 +1737,6 @@ class MetricsServiceTest { signMeasurementSpec(cmmsMeasurementSpec, MEASUREMENT_CONSUMER_SIGNING_KEY_HANDLE) } - whenever(internalReportingSetsMock.batchGetReportingSets(any())) - .thenReturn( - batchGetReportingSetsResponse { reportingSets += INTERNAL_SINGLE_PUBLISHER_REPORTING_SET } - ) whenever(internalMetricsMock.createMetric(any())) .thenReturn(internalPendingInitialSinglePublisherImpressionMetric) whenever(measurementsMock.createMeasurement(any())) @@ -1832,27 +1788,6 @@ class MetricsServiceTest { } } - // Verify proto argument of the internal ReportingSetsCoroutineImplBase::batchGetReportingSets - val batchGetReportingSetsCaptor: KArgumentCaptor = - argumentCaptor() - verifyBlocking(internalReportingSetsMock, times(2)) { - batchGetReportingSets(batchGetReportingSetsCaptor.capture()) - } - - val capturedBatchGetReportingSetsRequests = batchGetReportingSetsCaptor.allValues - assertThat(capturedBatchGetReportingSetsRequests) - .containsExactly( - batchGetReportingSetsRequest { - cmmsMeasurementConsumerId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportingSetIds += INTERNAL_SINGLE_PUBLISHER_REPORTING_SET.externalReportingSetId - }, - batchGetReportingSetsRequest { - cmmsMeasurementConsumerId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportingSetIds += - internalRequestingSinglePublisherImpressionMetric.externalReportingSetId - } - ) - // Verify proto argument of the internal MetricsCoroutineImplBase::createMetric verifyProtoArgument(internalMetricsMock, MetricsCoroutineImplBase::createMetric) .ignoringRepeatedFieldOrder() @@ -1860,13 +1795,6 @@ class MetricsServiceTest { internalCreateMetricRequest { metric = internalRequestingSinglePublisherImpressionMetric } ) - // Verify proto argument of MeasurementConsumersCoroutineImplBase::getMeasurementConsumer - verifyProtoArgument( - measurementConsumersMock, - MeasurementConsumersGrpcKt.MeasurementConsumersCoroutineImplBase::getMeasurementConsumer - ) - .isEqualTo(getMeasurementConsumerRequest { name = MEASUREMENT_CONSUMERS.values.first().name }) - // Verify proto argument of MeasurementsCoroutineImplBase::createMeasurement val measurementsCaptor: KArgumentCaptor = argumentCaptor() verifyBlocking(measurementsMock, times(1)) { createMeasurement(measurementsCaptor.capture()) } @@ -1952,28 +1880,6 @@ class MetricsServiceTest { val expected = PENDING_INCREMENTAL_REACH_METRIC - // Verify proto argument of the internal ReportingSetsCoroutineImplBase::batchGetReportingSets - val batchGetReportingSetsCaptor: KArgumentCaptor = - argumentCaptor() - verifyBlocking(internalReportingSetsMock, times(2)) { - batchGetReportingSets(batchGetReportingSetsCaptor.capture()) - } - - val capturedBatchGetReportingSetsRequests = batchGetReportingSetsCaptor.allValues - assertThat(capturedBatchGetReportingSetsRequests) - .containsExactly( - batchGetReportingSetsRequest { - cmmsMeasurementConsumerId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportingSetIds += INTERNAL_INCREMENTAL_REPORTING_SET.externalReportingSetId - }, - batchGetReportingSetsRequest { - cmmsMeasurementConsumerId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportingSetIds += INTERNAL_UNION_ALL_REPORTING_SET.externalReportingSetId - externalReportingSetIds += - INTERNAL_UNION_ALL_BUT_LAST_PUBLISHER_REPORTING_SET.externalReportingSetId - } - ) - // Verify proto argument of the internal MetricsCoroutineImplBase::createMetric verifyProtoArgument(internalMetricsMock, MetricsCoroutineImplBase::createMetric) .ignoringRepeatedFieldOrder() @@ -1984,13 +1890,6 @@ class MetricsServiceTest { } ) - // Verify proto argument of MeasurementConsumersCoroutineImplBase::getMeasurementConsumer - verifyProtoArgument( - measurementConsumersMock, - MeasurementConsumersGrpcKt.MeasurementConsumersCoroutineImplBase::getMeasurementConsumer - ) - .isEqualTo(getMeasurementConsumerRequest { name = MEASUREMENT_CONSUMERS.values.first().name }) - // Verify proto argument of MeasurementsCoroutineImplBase::createMeasurement val measurementsCaptor: KArgumentCaptor = argumentCaptor() verifyBlocking(measurementsMock, times(2)) { createMeasurement(measurementsCaptor.capture()) } @@ -2094,28 +1993,6 @@ class MetricsServiceTest { val expected = PENDING_INCREMENTAL_REACH_METRIC - // Verify proto argument of the internal ReportingSetsCoroutineImplBase::batchGetReportingSets - val batchGetReportingSetsCaptor: KArgumentCaptor = - argumentCaptor() - verifyBlocking(internalReportingSetsMock, times(2)) { - batchGetReportingSets(batchGetReportingSetsCaptor.capture()) - } - - val capturedBatchGetReportingSetsRequests = batchGetReportingSetsCaptor.allValues - assertThat(capturedBatchGetReportingSetsRequests) - .containsExactly( - batchGetReportingSetsRequest { - cmmsMeasurementConsumerId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportingSetIds += INTERNAL_INCREMENTAL_REPORTING_SET.externalReportingSetId - }, - batchGetReportingSetsRequest { - cmmsMeasurementConsumerId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportingSetIds += INTERNAL_UNION_ALL_REPORTING_SET.externalReportingSetId - externalReportingSetIds += - INTERNAL_UNION_ALL_BUT_LAST_PUBLISHER_REPORTING_SET.externalReportingSetId - } - ) - // Verify proto argument of the internal MetricsCoroutineImplBase::createMetric verifyProtoArgument(internalMetricsMock, MetricsCoroutineImplBase::createMetric) .ignoringRepeatedFieldOrder() @@ -2123,13 +2000,6 @@ class MetricsServiceTest { internalCreateMetricRequest { metric = INTERNAL_REQUESTING_INCREMENTAL_REACH_METRIC } ) - // Verify proto argument of MeasurementConsumersCoroutineImplBase::getMeasurementConsumer - verifyProtoArgument( - measurementConsumersMock, - MeasurementConsumersGrpcKt.MeasurementConsumersCoroutineImplBase::getMeasurementConsumer - ) - .isEqualTo(getMeasurementConsumerRequest { name = MEASUREMENT_CONSUMERS.values.first().name }) - // Verify proto argument of DataProvidersCoroutineImplBase::getDataProvider val dataProvidersCaptor: KArgumentCaptor = argumentCaptor() verifyBlocking(dataProvidersMock, never()) { getDataProvider(dataProvidersCaptor.capture()) } @@ -2166,28 +2036,6 @@ class MetricsServiceTest { val expected = PENDING_INCREMENTAL_REACH_METRIC - // Verify proto argument of the internal ReportingSetsCoroutineImplBase::batchGetReportingSets - val batchGetReportingSetsCaptor: KArgumentCaptor = - argumentCaptor() - verifyBlocking(internalReportingSetsMock, times(2)) { - batchGetReportingSets(batchGetReportingSetsCaptor.capture()) - } - - val capturedBatchGetReportingSetsRequests = batchGetReportingSetsCaptor.allValues - assertThat(capturedBatchGetReportingSetsRequests) - .containsExactly( - batchGetReportingSetsRequest { - cmmsMeasurementConsumerId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportingSetIds += INTERNAL_INCREMENTAL_REPORTING_SET.externalReportingSetId - }, - batchGetReportingSetsRequest { - cmmsMeasurementConsumerId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportingSetIds += INTERNAL_UNION_ALL_REPORTING_SET.externalReportingSetId - externalReportingSetIds += - INTERNAL_UNION_ALL_BUT_LAST_PUBLISHER_REPORTING_SET.externalReportingSetId - } - ) - // Verify proto argument of the internal MetricsCoroutineImplBase::createMetric verifyProtoArgument(internalMetricsMock, MetricsCoroutineImplBase::createMetric) .ignoringRepeatedFieldOrder() @@ -2198,13 +2046,6 @@ class MetricsServiceTest { } ) - // Verify proto argument of MeasurementConsumersCoroutineImplBase::getMeasurementConsumer - verifyProtoArgument( - measurementConsumersMock, - MeasurementConsumersGrpcKt.MeasurementConsumersCoroutineImplBase::getMeasurementConsumer - ) - .isEqualTo(getMeasurementConsumerRequest { name = MEASUREMENT_CONSUMERS.values.first().name }) - // Verify proto argument of DataProvidersCoroutineImplBase::getDataProvider val dataProvidersCaptor: KArgumentCaptor = argumentCaptor() verifyBlocking(dataProvidersMock, never()) { getDataProvider(dataProvidersCaptor.capture()) } @@ -2248,15 +2089,6 @@ class MetricsServiceTest { batchGetReportingSets(batchGetReportingSetsCaptor.capture()) } - val capturedBatchGetReportingSetsRequests = batchGetReportingSetsCaptor.allValues - assertThat(capturedBatchGetReportingSetsRequests) - .containsExactly( - batchGetReportingSetsRequest { - cmmsMeasurementConsumerId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportingSetIds += INTERNAL_INCREMENTAL_REPORTING_SET.externalReportingSetId - } - ) - // Verify proto argument of the internal MetricsCoroutineImplBase::createMetric verifyProtoArgument(internalMetricsMock, MetricsCoroutineImplBase::createMetric) .ignoringRepeatedFieldOrder() @@ -2718,24 +2550,22 @@ class MetricsServiceTest { } @Test - fun `createMetric throws exception when internal createMetric throws exception`() = runBlocking { - whenever(internalMetricsMock.createMetric(any())) - .thenThrow(StatusRuntimeException(Status.INVALID_ARGUMENT)) + fun `createMetric throws exception when internal createMetric throws exception`(): Unit = + runBlocking { + whenever(internalMetricsMock.createMetric(any())) + .thenThrow(StatusRuntimeException(Status.INVALID_ARGUMENT)) - val request = createMetricRequest { - parent = MEASUREMENT_CONSUMERS.values.first().name - metric = REQUESTING_INCREMENTAL_REACH_METRIC - } + val request = createMetricRequest { + parent = MEASUREMENT_CONSUMERS.values.first().name + metric = REQUESTING_INCREMENTAL_REACH_METRIC + } - val exception = assertFailsWith(Exception::class) { withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { runBlocking { service.createMetric(request) } } } - val expectedExceptionDescription = "Unable to create the metric in the reporting database." - assertThat(exception.message).isEqualTo(expectedExceptionDescription) - } + } @Test fun `createMetric throws exception when the CMMs createMeasurement throws INVALID_ARGUMENT`() = @@ -2755,12 +2585,10 @@ class MetricsServiceTest { } } assertThat(exception.grpcStatusCode()).isEqualTo(Status.Code.INVALID_ARGUMENT) - val expectedExceptionDescription = "Required field unspecified or invalid." - assertThat(exception.message).contains(expectedExceptionDescription) } @Test - fun `createMetric throws exception when batchSetCmmsMeasurementId throws exception`() = + fun `createMetric throws exception when batchSetCmmsMeasurementId throws exception`(): Unit = runBlocking { whenever(internalMeasurementsMock.batchSetCmmsMeasurementIds(any())) .thenThrow(StatusRuntimeException(Status.INVALID_ARGUMENT)) @@ -2770,15 +2598,11 @@ class MetricsServiceTest { metric = REQUESTING_INCREMENTAL_REACH_METRIC } - val exception = - assertFailsWith(Exception::class) { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.createMetric(request) } - } + assertFailsWith(Exception::class) { + withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { + runBlocking { service.createMetric(request) } } - val expectedExceptionDescription = - "Unable to set the CMMS measurement IDs for the measurements in the reporting database." - assertThat(exception.message).isEqualTo(expectedExceptionDescription) + } } @Test @@ -2840,17 +2664,6 @@ class MetricsServiceTest { @Test fun `batchCreateMetrics creates CMMS measurements`() = runBlocking { - whenever(internalReportingSetsMock.batchGetReportingSets(any())) - .thenReturn( - batchGetReportingSetsResponse { reportingSets += INTERNAL_INCREMENTAL_REPORTING_SET }, - batchGetReportingSetsResponse { reportingSets += INTERNAL_SINGLE_PUBLISHER_REPORTING_SET }, - batchGetReportingSetsResponse { - reportingSets += INTERNAL_UNION_ALL_REPORTING_SET - reportingSets += INTERNAL_UNION_ALL_BUT_LAST_PUBLISHER_REPORTING_SET - reportingSets += INTERNAL_SINGLE_PUBLISHER_REPORTING_SET - } - ) - val request = batchCreateMetricsRequest { parent = MEASUREMENT_CONSUMERS.values.first().name requests += createMetricRequest { @@ -2873,34 +2686,6 @@ class MetricsServiceTest { metrics += PENDING_SINGLE_PUBLISHER_IMPRESSION_METRIC } - // Verify proto argument of the internal ReportingSetsCoroutineImplBase::batchGetReportingSets - val batchGetReportingSetsCaptor: KArgumentCaptor = - argumentCaptor() - verifyBlocking(internalReportingSetsMock, times(3)) { - batchGetReportingSets(batchGetReportingSetsCaptor.capture()) - } - - val capturedBatchGetReportingSetsRequests = batchGetReportingSetsCaptor.allValues - assertThat(capturedBatchGetReportingSetsRequests) - .containsExactly( - batchGetReportingSetsRequest { - cmmsMeasurementConsumerId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportingSetIds += INTERNAL_INCREMENTAL_REPORTING_SET.externalReportingSetId - }, - batchGetReportingSetsRequest { - cmmsMeasurementConsumerId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportingSetIds += INTERNAL_SINGLE_PUBLISHER_REPORTING_SET.externalReportingSetId - }, - batchGetReportingSetsRequest { - cmmsMeasurementConsumerId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalReportingSetIds += INTERNAL_UNION_ALL_REPORTING_SET.externalReportingSetId - externalReportingSetIds += - INTERNAL_UNION_ALL_BUT_LAST_PUBLISHER_REPORTING_SET.externalReportingSetId - externalReportingSetIds += - INTERNAL_REQUESTING_SINGLE_PUBLISHER_IMPRESSION_METRIC.externalReportingSetId - } - ) - // Verify proto argument of the internal MetricsCoroutineImplBase::batchCreateMetrics verifyProtoArgument(internalMetricsMock, MetricsCoroutineImplBase::batchCreateMetrics) .ignoringRepeatedFieldOrder() @@ -2916,13 +2701,6 @@ class MetricsServiceTest { } ) - // Verify proto argument of MeasurementConsumersCoroutineImplBase::getMeasurementConsumer - verifyProtoArgument( - measurementConsumersMock, - MeasurementConsumersGrpcKt.MeasurementConsumersCoroutineImplBase::getMeasurementConsumer - ) - .isEqualTo(getMeasurementConsumerRequest { name = MEASUREMENT_CONSUMERS.values.first().name }) - // Verify proto argument of MeasurementsCoroutineImplBase::createMeasurement val measurementsCaptor: KArgumentCaptor = argumentCaptor() verifyBlocking(measurementsMock, times(3)) { createMeasurement(measurementsCaptor.capture()) } @@ -3072,19 +2850,6 @@ class MetricsServiceTest { } ) - // Verify proto argument of MeasurementsCoroutineImplBase::getMeasurement - val getMeasurementCaptor: KArgumentCaptor = argumentCaptor() - verifyBlocking(measurementsMock, times(3)) { getMeasurement(getMeasurementCaptor.capture()) } - val capturedGetMeasurementRequests = getMeasurementCaptor.allValues - assertThat(capturedGetMeasurementRequests) - .containsExactly( - getMeasurementRequest { name = PENDING_UNION_ALL_REACH_MEASUREMENT.name }, - getMeasurementRequest { - name = PENDING_UNION_ALL_BUT_LAST_PUBLISHER_REACH_MEASUREMENT.name - }, - getMeasurementRequest { name = PENDING_SINGLE_PUBLISHER_IMPRESSION_MEASUREMENT.name }, - ) - // Verify proto argument of internal MeasurementsCoroutineImplBase::batchSetMeasurementResults val batchSetMeasurementResultsCaptor: KArgumentCaptor = argumentCaptor() @@ -3154,18 +2919,6 @@ class MetricsServiceTest { } ) - // Verify proto argument of MeasurementsCoroutineImplBase::getMeasurement - val getMeasurementCaptor: KArgumentCaptor = argumentCaptor() - verifyBlocking(measurementsMock, times(2)) { getMeasurement(getMeasurementCaptor.capture()) } - val capturedGetMeasurementRequests = getMeasurementCaptor.allValues - assertThat(capturedGetMeasurementRequests) - .containsExactly( - getMeasurementRequest { name = PENDING_UNION_ALL_REACH_MEASUREMENT.name }, - getMeasurementRequest { - name = PENDING_UNION_ALL_BUT_LAST_PUBLISHER_REACH_MEASUREMENT.name - }, - ) - // Verify proto argument of internal MeasurementsCoroutineImplBase::batchSetMeasurementResults val batchSetMeasurementResultsCaptor: KArgumentCaptor = argumentCaptor() @@ -3238,15 +2991,6 @@ class MetricsServiceTest { } ) - // Verify proto argument of MeasurementsCoroutineImplBase::getMeasurement - val getMeasurementCaptor: KArgumentCaptor = argumentCaptor() - verifyBlocking(measurementsMock, times(1)) { getMeasurement(getMeasurementCaptor.capture()) } - val capturedGetMeasurementRequests = getMeasurementCaptor.allValues - assertThat(capturedGetMeasurementRequests) - .containsExactly( - getMeasurementRequest { name = PENDING_SINGLE_PUBLISHER_IMPRESSION_MEASUREMENT.name }, - ) - // Verify proto argument of internal MeasurementsCoroutineImplBase::batchSetMeasurementResults val batchSetMeasurementResultsCaptor: KArgumentCaptor = argumentCaptor() @@ -3301,19 +3045,6 @@ class MetricsServiceTest { } ) - // Verify proto argument of MeasurementsCoroutineImplBase::getMeasurement - val getMeasurementCaptor: KArgumentCaptor = argumentCaptor() - verifyBlocking(measurementsMock, times(3)) { getMeasurement(getMeasurementCaptor.capture()) } - val capturedGetMeasurementRequests = getMeasurementCaptor.allValues - assertThat(capturedGetMeasurementRequests) - .containsExactly( - getMeasurementRequest { name = PENDING_UNION_ALL_REACH_MEASUREMENT.name }, - getMeasurementRequest { - name = PENDING_UNION_ALL_BUT_LAST_PUBLISHER_REACH_MEASUREMENT.name - }, - getMeasurementRequest { name = PENDING_SINGLE_PUBLISHER_IMPRESSION_MEASUREMENT.name }, - ) - // Verify proto argument of internal MeasurementsCoroutineImplBase::batchSetMeasurementResults val batchSetMeasurementResultsCaptor: KArgumentCaptor = argumentCaptor() @@ -3386,15 +3117,6 @@ class MetricsServiceTest { } ) - // Verify proto argument of MeasurementsCoroutineImplBase::getMeasurement - val getMeasurementCaptor: KArgumentCaptor = argumentCaptor() - verifyBlocking(measurementsMock, times(1)) { getMeasurement(getMeasurementCaptor.capture()) } - val capturedGetMeasurementRequests = getMeasurementCaptor.allValues - assertThat(capturedGetMeasurementRequests) - .containsExactly( - getMeasurementRequest { name = PENDING_SINGLE_PUBLISHER_IMPRESSION_MEASUREMENT.name }, - ) - // Verify proto argument of internal MeasurementsCoroutineImplBase::batchSetMeasurementResults val batchSetMeasurementResultsCaptor: KArgumentCaptor = argumentCaptor() @@ -3469,15 +3191,6 @@ class MetricsServiceTest { } ) - // Verify proto argument of MeasurementsCoroutineImplBase::getMeasurement - val getMeasurementCaptor: KArgumentCaptor = argumentCaptor() - verifyBlocking(measurementsMock, times(1)) { getMeasurement(getMeasurementCaptor.capture()) } - val capturedGetMeasurementRequests = getMeasurementCaptor.allValues - assertThat(capturedGetMeasurementRequests) - .containsExactly( - getMeasurementRequest { name = PENDING_SINGLE_PUBLISHER_IMPRESSION_MEASUREMENT.name }, - ) - // Verify proto argument of internal MeasurementsCoroutineImplBase::batchSetMeasurementResults val batchSetMeasurementResultsCaptor: KArgumentCaptor = argumentCaptor() @@ -3519,13 +3232,32 @@ class MetricsServiceTest { } } - whenever(measurementsMock.getMeasurement(any())) - .thenReturn( - SUCCEEDED_UNION_ALL_REACH_MEASUREMENT, - SUCCEEDED_UNION_ALL_BUT_LAST_PUBLISHER_REACH_MEASUREMENT, - PENDING_SINGLE_PUBLISHER_IMPRESSION_MEASUREMENT + val measurementsMap = + mapOf( + SUCCEEDED_UNION_ALL_REACH_MEASUREMENT.name to SUCCEEDED_UNION_ALL_REACH_MEASUREMENT, + SUCCEEDED_UNION_ALL_BUT_LAST_PUBLISHER_REACH_MEASUREMENT.name to + SUCCEEDED_UNION_ALL_BUT_LAST_PUBLISHER_REACH_MEASUREMENT, + PENDING_SINGLE_PUBLISHER_IMPRESSION_MEASUREMENT.name to + PENDING_SINGLE_PUBLISHER_IMPRESSION_MEASUREMENT, + ) + + whenever(measurementsMock.getMeasurement(any())).thenAnswer { + val request = it.arguments[0] as GetMeasurementRequest + measurementsMap[request.name] + } + + whenever( + internalMetricsMock.batchGetMetrics( + eq( + internalBatchGetMetricsRequest { + cmmsMeasurementConsumerId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId + externalMetricIds += INTERNAL_PENDING_INCREMENTAL_REACH_METRIC.externalMetricId + externalMetricIds += + INTERNAL_PENDING_SINGLE_PUBLISHER_IMPRESSION_METRIC.externalMetricId + } + ) + ) ) - whenever(internalMetricsMock.batchGetMetrics(any())) .thenReturn( internalBatchGetMetricsResponse { metrics += @@ -3567,19 +3299,6 @@ class MetricsServiceTest { } ) - // Verify proto argument of MeasurementsCoroutineImplBase::getMeasurement - val getMeasurementCaptor: KArgumentCaptor = argumentCaptor() - verifyBlocking(measurementsMock, times(3)) { getMeasurement(getMeasurementCaptor.capture()) } - val capturedGetMeasurementRequests = getMeasurementCaptor.allValues - assertThat(capturedGetMeasurementRequests) - .containsExactly( - getMeasurementRequest { name = PENDING_UNION_ALL_REACH_MEASUREMENT.name }, - getMeasurementRequest { - name = PENDING_UNION_ALL_BUT_LAST_PUBLISHER_REACH_MEASUREMENT.name - }, - getMeasurementRequest { name = PENDING_SINGLE_PUBLISHER_IMPRESSION_MEASUREMENT.name }, - ) - // Verify proto argument of internal MeasurementsCoroutineImplBase::batchSetMeasurementResults val batchSetMeasurementResultsCaptor: KArgumentCaptor = argumentCaptor() @@ -3611,16 +3330,6 @@ class MetricsServiceTest { batchSetMeasurementFailures(batchSetMeasurementFailuresCaptor.capture()) } - // Verify proto argument of internal MetricsCoroutineImplBase::batchGetMetrics - verifyProtoArgument(internalMetricsMock, MetricsCoroutineImplBase::batchGetMetrics) - .isEqualTo( - internalBatchGetMetricsRequest { - cmmsMeasurementConsumerId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalMetricIds += INTERNAL_PENDING_INCREMENTAL_REACH_METRIC.externalMetricId - externalMetricIds += INTERNAL_PENDING_SINGLE_PUBLISHER_IMPRESSION_METRIC.externalMetricId - } - ) - assertThat(result).ignoringRepeatedFieldOrder().isEqualTo(expected) } @@ -3633,8 +3342,6 @@ class MetricsServiceTest { INTERNAL_PENDING_SINGLE_PUBLISHER_IMPRESSION_METRIC ) ) - whenever(measurementsMock.getMeasurement(any())) - .thenReturn(PENDING_SINGLE_PUBLISHER_IMPRESSION_MEASUREMENT) whenever(internalMetricsMock.batchGetMetrics(any())) .thenReturn( internalBatchGetMetricsResponse { @@ -3666,15 +3373,6 @@ class MetricsServiceTest { } ) - // Verify proto argument of MeasurementsCoroutineImplBase::getMeasurement - val getMeasurementCaptor: KArgumentCaptor = argumentCaptor() - verifyBlocking(measurementsMock, times(1)) { getMeasurement(getMeasurementCaptor.capture()) } - val capturedGetMeasurementRequests = getMeasurementCaptor.allValues - assertThat(capturedGetMeasurementRequests) - .containsExactly( - getMeasurementRequest { name = PENDING_SINGLE_PUBLISHER_IMPRESSION_MEASUREMENT.name }, - ) - // Verify proto argument of internal MeasurementsCoroutineImplBase::batchSetMeasurementResults val batchSetMeasurementResultsCaptor: KArgumentCaptor = argumentCaptor() @@ -3700,20 +3398,38 @@ class MetricsServiceTest { @Test fun `listMetrics returns failed metrics when the measurement is FAILED`() = runBlocking { - whenever(measurementsMock.getMeasurement(any())) - .thenReturn( - PENDING_UNION_ALL_REACH_MEASUREMENT, - PENDING_UNION_ALL_BUT_LAST_PUBLISHER_REACH_MEASUREMENT, - PENDING_SINGLE_PUBLISHER_IMPRESSION_MEASUREMENT.copy { - state = Measurement.State.FAILED - failure = failure { - reason = Measurement.Failure.Reason.REQUISITION_REFUSED - message = - INTERNAL_FAILED_SINGLE_PUBLISHER_IMPRESSION_MEASUREMENT.details.failure.message + val measurementsMap = + mapOf( + PENDING_UNION_ALL_REACH_MEASUREMENT.name to PENDING_UNION_ALL_REACH_MEASUREMENT, + PENDING_UNION_ALL_BUT_LAST_PUBLISHER_REACH_MEASUREMENT.name to + PENDING_UNION_ALL_BUT_LAST_PUBLISHER_REACH_MEASUREMENT, + PENDING_SINGLE_PUBLISHER_IMPRESSION_MEASUREMENT.name to + PENDING_SINGLE_PUBLISHER_IMPRESSION_MEASUREMENT.copy { + state = Measurement.State.FAILED + failure = failure { + reason = Measurement.Failure.Reason.REQUISITION_REFUSED + message = + INTERNAL_FAILED_SINGLE_PUBLISHER_IMPRESSION_MEASUREMENT.details.failure.message + } } - } ) - whenever(internalMetricsMock.batchGetMetrics(any())) + whenever(measurementsMock.getMeasurement(any())).thenAnswer { + val request = it.arguments[0] as GetMeasurementRequest + measurementsMap[request.name] + } + + whenever( + internalMetricsMock.batchGetMetrics( + eq( + internalBatchGetMetricsRequest { + cmmsMeasurementConsumerId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId + externalMetricIds += INTERNAL_PENDING_INCREMENTAL_REACH_METRIC.externalMetricId + externalMetricIds += + INTERNAL_PENDING_SINGLE_PUBLISHER_IMPRESSION_METRIC.externalMetricId + } + ) + ) + ) .thenReturn( internalBatchGetMetricsResponse { metrics += INTERNAL_PENDING_INCREMENTAL_REACH_METRIC @@ -3751,19 +3467,6 @@ class MetricsServiceTest { } ) - // Verify proto argument of MeasurementsCoroutineImplBase::getMeasurement - val getMeasurementCaptor: KArgumentCaptor = argumentCaptor() - verifyBlocking(measurementsMock, times(3)) { getMeasurement(getMeasurementCaptor.capture()) } - val capturedGetMeasurementRequests = getMeasurementCaptor.allValues - assertThat(capturedGetMeasurementRequests) - .containsExactly( - getMeasurementRequest { name = PENDING_UNION_ALL_REACH_MEASUREMENT.name }, - getMeasurementRequest { - name = PENDING_UNION_ALL_BUT_LAST_PUBLISHER_REACH_MEASUREMENT.name - }, - getMeasurementRequest { name = PENDING_SINGLE_PUBLISHER_IMPRESSION_MEASUREMENT.name }, - ) - // Verify proto argument of internal MeasurementsCoroutineImplBase::batchSetMeasurementResults val batchSetMeasurementResultsCaptor: KArgumentCaptor = argumentCaptor() @@ -3791,16 +3494,6 @@ class MetricsServiceTest { } ) - // Verify proto argument of internal MetricsCoroutineImplBase::batchGetMetrics - verifyProtoArgument(internalMetricsMock, MetricsCoroutineImplBase::batchGetMetrics) - .isEqualTo( - internalBatchGetMetricsRequest { - cmmsMeasurementConsumerId = MEASUREMENT_CONSUMERS.keys.first().measurementConsumerId - externalMetricIds += INTERNAL_PENDING_INCREMENTAL_REACH_METRIC.externalMetricId - externalMetricIds += INTERNAL_PENDING_SINGLE_PUBLISHER_IMPRESSION_METRIC.externalMetricId - } - ) - assertThat(result).ignoringRepeatedFieldOrder().isEqualTo(expected) } @@ -3892,42 +3585,31 @@ class MetricsServiceTest { } @Test - fun `listMetrics throws Exception when the internal streamMetrics throws Exception`() { + fun `listMetrics throws Exception when the internal streamMetrics throws Exception`(): Unit = runBlocking { whenever(internalMetricsMock.streamMetrics(any())) .thenThrow(StatusRuntimeException(Status.INVALID_ARGUMENT)) val request = listMetricsRequest { parent = MEASUREMENT_CONSUMERS.values.first().name } - val exception = - assertFailsWith(Exception::class) { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.listMetrics(request) } - } + assertFailsWith(Exception::class) { + withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { + runBlocking { service.listMetrics(request) } } - - val expectedExceptionDescription = "Unable to list metrics from the reporting database." - assertThat(exception.message).isEqualTo(expectedExceptionDescription) + } } - } @Test - fun `listMetrics throws Exception when getMeasurement throws Exception`() { - runBlocking { - whenever(measurementsMock.getMeasurement(any())) - .thenThrow(StatusRuntimeException(Status.INVALID_ARGUMENT)) - - val request = listMetricsRequest { parent = MEASUREMENT_CONSUMERS.values.first().name } + fun `listMetrics throws Exception when getMeasurement throws Exception`(): Unit = runBlocking { + whenever(measurementsMock.getMeasurement(any())) + .thenThrow(StatusRuntimeException(Status.INVALID_ARGUMENT)) - val exception = - assertFailsWith(Exception::class) { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { - runBlocking { service.listMetrics(request) } - } - } + val request = listMetricsRequest { parent = MEASUREMENT_CONSUMERS.values.first().name } - val expectedExceptionDescription = "Unable to retrieve the measurement" - assertThat(exception.message).contains(expectedExceptionDescription) + assertFailsWith(Exception::class) { + withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { + runBlocking { service.listMetrics(request) } + } } } @@ -3980,7 +3662,7 @@ class MetricsServiceTest { } @Test - fun `listMetrics throws Exception when internal batchGetMetrics throws Exception`() { + fun `listMetrics throws Exception when internal batchGetMetrics throws Exception`(): Unit = runBlocking { whenever(measurementsMock.getMeasurement(any())) .thenReturn( @@ -3999,7 +3681,6 @@ class MetricsServiceTest { } } } - } @Test fun `listMetrics throws FAILED_PRECONDITION when the measurement public key is not valid`() = @@ -4161,11 +3842,6 @@ class MetricsServiceTest { .thenReturn( internalBatchGetMetricsResponse { metrics += INTERNAL_PENDING_INCREMENTAL_REACH_METRIC }, ) - whenever(measurementsMock.getMeasurement(any())) - .thenReturn( - PENDING_UNION_ALL_REACH_MEASUREMENT, - PENDING_UNION_ALL_BUT_LAST_PUBLISHER_REACH_MEASUREMENT, - ) val request = getMetricRequest { name = PENDING_INCREMENTAL_REACH_METRIC.name } @@ -4190,18 +3866,6 @@ class MetricsServiceTest { } ) - // Verify proto argument of MeasurementsCoroutineImplBase::getMeasurement - val getMeasurementCaptor: KArgumentCaptor = argumentCaptor() - verifyBlocking(measurementsMock, times(2)) { getMeasurement(getMeasurementCaptor.capture()) } - val capturedGetMeasurementRequests = getMeasurementCaptor.allValues - assertThat(capturedGetMeasurementRequests) - .containsExactly( - getMeasurementRequest { name = PENDING_UNION_ALL_REACH_MEASUREMENT.name }, - getMeasurementRequest { - name = PENDING_UNION_ALL_BUT_LAST_PUBLISHER_REACH_MEASUREMENT.name - }, - ) - // Verify proto argument of internal MeasurementsCoroutineImplBase::batchSetMeasurementResults val batchSetMeasurementResultsCaptor: KArgumentCaptor = argumentCaptor() @@ -4223,7 +3887,18 @@ class MetricsServiceTest { @Test fun `getMetric returns the metric with SUCCEEDED when measurements are updated to SUCCEEDED`() = runBlocking { - whenever(internalMetricsMock.batchGetMetrics(any())) + whenever( + internalMetricsMock.batchGetMetrics( + eq( + internalBatchGetMetricsRequest { + cmmsMeasurementConsumerId = + INTERNAL_PENDING_CROSS_PUBLISHER_WATCH_DURATION_METRIC.cmmsMeasurementConsumerId + externalMetricIds += + INTERNAL_PENDING_CROSS_PUBLISHER_WATCH_DURATION_METRIC.externalMetricId + } + ) + ) + ) .thenReturn( internalBatchGetMetricsResponse { metrics += INTERNAL_PENDING_CROSS_PUBLISHER_WATCH_DURATION_METRIC @@ -4255,7 +3930,11 @@ class MetricsServiceTest { } } } - whenever(measurementsMock.getMeasurement(any())) + whenever( + measurementsMock.getMeasurement( + eq(getMeasurementRequest { name = PENDING_UNION_ALL_WATCH_DURATION_MEASUREMENT.name }) + ) + ) .thenReturn(succeededUnionAllWatchDurationMeasurement) whenever(internalMeasurementsMock.batchSetMeasurementResults(any())) .thenReturn( @@ -4271,38 +3950,6 @@ class MetricsServiceTest { runBlocking { service.getMetric(request) } } - // Verify proto argument of internal MetricsCoroutineImplBase::batchGetMetrics - val batchGetInternalMetricsCaptor: KArgumentCaptor = - argumentCaptor() - verifyBlocking(internalMetricsMock, times(2)) { - batchGetMetrics(batchGetInternalMetricsCaptor.capture()) - } - val capturedInternalGetMetricRequests = batchGetInternalMetricsCaptor.allValues - assertThat(capturedInternalGetMetricRequests) - .containsExactly( - internalBatchGetMetricsRequest { - cmmsMeasurementConsumerId = - INTERNAL_PENDING_CROSS_PUBLISHER_WATCH_DURATION_METRIC.cmmsMeasurementConsumerId - externalMetricIds += - INTERNAL_PENDING_CROSS_PUBLISHER_WATCH_DURATION_METRIC.externalMetricId - }, - internalBatchGetMetricsRequest { - cmmsMeasurementConsumerId = - INTERNAL_PENDING_CROSS_PUBLISHER_WATCH_DURATION_METRIC.cmmsMeasurementConsumerId - externalMetricIds += - INTERNAL_PENDING_CROSS_PUBLISHER_WATCH_DURATION_METRIC.externalMetricId - } - ) - - // Verify proto argument of MeasurementsCoroutineImplBase::getMeasurement - val getMeasurementCaptor: KArgumentCaptor = argumentCaptor() - verifyBlocking(measurementsMock, times(1)) { getMeasurement(getMeasurementCaptor.capture()) } - val capturedGetMeasurementRequests = getMeasurementCaptor.allValues - assertThat(capturedGetMeasurementRequests) - .containsExactly( - getMeasurementRequest { name = PENDING_UNION_ALL_WATCH_DURATION_MEASUREMENT.name }, - ) - // Verify proto argument of internal MeasurementsCoroutineImplBase::batchSetMeasurementResults val batchSetMeasurementResultsCaptor: KArgumentCaptor = argumentCaptor() @@ -4336,7 +3983,18 @@ class MetricsServiceTest { @Test fun `getMetric returns the metric with FAILED when measurements are updated to FAILED`() = runBlocking { - whenever(internalMetricsMock.batchGetMetrics(any())) + whenever( + internalMetricsMock.batchGetMetrics( + eq( + internalBatchGetMetricsRequest { + cmmsMeasurementConsumerId = + INTERNAL_PENDING_SINGLE_PUBLISHER_IMPRESSION_METRIC.cmmsMeasurementConsumerId + externalMetricIds += + INTERNAL_PENDING_SINGLE_PUBLISHER_IMPRESSION_METRIC.externalMetricId + } + ) + ) + ) .thenReturn( internalBatchGetMetricsResponse { metrics += INTERNAL_PENDING_SINGLE_PUBLISHER_IMPRESSION_METRIC @@ -4356,7 +4014,13 @@ class MetricsServiceTest { } } - whenever(measurementsMock.getMeasurement(any())) + whenever( + measurementsMock.getMeasurement( + eq( + getMeasurementRequest { name = PENDING_SINGLE_PUBLISHER_IMPRESSION_MEASUREMENT.name } + ) + ) + ) .thenReturn(failedSinglePublisherImpressionMeasurement) whenever(internalMeasurementsMock.batchSetMeasurementFailures(any())) .thenReturn( @@ -4372,38 +4036,6 @@ class MetricsServiceTest { runBlocking { service.getMetric(request) } } - // Verify proto argument of internal MetricsCoroutineImplBase::batchGetMetrics - val batchGetInternalMetricsCaptor: KArgumentCaptor = - argumentCaptor() - verifyBlocking(internalMetricsMock, times(2)) { - batchGetMetrics(batchGetInternalMetricsCaptor.capture()) - } - val capturedInternalGetMetricRequests = batchGetInternalMetricsCaptor.allValues - assertThat(capturedInternalGetMetricRequests) - .containsExactly( - internalBatchGetMetricsRequest { - cmmsMeasurementConsumerId = - INTERNAL_PENDING_SINGLE_PUBLISHER_IMPRESSION_METRIC.cmmsMeasurementConsumerId - externalMetricIds += - INTERNAL_PENDING_SINGLE_PUBLISHER_IMPRESSION_METRIC.externalMetricId - }, - internalBatchGetMetricsRequest { - cmmsMeasurementConsumerId = - INTERNAL_PENDING_SINGLE_PUBLISHER_IMPRESSION_METRIC.cmmsMeasurementConsumerId - externalMetricIds += - INTERNAL_PENDING_SINGLE_PUBLISHER_IMPRESSION_METRIC.externalMetricId - } - ) - - // Verify proto argument of MeasurementsCoroutineImplBase::getMeasurement - val getMeasurementCaptor: KArgumentCaptor = argumentCaptor() - verifyBlocking(measurementsMock, times(1)) { getMeasurement(getMeasurementCaptor.capture()) } - val capturedGetMeasurementRequests = getMeasurementCaptor.allValues - assertThat(capturedGetMeasurementRequests) - .containsExactly( - getMeasurementRequest { name = PENDING_SINGLE_PUBLISHER_IMPRESSION_MEASUREMENT.name }, - ) - // Verify proto argument of internal MeasurementsCoroutineImplBase::batchSetMeasurementResults val batchSetMeasurementResultsCaptor: KArgumentCaptor = argumentCaptor() @@ -4658,19 +4290,6 @@ class MetricsServiceTest { } ) - // Verify proto argument of MeasurementsCoroutineImplBase::getMeasurement - val getMeasurementCaptor: KArgumentCaptor = argumentCaptor() - verifyBlocking(measurementsMock, times(3)) { getMeasurement(getMeasurementCaptor.capture()) } - val capturedGetMeasurementRequests = getMeasurementCaptor.allValues - assertThat(capturedGetMeasurementRequests) - .containsExactly( - getMeasurementRequest { name = PENDING_UNION_ALL_REACH_MEASUREMENT.name }, - getMeasurementRequest { - name = PENDING_UNION_ALL_BUT_LAST_PUBLISHER_REACH_MEASUREMENT.name - }, - getMeasurementRequest { name = PENDING_SINGLE_PUBLISHER_IMPRESSION_MEASUREMENT.name }, - ) - // Verify proto argument of internal MeasurementsCoroutineImplBase::batchSetMeasurementResults val batchSetMeasurementResultsCaptor: KArgumentCaptor = argumentCaptor() From 56127b08f1998d4f888667a69f31dd6a94de5fb8 Mon Sep 17 00:00:00 2001 From: Rieman Li Date: Wed, 26 Apr 2023 21:51:39 +0000 Subject: [PATCH 10/12] Make some input arguments not memeber variables. --- .../service/api/v2alpha/MetricsService.kt | 22 +++++++++---------- .../service/api/v2alpha/MetricsServiceTest.kt | 4 ++-- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt index 3aa7ee4d953..6a699e04cd5 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt @@ -152,19 +152,19 @@ private const val MAX_PAGE_SIZE = 1000 private const val NANOS_PER_SECOND = 1_000_000_000 class MetricsService( + private val metricSpecConfig: MetricSpecConfig, private val internalReportingSetsStub: InternalReportingSetsCoroutineStub, - private val internalMeasurementsStub: InternalMeasurementsCoroutineStub, private val internalMetricsStub: InternalMetricsCoroutineStub, - private val dataProvidersStub: DataProvidersCoroutineStub, - private val measurementsStub: MeasurementsCoroutineStub, - private val certificatesStub: CertificatesCoroutineStub, - private val measurementConsumersStub: MeasurementConsumersCoroutineStub, - private val encryptionKeyPairStore: EncryptionKeyPairStore, - private val secureRandom: SecureRandom, - private val signingPrivateKeyDir: File, - private val trustedCertificates: Map, - private val metricSpecConfig: MetricSpecConfig, - private val coroutineContext: @BlockingExecutor CoroutineContext = Dispatchers.IO, + internalMeasurementsStub: InternalMeasurementsCoroutineStub, + dataProvidersStub: DataProvidersCoroutineStub, + measurementsStub: MeasurementsCoroutineStub, + certificatesStub: CertificatesCoroutineStub, + measurementConsumersStub: MeasurementConsumersCoroutineStub, + encryptionKeyPairStore: EncryptionKeyPairStore, + secureRandom: SecureRandom, + signingPrivateKeyDir: File, + trustedCertificates: Map, + coroutineContext: @BlockingExecutor CoroutineContext = Dispatchers.IO, ) : MetricsCoroutineImplBase() { private val measurementSupplier = diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsServiceTest.kt index 87392c56e5a..e567cb1ce12 100644 --- a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsServiceTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsServiceTest.kt @@ -1451,9 +1451,10 @@ class MetricsServiceTest { service = MetricsService( + METRIC_SPEC_CONFIG, InternalReportingSetsGrpcKt.ReportingSetsCoroutineStub(grpcTestServerRule.channel), - InternalMeasurementsGrpcKt.MeasurementsCoroutineStub(grpcTestServerRule.channel), InternalMetricsGrpcKt.MetricsCoroutineStub(grpcTestServerRule.channel), + InternalMeasurementsGrpcKt.MeasurementsCoroutineStub(grpcTestServerRule.channel), DataProvidersGrpcKt.DataProvidersCoroutineStub(grpcTestServerRule.channel), MeasurementsGrpcKt.MeasurementsCoroutineStub(grpcTestServerRule.channel), CertificatesGrpcKt.CertificatesCoroutineStub(grpcTestServerRule.channel), @@ -1464,7 +1465,6 @@ class MetricsServiceTest { listOf(AGGREGATOR_ROOT_CERTIFICATE, DATA_PROVIDER_ROOT_CERTIFICATE).associateBy { it.subjectKeyIdentifier!! }, - METRIC_SPEC_CONFIG ) } From fa8f2ceab78a2fc544aa6934dae52c839f204b55 Mon Sep 17 00:00:00 2001 From: Rieman Li Date: Wed, 26 Apr 2023 22:31:57 +0000 Subject: [PATCH 11/12] Fix unit tests. --- .../service/api/v2alpha/MetricsServiceTest.kt | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsServiceTest.kt index e567cb1ce12..e57ba83e61f 100644 --- a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsServiceTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsServiceTest.kt @@ -710,13 +710,12 @@ private val INTERNAL_SUCCEEDED_UNION_ALL_WATCH_DURATION_MEASUREMENT = private val UNION_ALL_BUT_LAST_PUBLISHER_REACH_MEASUREMENT_SPEC = measurementSpec { measurementPublicKey = MEASUREMENT_CONSUMER_PUBLIC_KEY.toByteString() - nonceHashes.addAll( + nonceHashes += listOf( hashSha256(SECURE_RANDOM_OUTPUT_LONG), hashSha256(SECURE_RANDOM_OUTPUT_LONG), hashSha256(SECURE_RANDOM_OUTPUT_LONG) ) - ) reach = MeasurementSpecKt.reach { @@ -861,13 +860,12 @@ private val PENDING_SINGLE_PUBLISHER_IMPRESSION_MEASUREMENT = private val UNION_ALL_WATCH_DURATION_MEASUREMENT_SPEC = measurementSpec { measurementPublicKey = MEASUREMENT_CONSUMER_PUBLIC_KEY.toByteString() - nonceHashes.addAll( + nonceHashes += listOf( hashSha256(SECURE_RANDOM_OUTPUT_LONG), hashSha256(SECURE_RANDOM_OUTPUT_LONG), hashSha256(SECURE_RANDOM_OUTPUT_LONG) ) - ) duration = MeasurementSpecKt.duration { @@ -1525,9 +1523,7 @@ class MetricsServiceTest { .isEqualTo( UNION_ALL_BUT_LAST_PUBLISHER_REACH_MEASUREMENT_SPEC.copy { nonceHashes.clear() - nonceHashes.addAll( - List(dataProvidersList.size) { hashSha256(SECURE_RANDOM_OUTPUT_LONG) } - ) + nonceHashes += List(dataProvidersList.size) { hashSha256(SECURE_RANDOM_OUTPUT_LONG) } } ) @@ -1926,9 +1922,7 @@ class MetricsServiceTest { .isEqualTo( UNION_ALL_BUT_LAST_PUBLISHER_REACH_MEASUREMENT_SPEC.copy { nonceHashes.clear() - nonceHashes.addAll( - List(dataProvidersList.size) { hashSha256(SECURE_RANDOM_OUTPUT_LONG) } - ) + nonceHashes += List(dataProvidersList.size) { hashSha256(SECURE_RANDOM_OUTPUT_LONG) } } ) @@ -2742,9 +2736,7 @@ class MetricsServiceTest { else UNION_ALL_BUT_LAST_PUBLISHER_REACH_MEASUREMENT_SPEC.copy { nonceHashes.clear() - nonceHashes.addAll( - List(dataProvidersList.size) { hashSha256(SECURE_RANDOM_OUTPUT_LONG) } - ) + nonceHashes += List(dataProvidersList.size) { hashSha256(SECURE_RANDOM_OUTPUT_LONG) } } ) From dd44a9085b9d248de73ceb3ce0684a5fd4a4b9ed Mon Sep 17 00:00:00 2001 From: Rieman Li Date: Thu, 27 Apr 2023 20:00:06 +0000 Subject: [PATCH 12/12] Add TODO for a potential performance improvement. --- .../reporting/service/api/v2alpha/MetricsService.kt | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt index 6a699e04cd5..ac9e0770a07 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt @@ -1015,6 +1015,10 @@ class MetricsService( return batchGetMetricsResponse { metrics += + /** + * TODO(@riemanli): a potential improvement can be done by only getting the metrics whose + * measurements are updated. Re-evaluate when a load-test is ready after deployment. + */ if (anyMeasurementUpdated) { batchGetInternalMetrics(principal.resourceKey.measurementConsumerId, externalMetricIds) .map { it.toMetric() } @@ -1084,8 +1088,13 @@ class MetricsService( principal, ) - // If any measurement got updated, pull the list of the up-to-date internal metrics. Otherwise, - // use the original list. + /** + * If any measurement got updated, pull the list of the up-to-date internal metrics. Otherwise, + * use the original list. + * + * TODO(@riemanli): a potential improvement can be done by only getting the metrics whose + * measurements are updated. Re-evaluate when a load-test is ready after deployment. + */ val internalMetrics: List = if (anyMeasurementUpdated) { batchGetInternalMetrics(