From baaf158c241f64cac126e23049d949c483a151d8 Mon Sep 17 00:00:00 2001 From: Tristan Vuong <85768771+tristanvuong2021@users.noreply.github.com> Date: Mon, 8 Apr 2024 15:05:36 -0700 Subject: [PATCH] Add State to Internal Metric (#1557) --- .../deploy/v2/postgres/readers/BUILD.bazel | 1 + .../v2/postgres/readers/MetricReader.kt | 74 ++++- .../v2/postgres/writers/CreateMetrics.kt | 9 +- .../writers/SetMeasurementFailures.kt | 33 ++ .../postgres/writers/SetMeasurementResults.kt | 41 +++ .../service/api/v2alpha/MetricsService.kt | 198 ++++++------ .../service/api/v2alpha/ProtoConversions.kt | 16 + .../internal/testing/v2/MetricsServiceTest.kt | 293 ++++++++++++++++++ .../internal/reporting/v2/metric.proto | 8 + .../postgres/add-state-column-to-metrics.sql | 40 +++ .../reporting/postgres/changelog-v2.yaml | 3 + .../v2/postgres/PostgresMetricsServiceTest.kt | 1 + .../service/api/v2alpha/MetricsServiceTest.kt | 66 ++++ 13 files changed, 658 insertions(+), 125 deletions(-) create mode 100644 src/main/resources/reporting/postgres/add-state-column-to-metrics.sql diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/readers/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/readers/BUILD.bazel index 9e2eb652d04..de79d31548b 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/readers/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/readers/BUILD.bazel @@ -21,6 +21,7 @@ kt_jvm_library( "@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core", "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common", "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/db/r2dbc", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/db/r2dbc/postgres", "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/identity", ], ) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/readers/MetricReader.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/readers/MetricReader.kt index 3f9edb8db92..a312ef1bd5f 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/readers/MetricReader.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/readers/MetricReader.kt @@ -30,6 +30,8 @@ import org.wfanet.measurement.common.db.r2dbc.BoundStatement import org.wfanet.measurement.common.db.r2dbc.ReadContext import org.wfanet.measurement.common.db.r2dbc.ResultRow import org.wfanet.measurement.common.db.r2dbc.boundStatement +import org.wfanet.measurement.common.db.r2dbc.postgres.ValuesListBoundStatement +import org.wfanet.measurement.common.db.r2dbc.postgres.valuesListBoundStatement import org.wfanet.measurement.common.identity.InternalId import org.wfanet.measurement.common.toInstant import org.wfanet.measurement.common.toProtoDuration @@ -81,6 +83,7 @@ class MetricReader(private val readContext: ReadContext) { val metricSpec: MetricSpec, val weightedMeasurementInfoMap: MutableMap, val details: Metric.Details, + val state: Metric.State, ) private data class MetricMeasurementKey( @@ -133,6 +136,7 @@ class MetricReader(private val readContext: ReadContext) { Metrics.VidSamplingIntervalWidth, Metrics.CreateTime, Metrics.MetricDetails, + Metrics.State as MetricsState, MetricMeasurements.Coefficient, MetricMeasurements.BinaryRepresentation, Measurements.MeasurementId, @@ -140,7 +144,7 @@ class MetricReader(private val readContext: ReadContext) { Measurements.CmmsMeasurementId, Measurements.TimeIntervalStart AS MeasurementsTimeIntervalStart, Measurements.TimeIntervalEndExclusive AS MeasurementsTimeIntervalEndExclusive, - Measurements.State, + Measurements.State as MeasurementsState, Measurements.MeasurementDetails, PrimitiveReportingSetBases.PrimitiveReportingSetBasisId, PrimitiveReportingSets.ExternalReportingSetId AS PrimitiveExternalReportingSetId, @@ -181,26 +185,64 @@ class MetricReader(private val readContext: ReadContext) { JOIN Metrics USING(MeasurementConsumerId) $baseSqlJoins WHERE Metrics.MeasurementConsumerId = $1 - AND CreateMetricRequestId IN + AND CreateMetricRequestId IN (VALUES ${ValuesListBoundStatement.VALUES_LIST_PLACEHOLDER}) """ .trimIndent() ) - var i = 2 - val bindingMap = mutableMapOf() - val inList = - createMetricRequestIds.joinToString(separator = ",", prefix = "(", postfix = ")") { - val index = "$$i" - bindingMap[it] = index - i++ - index + val statement = + valuesListBoundStatement(valuesStartIndex = 1, paramCount = 1, sql.toString()) { + bind("$1", measurementConsumerId) + createMetricRequestIds.forEach { addValuesBinding { bindValuesParam(0, it) } } } - sql.append(inList) + + return flow { + val metricInfoMap = buildResultMap(statement) + + for (entry in metricInfoMap) { + val metricInfo = entry.value + + val metric = metricInfo.buildMetric() + + val createMetricRequestId = metricInfo.createMetricRequestId ?: "" + emit( + Result( + measurementConsumerId = metricInfo.measurementConsumerId, + metricId = metricInfo.metricId, + createMetricRequestId = createMetricRequestId, + metric = metric, + ) + ) + } + } + } + + fun readMetricsByCmmsMeasurementId( + measurementConsumerId: InternalId, + cmmsMeasurementIds: Collection, + ): Flow { + if (cmmsMeasurementIds.isEmpty()) { + return emptyFlow() + } + + val sql = + StringBuilder( + """ + $baseSqlSelect + FROM + MeasurementConsumers + JOIN Metrics USING(MeasurementConsumerId) + $baseSqlJoins + WHERE Metrics.MeasurementConsumerId = $1 + AND CmmsMeasurementId IN (VALUES ${ValuesListBoundStatement.VALUES_LIST_PLACEHOLDER}) + """ + .trimIndent() + ) val statement = - boundStatement(sql.toString()) { + valuesListBoundStatement(valuesStartIndex = 1, paramCount = 1, sql.toString()) { bind("$1", measurementConsumerId) - createMetricRequestIds.forEach { bind(bindingMap.getValue(it), it) } + cmmsMeasurementIds.forEach { addValuesBinding { bindValuesParam(0, it) } } } return flow { @@ -625,6 +667,7 @@ class MetricReader(private val readContext: ReadContext) { if (metricInfo.details != Metric.Details.getDefaultInstance()) { details = metricInfo.details } + state = metricInfo.state } } @@ -662,12 +705,14 @@ class MetricReader(private val readContext: ReadContext) { val cmmsMeasurementId: String? = row["CmmsMeasurementId"] val measurementTimeIntervalStart: Instant = row["MeasurementsTimeIntervalStart"] val measurementTimeIntervalEnd: Instant = row["MeasurementsTimeIntervalEndExclusive"] - val measurementState: Measurement.State = Measurement.State.forNumber(row["State"]) + val measurementState: Measurement.State = + row.getProtoEnum("MeasurementsState", Measurement.State::forNumber) val measurementDetails: Measurement.Details = row.getProtoMessage("MeasurementDetails", Measurement.Details.parser()) val primitiveReportingSetBasisId: InternalId = row["PrimitiveReportingSetBasisId"] val primitiveExternalReportingSetId: String = row["PrimitiveExternalReportingSetId"] val primitiveReportingSetBasisFilter: String? = row["PrimitiveReportingSetBasisFilter"] + val metricState: Metric.State = row.getProtoEnum("MetricsState", Metric.State::forNumber) val metricInfo = metricInfoMap.computeIfAbsent(externalMetricId) { @@ -767,6 +812,7 @@ class MetricReader(private val readContext: ReadContext) { metricSpec = metricSpec, details = metricDetails, weightedMeasurementInfoMap = mutableMapOf(), + state = metricState, ) } diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/writers/CreateMetrics.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/writers/CreateMetrics.kt index 0fc9f570897..bede5115af2 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/writers/CreateMetrics.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/writers/CreateMetrics.kt @@ -193,7 +193,7 @@ class CreateMetrics(private val requests: List) : val statement = valuesListBoundStatement( valuesStartIndex = 0, - paramCount = 20, + paramCount = 21, """ INSERT INTO Metrics ( @@ -216,11 +216,13 @@ class CreateMetrics(private val requests: List) : VidSamplingIntervalWidth, CreateTime, MetricDetails, - MetricDetailsJson + MetricDetailsJson, + State ) VALUES ${ValuesListBoundStatement.VALUES_LIST_PLACEHOLDER} """, ) { + val createTime = Instant.now().atOffset(ZoneOffset.UTC).truncatedTo(ChronoUnit.MICROS) requests.forEach { val existingMetric: Metric? = existingMetricsMap[it.requestId] if (existingMetric != null) { @@ -229,7 +231,6 @@ class CreateMetrics(private val requests: List) : val metricId = idGenerator.generateInternalId() val externalMetricId: String = it.externalMetricId val reportingSetId: InternalId? = reportingSetMap[it.metric.externalReportingSetId] - val createTime = Instant.now().atOffset(ZoneOffset.UTC).truncatedTo(ChronoUnit.MICROS) val vidSamplingIntervalStart = if (it.metric.metricSpec.typeCase == MetricSpec.TypeCase.POPULATION_COUNT) 0 else it.metric.metricSpec.vidSamplingInterval.start @@ -317,6 +318,7 @@ class CreateMetrics(private val requests: List) : bindValuesParam(17, createTime) bindValuesParam(18, it.metric.details) bindValuesParam(19, it.metric.details.toJson()) + bindValuesParam(20, Metric.State.RUNNING) } if (it.requestId.isNotEmpty()) { @@ -351,6 +353,7 @@ class CreateMetrics(private val requests: List) : weightedMeasurements.clear() weightedMeasurements.addAll(weightedMeasurementsAndInsertData.weightedMeasurements) this.createTime = createTime.toInstant().toProtoTime() + state = Metric.State.RUNNING } ) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/writers/SetMeasurementFailures.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/writers/SetMeasurementFailures.kt index 28cbf0a1e9f..0ab0b59190f 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/writers/SetMeasurementFailures.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/writers/SetMeasurementFailures.kt @@ -19,11 +19,14 @@ package org.wfanet.measurement.reporting.deploy.v2.postgres.writers import org.wfanet.measurement.common.db.r2dbc.postgres.PostgresWriter import org.wfanet.measurement.common.db.r2dbc.postgres.ValuesListBoundStatement import org.wfanet.measurement.common.db.r2dbc.postgres.valuesListBoundStatement +import org.wfanet.measurement.common.identity.InternalId import org.wfanet.measurement.common.toJson import org.wfanet.measurement.internal.reporting.v2.BatchSetMeasurementFailuresRequest import org.wfanet.measurement.internal.reporting.v2.Measurement import org.wfanet.measurement.internal.reporting.v2.MeasurementKt +import org.wfanet.measurement.internal.reporting.v2.Metric import org.wfanet.measurement.reporting.deploy.v2.postgres.readers.MeasurementConsumerReader +import org.wfanet.measurement.reporting.deploy.v2.postgres.readers.MetricReader import org.wfanet.measurement.reporting.service.internal.MeasurementConsumerNotFoundException import org.wfanet.measurement.reporting.service.internal.MeasurementNotFoundException @@ -72,5 +75,35 @@ class SetMeasurementFailures(private val request: BatchSetMeasurementFailuresReq if (result.numRowsUpdated < request.measurementFailuresList.size) { throw MeasurementNotFoundException() } + + // Read all metrics tied to Measurements that were updated. + val metricIds: List = buildList { + MetricReader(transactionContext) + .readMetricsByCmmsMeasurementId( + measurementConsumerId, + request.measurementFailuresList.map { it.cmmsMeasurementId }, + ) + .collect { metricReaderResult -> add(metricReaderResult.metricId) } + } + + if (metricIds.isNotEmpty()) { + val metricStateUpdateStatement = + valuesListBoundStatement( + valuesStartIndex = 2, + paramCount = 1, + """ + UPDATE Metrics AS m SET State = $1 + FROM (VALUES ${ValuesListBoundStatement.VALUES_LIST_PLACEHOLDER}) + AS c(MetricId) + WHERE MeasurementConsumerId = $2 AND m.MetricId = c.MetricId + """, + ) { + bind("$1", Metric.State.FAILED) + bind("$2", measurementConsumerId) + metricIds.forEach { addValuesBinding { bindValuesParam(0, it) } } + } + + transactionContext.executeStatement(metricStateUpdateStatement) + } } } diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/writers/SetMeasurementResults.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/writers/SetMeasurementResults.kt index 37f73c8075b..444824af3ec 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/writers/SetMeasurementResults.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/writers/SetMeasurementResults.kt @@ -19,11 +19,14 @@ package org.wfanet.measurement.reporting.deploy.v2.postgres.writers import org.wfanet.measurement.common.db.r2dbc.postgres.PostgresWriter import org.wfanet.measurement.common.db.r2dbc.postgres.ValuesListBoundStatement import org.wfanet.measurement.common.db.r2dbc.postgres.valuesListBoundStatement +import org.wfanet.measurement.common.identity.InternalId import org.wfanet.measurement.common.toJson import org.wfanet.measurement.internal.reporting.v2.BatchSetMeasurementResultsRequest import org.wfanet.measurement.internal.reporting.v2.Measurement import org.wfanet.measurement.internal.reporting.v2.MeasurementKt +import org.wfanet.measurement.internal.reporting.v2.Metric import org.wfanet.measurement.reporting.deploy.v2.postgres.readers.MeasurementConsumerReader +import org.wfanet.measurement.reporting.deploy.v2.postgres.readers.MetricReader import org.wfanet.measurement.reporting.service.internal.MeasurementConsumerNotFoundException import org.wfanet.measurement.reporting.service.internal.MeasurementNotFoundException @@ -72,5 +75,43 @@ class SetMeasurementResults(private val request: BatchSetMeasurementResultsReque if (result.numRowsUpdated < request.measurementResultsList.size) { throw MeasurementNotFoundException() } + + // Read all metrics tied to Measurements that were updated and determine any state changes. + val metricIds: List = buildList { + MetricReader(transactionContext) + .readMetricsByCmmsMeasurementId( + measurementConsumerId, + request.measurementResultsList.map { it.cmmsMeasurementId }, + ) + .collect { metricReaderResult -> + if (metricReaderResult.metric.state == Metric.State.RUNNING) { + val measurementStates = + metricReaderResult.metric.weightedMeasurementsList.map { it.measurement.state } + if (measurementStates.all { it == Measurement.State.SUCCEEDED }) { + add(metricReaderResult.metricId) + } + } + } + } + + if (metricIds.isNotEmpty()) { + val metricStateUpdateStatement = + valuesListBoundStatement( + valuesStartIndex = 2, + paramCount = 1, + """ + UPDATE Metrics AS m SET State = $1 + FROM (VALUES ${ValuesListBoundStatement.VALUES_LIST_PLACEHOLDER}) + AS c(MetricId) + WHERE MeasurementConsumerId = $2 AND m.MetricId = c.MetricId + """, + ) { + bind("$1", Metric.State.SUCCEEDED) + bind("$2", measurementConsumerId) + metricIds.forEach { addValuesBinding { bindValuesParam(0, it) } } + } + + transactionContext.executeStatement(metricStateUpdateStatement) + } } } diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt index 7f76794b98d..d85e2e10eda 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt @@ -1142,31 +1142,11 @@ class MetricsService( val internalMetric: InternalMetric = getInternalMetric(metricKey.cmmsMeasurementConsumerId, metricKey.metricId) - // Early exit when the metric is at a terminal state. - if (internalMetric.state != Metric.State.RUNNING) { - return internalMetric.toMetric(variances) - } - - // Only syncs pending measurements which can only be in metrics that are still running. - val toBeSyncedInternalMeasurements: List = - internalMetric.weightedMeasurementsList - .map { weightedMeasurement -> weightedMeasurement.measurement } - .filter { internalMeasurement -> - internalMeasurement.state == InternalMeasurement.State.PENDING - } - - val anyMeasurementUpdated: Boolean = - measurementSupplier.syncInternalMeasurements( - toBeSyncedInternalMeasurements, - principal.config.apiKey, + return syncAndConvertInternalMetricsToPublicMetrics( + mapOf(internalMetric.state to listOf(internalMetric)), principal, ) - - return if (anyMeasurementUpdated) { - getInternalMetric(metricKey.cmmsMeasurementConsumerId, metricKey.metricId).toMetric(variances) - } else { - internalMetric.toMetric(variances) - } + .single() } override suspend fun batchGetMetrics(request: BatchGetMetricsRequest): BatchGetMetricsResponse { @@ -1201,39 +1181,13 @@ class MetricsService( metricKey.metricId } - val internalMetrics: List = - batchGetInternalMetrics(principal.resourceKey.measurementConsumerId, metricIds) - - // Only syncs pending measurements which can only be in metrics that are still running. - val toBeSyncedInternalMeasurements: List = - internalMetrics - .filter { internalMetric -> internalMetric.state == Metric.State.RUNNING } - .flatMap { internalMetric -> internalMetric.weightedMeasurementsList } - .map { weightedMeasurement -> weightedMeasurement.measurement } - .filter { internalMeasurement -> - internalMeasurement.state == InternalMeasurement.State.PENDING - } - - val anyMeasurementUpdated: Boolean = - measurementSupplier.syncInternalMeasurements( - toBeSyncedInternalMeasurements, - principal.config.apiKey, - principal, - ) + val internalMetricsByState: Map> = + batchGetInternalMetrics(principal.resourceKey.measurementConsumerId, metricIds).groupBy { + it.state + } return batchGetMetricsResponse { - metrics += - /** - * TODO(@riemanli): a potential improvement can be done by only getting the metrics whose - * measurements are updated. Re-evaluate when a load-test is ready after deployment. - */ - if (anyMeasurementUpdated) { - batchGetInternalMetrics(principal.resourceKey.measurementConsumerId, metricIds).map { - it.toMetric(variances) - } - } else { - internalMetrics.map { it.toMetric(variances) } - } + metrics += syncAndConvertInternalMetricsToPublicMetrics(internalMetricsByState, principal) } } @@ -1283,45 +1237,11 @@ class MetricsService( null } - val subResults: List = - results.subList(0, min(results.size, listMetricsPageToken.pageSize)) - - // Only syncs pending measurements which can only be in metrics that are still running. - val toBeSyncedInternalMeasurements: List = - subResults - .filter { internalMetric -> internalMetric.state == Metric.State.RUNNING } - .flatMap { internalMetric -> internalMetric.weightedMeasurementsList } - .map { weightedMeasurement -> weightedMeasurement.measurement } - .filter { internalMeasurement -> - internalMeasurement.state == InternalMeasurement.State.PENDING - } - - val anyMeasurementUpdated: Boolean = - measurementSupplier.syncInternalMeasurements( - toBeSyncedInternalMeasurements, - apiAuthenticationKey, - principal, - ) - - /** - * If any measurement got updated, pull the list of the up-to-date internal metrics. Otherwise, - * use the original list. - * - * TODO(@riemanli): a potential improvement can be done by only getting the metrics whose - * measurements are updated. Re-evaluate when a load-test is ready after deployment. - */ - val internalMetrics: List = - if (anyMeasurementUpdated) { - batchGetInternalMetrics( - principal.resourceKey.measurementConsumerId, - subResults.map { internalMetric -> internalMetric.externalMetricId }, - ) - } else { - subResults - } + val subResultsByState: Map> = + results.subList(0, min(results.size, listMetricsPageToken.pageSize)).groupBy { it.state } return listMetricsResponse { - metrics += internalMetrics.map { it.toMetric(variances) } + metrics += syncAndConvertInternalMetricsToPublicMetrics(subResultsByState, principal) if (nextPageToken != null) { this.nextPageToken = nextPageToken.toByteString().base64UrlEncode() @@ -1389,7 +1309,7 @@ class MetricsService( buildInternalCreateMetricRequest( principal.resourceKey.measurementConsumerId, request, - batchGetReportingSetsResponse.reportingSetsList.first(), + batchGetReportingSetsResponse.reportingSetsList.single(), ) val internalMetric = @@ -1413,7 +1333,7 @@ class MetricsService( .asRuntimeException() } - if (internalMetric.state == Metric.State.RUNNING) { + if (internalMetric.state == InternalMetric.State.RUNNING) { measurementSupplier.createCmmsMeasurements(listOf(internalMetric), principal) } @@ -1520,7 +1440,9 @@ class MetricsService( } val internalRunningMetrics = - internalMetrics.filter { internalMetric -> internalMetric.state == Metric.State.RUNNING } + internalMetrics.filter { internalMetric -> + internalMetric.state == InternalMetric.State.RUNNING + } if (internalRunningMetrics.isNotEmpty()) { measurementSupplier.createCmmsMeasurements(internalRunningMetrics, principal) } @@ -1651,6 +1573,67 @@ class MetricsService( } } + /** Converts [InternalMetric]s to public [Metric]s after syncing [Measurement]s. */ + private suspend fun syncAndConvertInternalMetricsToPublicMetrics( + metricsByState: Map>, + principal: MeasurementConsumerPrincipal, + ): List { + // Only syncs pending measurements which can only be in metrics that are still running. + val toBeSyncedInternalMeasurements: List = + if (metricsByState.containsKey(InternalMetric.State.RUNNING)) { + metricsByState + .getValue(InternalMetric.State.RUNNING) + .flatMap { internalMetric -> internalMetric.weightedMeasurementsList } + .map { weightedMeasurement -> weightedMeasurement.measurement } + .filter { internalMeasurement -> + internalMeasurement.state == InternalMeasurement.State.PENDING + } + } else { + emptyList() + } + + val anyMeasurementUpdated: Boolean = + measurementSupplier.syncInternalMeasurements( + toBeSyncedInternalMeasurements, + principal.config.apiKey, + principal, + ) + + return buildList { + for (state in metricsByState.keys) { + when (state) { + InternalMetric.State.SUCCEEDED, + InternalMetric.State.FAILED -> + addAll(metricsByState.getValue(state).map { it.toMetric(variances) }) + InternalMetric.State.RUNNING -> { + if (anyMeasurementUpdated) { + val updatedInternalMetrics = + batchGetInternalMetrics( + principal.resourceKey.measurementConsumerId, + metricsByState.getValue(InternalMetric.State.RUNNING).map { it.externalMetricId }, + ) + addAll(updatedInternalMetrics.map { it.toMetric(variances) }) + } else { + addAll(metricsByState.getValue(state).map { it.toMetric(variances) }) + } + } + InternalMetric.State.STATE_UNSPECIFIED -> { + // Metrics created before state was tracked in the database will have the state be + // unspecified. This calculates the correct state for those metrics. + addAll( + metricsByState.getValue(state).map { internalMetric -> + internalMetric + .copy { this.state = internalMetric.calculateState() } + .toMetric(variances) + } + ) + } + InternalMetric.State.UNRECOGNIZED -> error("Invalid Metric State") + } + } + } + } + companion object { private val RESOURCE_ID_REGEX = Regex("^[a-z]([a-z0-9-]{0,61}[a-z0-9])?$") } @@ -1706,7 +1689,7 @@ private fun InternalMetric.toMetric(variances: Variances): Metric { timeInterval = source.timeInterval metricSpec = source.metricSpec.toMetricSpec() filters += source.details.filtersList - state = source.state + state = source.state.toPublic() createTime = source.createTime if (state == Metric.State.SUCCEEDED) { result = buildMetricResult(source, variances) @@ -1871,7 +1854,7 @@ private fun calculateWatchDurationResult( // Only compute univariate statistics for union-only operations, i.e. single source measurement. if (weightedMeasurements.size == 1) { - val weightedMeasurement = weightedMeasurements.first() + val weightedMeasurement = weightedMeasurements.single() val weightedMeasurementVarianceParamsList: List = buildWeightedWatchDurationMeasurementVarianceParamsPerResult( @@ -2051,7 +2034,7 @@ private fun calculateImpressionResult( // Only compute univariate statistics for union-only operations, i.e. single source measurement. if (weightedMeasurements.size == 1) { - val weightedMeasurement = weightedMeasurements.first() + val weightedMeasurement = weightedMeasurements.single() val weightedMeasurementVarianceParamsList: List = buildWeightedImpressionMeasurementVarianceParamsPerResult(weightedMeasurement, metricSpec) @@ -2319,7 +2302,7 @@ fun buildWeightedFrequencyMeasurementVarianceParams( val frequencyResult: InternalMeasurement.Result.Frequency = if (weightedMeasurement.measurement.details.resultsList.size == 1) { - weightedMeasurement.measurement.details.resultsList.first().frequency + weightedMeasurement.measurement.details.resultsList.single().frequency } else if (weightedMeasurement.measurement.details.resultsList.size > 1) { failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) { "No supported methodology generates more than one frequency result." @@ -2622,14 +2605,13 @@ private operator fun ProtoDuration.plus(other: ProtoDuration): ProtoDuration { return Durations.add(this, other) } -private val InternalMetric.state: Metric.State - get() { - val measurementStates = weightedMeasurementsList.map { it.measurement.state } - return if (measurementStates.all { it == InternalMeasurement.State.SUCCEEDED }) { - Metric.State.SUCCEEDED - } else if (measurementStates.any { it == InternalMeasurement.State.FAILED }) { - Metric.State.FAILED - } else { - Metric.State.RUNNING - } +private fun InternalMetric.calculateState(): InternalMetric.State { + val measurementStates = weightedMeasurementsList.map { it.measurement.state } + return if (measurementStates.all { it == InternalMeasurement.State.SUCCEEDED }) { + InternalMetric.State.SUCCEEDED + } else if (measurementStates.any { it == InternalMeasurement.State.FAILED }) { + InternalMetric.State.FAILED + } else { + InternalMetric.State.RUNNING } +} diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ProtoConversions.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ProtoConversions.kt index 81349a4b673..db80785e7ff 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ProtoConversions.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ProtoConversions.kt @@ -51,6 +51,7 @@ import org.wfanet.measurement.internal.reporting.v2.LiquidLegionsDistribution as import org.wfanet.measurement.internal.reporting.v2.LiquidLegionsV2 import org.wfanet.measurement.internal.reporting.v2.Measurement as InternalMeasurement import org.wfanet.measurement.internal.reporting.v2.MeasurementKt as InternalMeasurementKt +import org.wfanet.measurement.internal.reporting.v2.Metric as InternalMetric import org.wfanet.measurement.internal.reporting.v2.MetricSpec as InternalMetricSpec import org.wfanet.measurement.internal.reporting.v2.MetricSpecKt as InternalMetricSpecKt import org.wfanet.measurement.internal.reporting.v2.NoiseMechanism as InternalNoiseMechanism @@ -86,6 +87,7 @@ import org.wfanet.measurement.reporting.v2alpha.CreateMetricRequest import org.wfanet.measurement.reporting.v2alpha.ListMetricsPageToken import org.wfanet.measurement.reporting.v2alpha.ListReportingSetsPageToken import org.wfanet.measurement.reporting.v2alpha.ListReportsPageToken +import org.wfanet.measurement.reporting.v2alpha.Metric import org.wfanet.measurement.reporting.v2alpha.MetricSpec import org.wfanet.measurement.reporting.v2alpha.MetricSpecKt import org.wfanet.measurement.reporting.v2alpha.Report @@ -369,6 +371,20 @@ fun InternalMetricSpec.WatchDurationParams.toDuration(): MeasurementSpec.Duratio } } +/** Converts an internal [InternalMetric.State] to a public [Metric.State]. */ +fun InternalMetric.State.toPublic(): Metric.State { + return when (this) { + InternalMetric.State.RUNNING -> Metric.State.RUNNING + InternalMetric.State.SUCCEEDED -> Metric.State.SUCCEEDED + InternalMetric.State.FAILED -> Metric.State.FAILED + InternalMetric.State.STATE_UNSPECIFIED -> Metric.State.STATE_UNSPECIFIED + InternalMetric.State.UNRECOGNIZED -> + // State is set by the system so if this is reached, something went wrong. + throw Status.UNKNOWN.withDescription("There is an unknown problem with the Metric") + .asRuntimeException() + } +} + /** Converts a CMM [Measurement.Failure] to an [InternalMeasurement.Failure]. */ fun Measurement.Failure.toInternal(): InternalMeasurement.Failure { val source = this diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/internal/testing/v2/MetricsServiceTest.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/internal/testing/v2/MetricsServiceTest.kt index 40c3919b9ec..8132f3b990b 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/internal/testing/v2/MetricsServiceTest.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/internal/testing/v2/MetricsServiceTest.kt @@ -34,8 +34,14 @@ import org.junit.runner.RunWith import org.junit.runners.JUnit4 import org.wfanet.measurement.common.identity.IdGenerator import org.wfanet.measurement.common.identity.RandomIdGenerator +import org.wfanet.measurement.internal.reporting.v2.BatchSetCmmsMeasurementIdsRequestKt +import org.wfanet.measurement.internal.reporting.v2.BatchSetMeasurementFailuresRequestKt +import org.wfanet.measurement.internal.reporting.v2.BatchSetMeasurementResultsRequestKt import org.wfanet.measurement.internal.reporting.v2.CreateMetricRequest import org.wfanet.measurement.internal.reporting.v2.MeasurementConsumersGrpcKt.MeasurementConsumersCoroutineImplBase +import org.wfanet.measurement.internal.reporting.v2.MeasurementKt +import org.wfanet.measurement.internal.reporting.v2.MeasurementsGrpcKt.MeasurementsCoroutineImplBase +import org.wfanet.measurement.internal.reporting.v2.Metric import org.wfanet.measurement.internal.reporting.v2.MetricKt import org.wfanet.measurement.internal.reporting.v2.MetricSpec import org.wfanet.measurement.internal.reporting.v2.MetricSpecKt @@ -46,6 +52,9 @@ import org.wfanet.measurement.internal.reporting.v2.ReportingSetsGrpcKt.Reportin import org.wfanet.measurement.internal.reporting.v2.StreamMetricsRequestKt import org.wfanet.measurement.internal.reporting.v2.batchCreateMetricsRequest import org.wfanet.measurement.internal.reporting.v2.batchGetMetricsRequest +import org.wfanet.measurement.internal.reporting.v2.batchSetCmmsMeasurementIdsRequest +import org.wfanet.measurement.internal.reporting.v2.batchSetMeasurementFailuresRequest +import org.wfanet.measurement.internal.reporting.v2.batchSetMeasurementResultsRequest import org.wfanet.measurement.internal.reporting.v2.copy import org.wfanet.measurement.internal.reporting.v2.createMetricRequest import org.wfanet.measurement.internal.reporting.v2.createReportingSetRequest @@ -67,6 +76,7 @@ abstract class MetricsServiceTest { val metricsService: T, val reportingSetsService: ReportingSetsCoroutineImplBase, val measurementConsumersService: MeasurementConsumersCoroutineImplBase, + val measurementsService: MeasurementsCoroutineImplBase, ) /** Instance of the service under test. */ @@ -74,6 +84,7 @@ abstract class MetricsServiceTest { private lateinit var reportingSetsService: ReportingSetsCoroutineImplBase private lateinit var measurementConsumersService: MeasurementConsumersCoroutineImplBase + private lateinit var measurementsService: MeasurementsCoroutineImplBase /** Constructs the services being tested. */ protected abstract fun newServices(idGenerator: IdGenerator): Services @@ -84,6 +95,7 @@ abstract class MetricsServiceTest { service = services.metricsService reportingSetsService = services.reportingSetsService measurementConsumersService = services.measurementConsumersService + measurementsService = services.measurementsService } @Test @@ -178,6 +190,7 @@ abstract class MetricsServiceTest { assertThat(createdMetric.externalMetricId).isNotEqualTo(0) assertThat(createdMetric.hasCreateTime()).isTrue() + assertThat(createdMetric.state).isEqualTo(Metric.State.RUNNING) createdMetric.weightedMeasurementsList.forEach { assertThat(it.measurement.cmmsCreateMeasurementRequestId).isNotEmpty() } @@ -281,6 +294,7 @@ abstract class MetricsServiceTest { assertThat(createdMetric.externalMetricId).isNotEqualTo(0) assertThat(createdMetric.hasCreateTime()).isTrue() + assertThat(createdMetric.state).isEqualTo(Metric.State.RUNNING) createdMetric.weightedMeasurementsList.forEach { assertThat(it.measurement.cmmsCreateMeasurementRequestId).isNotEmpty() } @@ -379,6 +393,7 @@ abstract class MetricsServiceTest { assertThat(createdMetric.externalMetricId).isNotEqualTo(0) assertThat(createdMetric.hasCreateTime()).isTrue() + assertThat(createdMetric.state).isEqualTo(Metric.State.RUNNING) createdMetric.weightedMeasurementsList.forEach { assertThat(it.measurement.cmmsCreateMeasurementRequestId).isNotEmpty() } @@ -477,6 +492,7 @@ abstract class MetricsServiceTest { assertThat(createdMetric.externalMetricId).isNotEqualTo(0) assertThat(createdMetric.hasCreateTime()).isTrue() + assertThat(createdMetric.state).isEqualTo(Metric.State.RUNNING) createdMetric.weightedMeasurementsList.forEach { assertThat(it.measurement.cmmsCreateMeasurementRequestId).isNotEmpty() } @@ -527,6 +543,7 @@ abstract class MetricsServiceTest { assertThat(createdMetric.externalMetricId).isNotEmpty() assertThat(createdMetric.hasCreateTime()).isTrue() + assertThat(createdMetric.state).isEqualTo(Metric.State.RUNNING) assertThat( createdMetric.weightedMeasurementsList.first().measurement.cmmsCreateMeasurementRequestId ) @@ -594,6 +611,7 @@ abstract class MetricsServiceTest { assertThat(createdMetric.externalMetricId).isNotEqualTo(0) assertThat(createdMetric.hasCreateTime()).isTrue() + assertThat(createdMetric.state).isEqualTo(Metric.State.RUNNING) createdMetric.weightedMeasurementsList.forEach { assertThat(it.measurement.cmmsCreateMeasurementRequestId).isNotEmpty() } @@ -697,6 +715,7 @@ abstract class MetricsServiceTest { ) assertThat(createdMetric.externalMetricId).isNotEqualTo(0L) + assertThat(createdMetric.state).isEqualTo(Metric.State.RUNNING) val sameCreatedMetric = service.createMetric( @@ -1359,6 +1378,7 @@ abstract class MetricsServiceTest { assertThat(batchCreateMetricsResponse.metricsList).hasSize(1) assertThat(createdMetric.externalMetricId).isNotEqualTo(0) assertThat(createdMetric.hasCreateTime()).isTrue() + assertThat(createdMetric.state).isEqualTo(Metric.State.RUNNING) createdMetric.weightedMeasurementsList.forEach { assertThat(it.measurement.cmmsCreateMeasurementRequestId).isNotEmpty() } @@ -1442,8 +1462,10 @@ abstract class MetricsServiceTest { assertThat(batchCreateMetricsResponse.metricsList).hasSize(2) assertThat(createdMetric.externalMetricId).isNotEqualTo(0) assertThat(createdMetric.hasCreateTime()).isTrue() + assertThat(createdMetric.state).isEqualTo(Metric.State.RUNNING) assertThat(createdMetric2.externalMetricId).isNotEqualTo(0) assertThat(createdMetric2.hasCreateTime()).isTrue() + assertThat(createdMetric2.state).isEqualTo(Metric.State.RUNNING) batchCreateMetricsResponse.metricsList.forEach { it.weightedMeasurementsList.forEach { weightedMeasurement -> assertThat(weightedMeasurement.measurement.cmmsCreateMeasurementRequestId).isNotEmpty() @@ -1544,6 +1566,7 @@ abstract class MetricsServiceTest { val createdMetric2 = batchCreateMetricsResponse.metricsList.last() assertThat(createdMetric2.externalMetricId).isNotEqualTo(0) assertThat(createdMetric2.hasCreateTime()).isTrue() + assertThat(createdMetric2.state).isEqualTo(Metric.State.RUNNING) createdMetric2.weightedMeasurementsList.forEach { assertThat(it.measurement.cmmsCreateMeasurementRequestId).isNotEmpty() } @@ -2505,6 +2528,276 @@ abstract class MetricsServiceTest { .containsExactly(createdMetric) } + @Test + fun `batchGetMetrics returns metric with SUCCEEDED state when all measurements SUCCEEDED`(): + Unit = runBlocking { + createMeasurementConsumer(CMMS_MEASUREMENT_CONSUMER_ID, measurementConsumersService) + + val createMetricRequest = + createCreateMetricRequest(CMMS_MEASUREMENT_CONSUMER_ID, reportingSetsService).copy { + val source = this + metric = + source.metric.copy { + weightedMeasurements += + MetricKt.weightedMeasurement { + weight = 2 + binaryRepresentation = 1 + measurement = measurement { + cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID + cmmsCreateMeasurementRequestId = "1234" + timeInterval = interval { + startTime = timestamp { seconds = 10 } + endTime = timestamp { seconds = 100 } + } + primitiveReportingSetBases += + ReportingSetKt.primitiveReportingSetBasis { + externalReportingSetId = source.metric.externalReportingSetId + } + } + } + weightedMeasurements += + MetricKt.weightedMeasurement { + weight = 3 + binaryRepresentation = 2 + measurement = measurement { + cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID + cmmsCreateMeasurementRequestId = "1235" + timeInterval = interval { + startTime = timestamp { seconds = 10 } + endTime = timestamp { seconds = 100 } + } + primitiveReportingSetBases += + ReportingSetKt.primitiveReportingSetBasis { + externalReportingSetId = source.metric.externalReportingSetId + } + } + } + } + } + val createdMetric = service.createMetric(createMetricRequest) + + val suffix = "-1" + val batchSetCmmsMeasurementIdsRequest = batchSetCmmsMeasurementIdsRequest { + cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID + createdMetric.weightedMeasurementsList.forEach { + measurementIds += + BatchSetCmmsMeasurementIdsRequestKt.measurementIds { + cmmsCreateMeasurementRequestId = it.measurement.cmmsCreateMeasurementRequestId + cmmsMeasurementId = it.measurement.cmmsCreateMeasurementRequestId + suffix + } + } + } + measurementsService.batchSetCmmsMeasurementIds(batchSetCmmsMeasurementIdsRequest) + + val batchSetMeasurementResultsRequest = batchSetMeasurementResultsRequest { + cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID + createdMetric.weightedMeasurementsList.forEach { + measurementResults += + BatchSetMeasurementResultsRequestKt.measurementResult { + cmmsMeasurementId = it.measurement.cmmsCreateMeasurementRequestId + suffix + results += MeasurementKt.result { reach = MeasurementKt.ResultKt.reach { value = 2 } } + } + } + } + measurementsService.batchSetMeasurementResults(batchSetMeasurementResultsRequest) + + val retrievedMetrics = + service.batchGetMetrics( + batchGetMetricsRequest { + cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID + externalMetricIds += createdMetric.externalMetricId + } + ) + + assertThat(retrievedMetrics.metricsList.first().state).isEqualTo(Metric.State.SUCCEEDED) + } + + @Test + fun `batchGetMetrics returns metric with FAILED state when measurement FAILED`(): Unit = + runBlocking { + createMeasurementConsumer(CMMS_MEASUREMENT_CONSUMER_ID, measurementConsumersService) + + val createMetricRequest = + createCreateMetricRequest(CMMS_MEASUREMENT_CONSUMER_ID, reportingSetsService).copy { + val source = this + metric = + source.metric.copy { + weightedMeasurements += + MetricKt.weightedMeasurement { + weight = 2 + binaryRepresentation = 1 + measurement = measurement { + cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID + cmmsCreateMeasurementRequestId = "1234" + timeInterval = interval { + startTime = timestamp { seconds = 10 } + endTime = timestamp { seconds = 100 } + } + primitiveReportingSetBases += + ReportingSetKt.primitiveReportingSetBasis { + externalReportingSetId = source.metric.externalReportingSetId + } + } + } + weightedMeasurements += + MetricKt.weightedMeasurement { + weight = 3 + binaryRepresentation = 2 + measurement = measurement { + cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID + cmmsCreateMeasurementRequestId = "1235" + timeInterval = interval { + startTime = timestamp { seconds = 10 } + endTime = timestamp { seconds = 100 } + } + primitiveReportingSetBases += + ReportingSetKt.primitiveReportingSetBasis { + externalReportingSetId = source.metric.externalReportingSetId + } + } + } + } + } + val createdMetric = service.createMetric(createMetricRequest) + + val suffix = "-1" + val batchSetCmmsMeasurementIdsRequest = batchSetCmmsMeasurementIdsRequest { + cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID + createdMetric.weightedMeasurementsList.forEach { + measurementIds += + BatchSetCmmsMeasurementIdsRequestKt.measurementIds { + cmmsCreateMeasurementRequestId = it.measurement.cmmsCreateMeasurementRequestId + cmmsMeasurementId = it.measurement.cmmsCreateMeasurementRequestId + suffix + } + } + } + measurementsService.batchSetCmmsMeasurementIds(batchSetCmmsMeasurementIdsRequest) + + val batchSetMeasurementFailuresRequest = batchSetMeasurementFailuresRequest { + cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID + measurementFailures += + BatchSetMeasurementFailuresRequestKt.measurementFailure { + cmmsMeasurementId = + createdMetric.weightedMeasurementsList + .first() + .measurement + .cmmsCreateMeasurementRequestId + suffix + failure = MeasurementKt.failure { message = "failure" } + } + } + measurementsService.batchSetMeasurementFailures(batchSetMeasurementFailuresRequest) + + val retrievedMetrics = + service.batchGetMetrics( + batchGetMetricsRequest { + cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID + externalMetricIds += createdMetric.externalMetricId + } + ) + + assertThat(retrievedMetrics.metricsList.first().state).isEqualTo(Metric.State.FAILED) + } + + @Test + fun `batchGetMetrics gets FAILED state when a measurement SUCCEEDED and other FAILED`(): Unit = + runBlocking { + createMeasurementConsumer(CMMS_MEASUREMENT_CONSUMER_ID, measurementConsumersService) + + val createMetricRequest = + createCreateMetricRequest(CMMS_MEASUREMENT_CONSUMER_ID, reportingSetsService).copy { + val source = this + metric = + source.metric.copy { + weightedMeasurements += + MetricKt.weightedMeasurement { + weight = 2 + binaryRepresentation = 1 + measurement = measurement { + cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID + cmmsCreateMeasurementRequestId = "1234" + timeInterval = interval { + startTime = timestamp { seconds = 10 } + endTime = timestamp { seconds = 100 } + } + primitiveReportingSetBases += + ReportingSetKt.primitiveReportingSetBasis { + externalReportingSetId = source.metric.externalReportingSetId + } + } + } + weightedMeasurements += + MetricKt.weightedMeasurement { + weight = 3 + binaryRepresentation = 2 + measurement = measurement { + cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID + cmmsCreateMeasurementRequestId = "1235" + timeInterval = interval { + startTime = timestamp { seconds = 10 } + endTime = timestamp { seconds = 100 } + } + primitiveReportingSetBases += + ReportingSetKt.primitiveReportingSetBasis { + externalReportingSetId = source.metric.externalReportingSetId + } + } + } + } + } + val createdMetric = service.createMetric(createMetricRequest) + + val suffix = "-1" + val batchSetCmmsMeasurementIdsRequest = batchSetCmmsMeasurementIdsRequest { + cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID + createdMetric.weightedMeasurementsList.forEach { + measurementIds += + BatchSetCmmsMeasurementIdsRequestKt.measurementIds { + cmmsCreateMeasurementRequestId = it.measurement.cmmsCreateMeasurementRequestId + cmmsMeasurementId = it.measurement.cmmsCreateMeasurementRequestId + suffix + } + } + } + measurementsService.batchSetCmmsMeasurementIds(batchSetCmmsMeasurementIdsRequest) + + val batchSetMeasurementResultsRequest = batchSetMeasurementResultsRequest { + cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID + measurementResults += + BatchSetMeasurementResultsRequestKt.measurementResult { + cmmsMeasurementId = + createdMetric.weightedMeasurementsList + .first() + .measurement + .cmmsCreateMeasurementRequestId + suffix + results += MeasurementKt.result { reach = MeasurementKt.ResultKt.reach { value = 2 } } + } + } + measurementsService.batchSetMeasurementResults(batchSetMeasurementResultsRequest) + + val batchSetMeasurementFailuresRequest = batchSetMeasurementFailuresRequest { + cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID + measurementFailures += + BatchSetMeasurementFailuresRequestKt.measurementFailure { + cmmsMeasurementId = + createdMetric.weightedMeasurementsList + .last() + .measurement + .cmmsCreateMeasurementRequestId + suffix + failure = MeasurementKt.failure { message = "failure" } + } + } + measurementsService.batchSetMeasurementFailures(batchSetMeasurementFailuresRequest) + + val retrievedMetrics = + service.batchGetMetrics( + batchGetMetricsRequest { + cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID + externalMetricIds += createdMetric.externalMetricId + } + ) + + assertThat(retrievedMetrics.metricsList.first().state).isEqualTo(Metric.State.FAILED) + } + @Test fun `batchGetMetrics throws NOT_FOUND when not all metrics found`(): Unit = runBlocking { createMeasurementConsumer(CMMS_MEASUREMENT_CONSUMER_ID, measurementConsumersService) diff --git a/src/main/proto/wfa/measurement/internal/reporting/v2/metric.proto b/src/main/proto/wfa/measurement/internal/reporting/v2/metric.proto index cdc2b1b5d1c..1dd42589e5c 100644 --- a/src/main/proto/wfa/measurement/internal/reporting/v2/metric.proto +++ b/src/main/proto/wfa/measurement/internal/reporting/v2/metric.proto @@ -89,4 +89,12 @@ message Metric { repeated string filters = 1; } Details details = 8; + + enum State { + STATE_UNSPECIFIED = 0; + RUNNING = 1; + SUCCEEDED = 2; + FAILED = 3; + } + State state = 9; } diff --git a/src/main/resources/reporting/postgres/add-state-column-to-metrics.sql b/src/main/resources/reporting/postgres/add-state-column-to-metrics.sql new file mode 100644 index 00000000000..4bd3df90132 --- /dev/null +++ b/src/main/resources/reporting/postgres/add-state-column-to-metrics.sql @@ -0,0 +1,40 @@ +-- liquibase formatted sql + +-- Copyright 2024 The Cross-Media Measurement Authors +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +-- Postgres database schema for the Reporting server. +-- +-- Table hierarchy: +-- Root +-- └── MeasurementConsumers +-- ├── EventGroups +-- ├── ReportingSets +-- │ ├── ReportingSetEventGroups +-- │ ├── PrimitiveReportingSetBases +-- │ │ └── PrimitiveReportingSetBasisFilters +-- │ ├── SetExpressions +-- │ └── WeightedSubsetUnions +-- │ └── WeightedSubsetUnionPrimitiveReportingSetBases +-- ├── Metrics +-- │ └── MetricMeasurements +-- ├── Measurements +-- │ └── MeasurementPrimitiveReportingSetBases +-- ├── MetricCalculationSpecs +-- └── Reports +-- ├── ReportTimeIntervals +-- └── MetricCalculationSpecReportingMetrics + +-- changeset tristanvuong2021:add-state-column-metrics-table dbms:postgresql +ALTER TABLE Metrics ADD COLUMN State integer NOT NULL DEFAULT 0; diff --git a/src/main/resources/reporting/postgres/changelog-v2.yaml b/src/main/resources/reporting/postgres/changelog-v2.yaml index d84b6d9149e..267645e3de6 100644 --- a/src/main/resources/reporting/postgres/changelog-v2.yaml +++ b/src/main/resources/reporting/postgres/changelog-v2.yaml @@ -40,3 +40,6 @@ databaseChangeLog: - include: file: drop-report-time-intervals-table.sql relativeToChangeLogFile: true +- include: + file: add-state-column-to-metrics.sql + relativeToChangeLogFile: true diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/PostgresMetricsServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/PostgresMetricsServiceTest.kt index 0398dbfdfba..79f0c13f55d 100644 --- a/src/test/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/PostgresMetricsServiceTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/PostgresMetricsServiceTest.kt @@ -33,6 +33,7 @@ class PostgresMetricsServiceTest : MetricsServiceTest() PostgresMetricsService(idGenerator, client), PostgresReportingSetsService(idGenerator, client), PostgresMeasurementConsumersService(idGenerator, client), + PostgresMeasurementsService(idGenerator, client), ) } diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsServiceTest.kt index 15474945d18..8e24ca9e780 100644 --- a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsServiceTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsServiceTest.kt @@ -162,6 +162,7 @@ import org.wfanet.measurement.internal.reporting.v2.Measurement as InternalMeasu import org.wfanet.measurement.internal.reporting.v2.MeasurementKt as InternalMeasurementKt import org.wfanet.measurement.internal.reporting.v2.MeasurementsGrpcKt as InternalMeasurementsGrpcKt import org.wfanet.measurement.internal.reporting.v2.MeasurementsGrpcKt.MeasurementsCoroutineImplBase as InternalMeasurementsCoroutineImplBase +import org.wfanet.measurement.internal.reporting.v2.Metric as InternalMetric import org.wfanet.measurement.internal.reporting.v2.MetricKt as InternalMetricKt import org.wfanet.measurement.internal.reporting.v2.MetricKt.weightedMeasurement import org.wfanet.measurement.internal.reporting.v2.MetricSpec as InternalMetricSpec @@ -1447,6 +1448,7 @@ private val INTERNAL_PENDING_INITIAL_INCREMENTAL_REACH_METRIC = INTERNAL_REQUESTING_INCREMENTAL_REACH_METRIC.copy { externalMetricId = "331L" createTime = Instant.now().toProtoTime() + state = InternalMetric.State.RUNNING weightedMeasurements.clear() weightedMeasurements += weightedMeasurement { weight = 1 @@ -1480,6 +1482,7 @@ private val INTERNAL_PENDING_INCREMENTAL_REACH_METRIC = private val INTERNAL_SUCCEEDED_INCREMENTAL_REACH_METRIC = INTERNAL_PENDING_INCREMENTAL_REACH_METRIC.copy { + state = InternalMetric.State.SUCCEEDED weightedMeasurements.clear() weightedMeasurements += weightedMeasurement { weight = 1 @@ -1536,6 +1539,7 @@ private val INTERNAL_PENDING_INITIAL_SINGLE_PUBLISHER_REACH_FREQUENCY_METRIC = INTERNAL_REQUESTING_SINGLE_PUBLISHER_REACH_FREQUENCY_METRIC.copy { externalMetricId = "332L" createTime = Instant.now().toProtoTime() + state = InternalMetric.State.RUNNING weightedMeasurements.clear() weightedMeasurements += weightedMeasurement { weight = 1 @@ -1559,6 +1563,7 @@ private val INTERNAL_PENDING_SINGLE_PUBLISHER_REACH_FREQUENCY_METRIC = private val INTERNAL_SUCCEEDED_SINGLE_PUBLISHER_REACH_FREQUENCY_METRIC = INTERNAL_PENDING_SINGLE_PUBLISHER_REACH_FREQUENCY_METRIC.copy { + state = InternalMetric.State.SUCCEEDED weightedMeasurements.clear() weightedMeasurements += weightedMeasurement { weight = 1 @@ -1605,6 +1610,7 @@ private val INTERNAL_PENDING_INITIAL_SINGLE_PUBLISHER_IMPRESSION_METRIC = INTERNAL_REQUESTING_SINGLE_PUBLISHER_IMPRESSION_METRIC.copy { externalMetricId = "333L" createTime = Instant.now().toProtoTime() + state = InternalMetric.State.RUNNING weightedMeasurements.clear() weightedMeasurements += weightedMeasurement { weight = 1 @@ -1626,6 +1632,7 @@ private val INTERNAL_PENDING_SINGLE_PUBLISHER_IMPRESSION_METRIC = private val INTERNAL_FAILED_SINGLE_PUBLISHER_IMPRESSION_METRIC = INTERNAL_PENDING_SINGLE_PUBLISHER_IMPRESSION_METRIC.copy { + state = InternalMetric.State.FAILED weightedMeasurements.clear() weightedMeasurements += weightedMeasurement { weight = 1 @@ -1636,6 +1643,7 @@ private val INTERNAL_FAILED_SINGLE_PUBLISHER_IMPRESSION_METRIC = private val INTERNAL_SUCCEEDED_SINGLE_PUBLISHER_IMPRESSION_METRIC = INTERNAL_PENDING_SINGLE_PUBLISHER_IMPRESSION_METRIC.copy { + state = InternalMetric.State.SUCCEEDED weightedMeasurements.clear() weightedMeasurements += weightedMeasurement { weight = 1 @@ -1646,6 +1654,7 @@ private val INTERNAL_SUCCEEDED_SINGLE_PUBLISHER_IMPRESSION_METRIC = private val INTERNAL_SUCCEEDED_SINGLE_PUBLISHER_IMPRESSION_METRIC_CUSTOM_CAP = INTERNAL_PENDING_SINGLE_PUBLISHER_IMPRESSION_METRIC.copy { + state = InternalMetric.State.SUCCEEDED weightedMeasurements.clear() weightedMeasurements += weightedMeasurement { weight = 1 @@ -1687,6 +1696,7 @@ private val INTERNAL_PENDING_INITIAL_CROSS_PUBLISHER_WATCH_DURATION_METRIC = INTERNAL_REQUESTING_CROSS_PUBLISHER_WATCH_DURATION_METRIC.copy { externalMetricId = "334L" createTime = Instant.now().toProtoTime() + state = InternalMetric.State.RUNNING weightedMeasurements.clear() weightedMeasurements += weightedMeasurement { weight = 1 @@ -1707,6 +1717,7 @@ private val INTERNAL_PENDING_CROSS_PUBLISHER_WATCH_DURATION_METRIC = private val INTERNAL_SUCCEEDED_CROSS_PUBLISHER_WATCH_DURATION_METRIC = INTERNAL_PENDING_CROSS_PUBLISHER_WATCH_DURATION_METRIC.copy { + state = InternalMetric.State.SUCCEEDED weightedMeasurements.clear() weightedMeasurements += weightedMeasurement { weight = 1 @@ -1741,6 +1752,7 @@ private val INTERNAL_PENDING_INITIAL_POPULATION_METRIC = INTERNAL_REQUESTING_POPULATION_METRIC.copy { externalMetricId = "331L" createTime = Instant.now().toProtoTime() + state = InternalMetric.State.RUNNING weightedMeasurements.clear() weightedMeasurements += weightedMeasurement { @@ -1762,6 +1774,7 @@ val INTERNAL_PENDING_POPULATION_METRIC = val INTERNAL_SUCCEEDED_POPULATION_METRIC = INTERNAL_PENDING_POPULATION_METRIC.copy { + state = InternalMetric.State.SUCCEEDED weightedMeasurements.clear() weightedMeasurements += weightedMeasurement { weight = 1 @@ -4737,6 +4750,7 @@ class MetricsServiceTest { internalBatchGetMetricsResponse { metrics += INTERNAL_PENDING_INCREMENTAL_REACH_METRIC.copy { + state = InternalMetric.State.SUCCEEDED weightedMeasurements.clear() weightedMeasurements += weightedMeasurement { weight = 1 @@ -4916,6 +4930,7 @@ class MetricsServiceTest { metrics += INTERNAL_PENDING_INCREMENTAL_REACH_METRIC metrics += INTERNAL_PENDING_SINGLE_PUBLISHER_IMPRESSION_METRIC.copy { + state = InternalMetric.State.FAILED weightedMeasurements.clear() weightedMeasurements += weightedMeasurement { weight = 1 @@ -5260,6 +5275,57 @@ class MetricsServiceTest { assertThat(exception).hasMessageThat().contains(AGGREGATOR_CERTIFICATE.name) } + @Test + fun `getMetric returns the metric with SUCCEEDED when the metric has state STATE_UNSPECIFIED`() = + runBlocking { + whenever(internalMetricsMock.batchGetMetrics(any())) + .thenReturn( + internalBatchGetMetricsResponse { + metrics += INTERNAL_SUCCEEDED_INCREMENTAL_REACH_METRIC.copy { clearState() } + } + ) + + val request = getMetricRequest { name = SUCCEEDED_INCREMENTAL_REACH_METRIC.name } + + val result = + withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMERS.values.first().name, CONFIG) { + runBlocking { service.getMetric(request) } + } + + // Verify proto argument of internal MetricsCoroutineImplBase::batchGetMetrics + val batchGetInternalMetricsCaptor: KArgumentCaptor = + argumentCaptor() + verifyBlocking(internalMetricsMock, times(1)) { + batchGetMetrics(batchGetInternalMetricsCaptor.capture()) + } + val capturedInternalGetMetricRequests = batchGetInternalMetricsCaptor.allValues + assertThat(capturedInternalGetMetricRequests) + .containsExactly( + internalBatchGetMetricsRequest { + cmmsMeasurementConsumerId = + INTERNAL_SUCCEEDED_INCREMENTAL_REACH_METRIC.cmmsMeasurementConsumerId + externalMetricIds += INTERNAL_SUCCEEDED_INCREMENTAL_REACH_METRIC.externalMetricId + } + ) + + // Verify proto argument of internal MeasurementsCoroutineImplBase::batchSetMeasurementResults + val batchSetMeasurementResultsCaptor: KArgumentCaptor = + argumentCaptor() + verifyBlocking(internalMeasurementsMock, never()) { + batchSetMeasurementResults(batchSetMeasurementResultsCaptor.capture()) + } + + // Verify proto argument of internal + // MeasurementsCoroutineImplBase::batchSetMeasurementFailures + val batchSetMeasurementFailuresCaptor: KArgumentCaptor = + argumentCaptor() + verifyBlocking(internalMeasurementsMock, never()) { + batchSetMeasurementFailures(batchSetMeasurementFailuresCaptor.capture()) + } + + assertThat(result).isEqualTo(SUCCEEDED_INCREMENTAL_REACH_METRIC) + } + @Test fun `getMetric returns the metric with SUCCEEDED when the metric is already succeeded`() = runBlocking {