diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt index cc559bbe5b4..5089f17e06b 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt @@ -875,14 +875,10 @@ class MetricsService( } val internalMetric: InternalMetric = - batchGetInternalMetrics( - metricKey.cmmsMeasurementConsumerId, - listOf(apiIdToExternalId(metricKey.metricId)) - ) - .first() + getInternalMetric(metricKey.cmmsMeasurementConsumerId, apiIdToExternalId(metricKey.metricId)) // Early exit when the metric is at a terminal state. - if (internalMetric.state != InternalMetric.State.RUNNING) { + if (determineMetricState(internalMetric) != Metric.State.RUNNING) { return internalMetric.toMetric() } @@ -894,18 +890,19 @@ class MetricsService( internalMeasurement.state == InternalMeasurement.State.PENDING } - measurementSupplier.syncInternalMeasurements( - toBeSyncedInternalMeasurements, - principal.config.apiKey, - principal, - ) - - return batchGetInternalMetrics( - metricKey.cmmsMeasurementConsumerId, - listOf(apiIdToExternalId(metricKey.metricId)) + val anyMeasurementUpdated: Boolean = + measurementSupplier.syncInternalMeasurements( + toBeSyncedInternalMeasurements, + principal.config.apiKey, + principal, ) - .first() - .toMetric() + + return if (anyMeasurementUpdated) { + getInternalMetric(metricKey.cmmsMeasurementConsumerId, apiIdToExternalId(metricKey.metricId)) + .toMetric() + } else { + internalMetric.toMetric() + } } override suspend fun batchGetMetrics(request: BatchGetMetricsRequest): BatchGetMetricsResponse { @@ -945,23 +942,28 @@ class MetricsService( // Only syncs pending measurements which can only be in metrics that are still running. val toBeSyncedInternalMeasurements: List = internalMetrics - .filter { internalMetric -> internalMetric.state == InternalMetric.State.RUNNING } + .filter { internalMetric -> determineMetricState(internalMetric) == Metric.State.RUNNING } .flatMap { internalMetric -> internalMetric.weightedMeasurementsList } .map { weightedMeasurement -> weightedMeasurement.measurement } .filter { internalMeasurement -> internalMeasurement.state == InternalMeasurement.State.PENDING } - measurementSupplier.syncInternalMeasurements( - toBeSyncedInternalMeasurements, - principal.config.apiKey, - principal, - ) + val anyMeasurementUpdated: Boolean = + measurementSupplier.syncInternalMeasurements( + toBeSyncedInternalMeasurements, + principal.config.apiKey, + principal, + ) return batchGetMetricsResponse { metrics += - batchGetInternalMetrics(principal.resourceKey.measurementConsumerId, externalMetricIds) - .map { it.toMetric() } + if (anyMeasurementUpdated) { + batchGetInternalMetrics(principal.resourceKey.measurementConsumerId, externalMetricIds) + .map { it.toMetric() } + } else { + internalMetrics.map { it.toMetric() } + } } } override suspend fun listMetrics(request: ListMetricsRequest): ListMetricsResponse { @@ -1046,7 +1048,7 @@ class MetricsService( } } - /** Gets a batch of [InternalMetric]. */ + /** Gets a batch of [InternalMetric]s. */ private suspend fun batchGetInternalMetrics( cmmsMeasurementConsumerId: String, externalMetricIds: List, @@ -1063,6 +1065,18 @@ class MetricsService( } } + /** Gets an [InternalMetric]. */ + private suspend fun getInternalMetric( + cmmsMeasurementConsumerId: String, + externalMetricId: Long, + ): InternalMetric { + return try { + batchGetInternalMetrics(cmmsMeasurementConsumerId, listOf(externalMetricId)).first() + } catch (e: StatusException) { + throw Exception("Unable to get metrics from the reporting database.", e) + } + } + override suspend fun createMetric(request: CreateMetricRequest): Metric { grpcRequireNotNull(MeasurementConsumerKey.fromName(request.parent)) { "Parent is either unspecified or invalid." diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsServiceTest.kt index f71d3c65586..bd46ea11d1c 100644 --- a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsServiceTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsServiceTest.kt @@ -1101,7 +1101,13 @@ private val INTERNAL_PENDING_SINGLE_PUBLISHER_IMPRESSION_METRIC = } private val INTERNAL_FAILED_SINGLE_PUBLISHER_IMPRESSION_METRIC = - INTERNAL_PENDING_SINGLE_PUBLISHER_IMPRESSION_METRIC.copy { state = InternalMetric.State.FAILED } + INTERNAL_PENDING_SINGLE_PUBLISHER_IMPRESSION_METRIC.copy { + weightedMeasurements.clear() + weightedMeasurements += weightedMeasurement { + weight = 1 + measurement = INTERNAL_FAILED_SINGLE_PUBLISHER_IMPRESSION_MEASUREMENT + } + } // Internal Cross Publisher Watch Duration Metrics private val INTERNAL_REQUESTING_CROSS_PUBLISHER_WATCH_DURATION_METRIC = internalMetric { @@ -1140,7 +1146,6 @@ private val INTERNAL_PENDING_INITIAL_CROSS_PUBLISHER_WATCH_DURATION_METRIC = weight = 1 measurement = INTERNAL_PENDING_NOT_CREATED_UNION_ALL_WATCH_DURATION_MEASUREMENT } - state = InternalMetric.State.RUNNING } private val INTERNAL_PENDING_CROSS_PUBLISHER_WATCH_DURATION_METRIC = @@ -1154,17 +1159,12 @@ private val INTERNAL_PENDING_CROSS_PUBLISHER_WATCH_DURATION_METRIC = private val INTERNAL_SUCCEEDED_CROSS_PUBLISHER_WATCH_DURATION_METRIC = INTERNAL_PENDING_CROSS_PUBLISHER_WATCH_DURATION_METRIC.copy { - state = InternalMetric.State.SUCCEEDED - details = - InternalMetricKt.details { - filters += this@copy.details.filtersList - result = internalMetricResult { - watchDuration = - InternalMetricResultKt.watchDurationResult { - value = TOTAL_WATCH_DURATION.seconds.toDouble() - } - } - } + weightedMeasurements.clear() + weightedMeasurements += weightedMeasurement { + weight = 1 + measurement = INTERNAL_SUCCEEDED_UNION_ALL_WATCH_DURATION_MEASUREMENT + } + details = InternalMetricKt.details { filters += this@copy.details.filtersList } } // Public Metrics @@ -4172,7 +4172,11 @@ class MetricsServiceTest { whenever(internalMetricsMock.batchGetMetrics(any())) .thenReturn( internalBatchGetMetricsResponse { metrics += INTERNAL_PENDING_INCREMENTAL_REACH_METRIC }, - internalBatchGetMetricsResponse { metrics += INTERNAL_PENDING_INCREMENTAL_REACH_METRIC } + ) + whenever(measurementsMock.getMeasurement(any())) + .thenReturn( + PENDING_UNION_ALL_REACH_MEASUREMENT, + PENDING_UNION_ALL_BUT_LAST_PUBLISHER_REACH_MEASUREMENT, ) val request = getMetricRequest { name = PENDING_INCREMENTAL_REACH_METRIC.name } @@ -4185,17 +4189,12 @@ class MetricsServiceTest { // Verify proto argument of internal MetricsCoroutineImplBase::batchGetMetrics val batchGetInternalMetricsCaptor: KArgumentCaptor = argumentCaptor() - verifyBlocking(internalMetricsMock, times(2)) { + verifyBlocking(internalMetricsMock, times(1)) { batchGetMetrics(batchGetInternalMetricsCaptor.capture()) } val capturedInternalGetMetricRequests = batchGetInternalMetricsCaptor.allValues assertThat(capturedInternalGetMetricRequests) .containsExactly( - internalBatchGetMetricsRequest { - cmmsMeasurementConsumerId = - INTERNAL_PENDING_INCREMENTAL_REACH_METRIC.cmmsMeasurementConsumerId - externalMetricIds += INTERNAL_PENDING_INCREMENTAL_REACH_METRIC.externalMetricId - }, internalBatchGetMetricsRequest { cmmsMeasurementConsumerId = INTERNAL_PENDING_INCREMENTAL_REACH_METRIC.cmmsMeasurementConsumerId @@ -4203,6 +4202,18 @@ class MetricsServiceTest { } ) + // Verify proto argument of MeasurementsCoroutineImplBase::getMeasurement + val getMeasurementCaptor: KArgumentCaptor = argumentCaptor() + verifyBlocking(measurementsMock, times(2)) { getMeasurement(getMeasurementCaptor.capture()) } + val capturedGetMeasurementRequests = getMeasurementCaptor.allValues + assertThat(capturedGetMeasurementRequests) + .containsExactly( + getMeasurementRequest { name = PENDING_UNION_ALL_REACH_MEASUREMENT.name }, + getMeasurementRequest { + name = PENDING_UNION_ALL_BUT_LAST_PUBLISHER_REACH_MEASUREMENT.name + }, + ) + // Verify proto argument of internal MeasurementsCoroutineImplBase::batchSetMeasurementResults val batchSetMeasurementResultsCaptor: KArgumentCaptor = argumentCaptor() @@ -4295,6 +4306,15 @@ class MetricsServiceTest { } ) + // Verify proto argument of MeasurementsCoroutineImplBase::getMeasurement + val getMeasurementCaptor: KArgumentCaptor = argumentCaptor() + verifyBlocking(measurementsMock, times(1)) { getMeasurement(getMeasurementCaptor.capture()) } + val capturedGetMeasurementRequests = getMeasurementCaptor.allValues + assertThat(capturedGetMeasurementRequests) + .containsExactly( + getMeasurementRequest { name = PENDING_UNION_ALL_WATCH_DURATION_MEASUREMENT.name }, + ) + // Verify proto argument of internal MeasurementsCoroutineImplBase::batchSetMeasurementResults val batchSetMeasurementResultsCaptor: KArgumentCaptor = argumentCaptor()