diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/PostgresMetricsService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/PostgresMetricsService.kt index b3d838383a1..6fd48040f32 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/PostgresMetricsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/PostgresMetricsService.kt @@ -17,8 +17,15 @@ package org.wfanet.measurement.reporting.deploy.v2.postgres import io.grpc.Status +import java.lang.IllegalStateException import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.emitAll +import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.flow.map +import kotlinx.coroutines.flow.toList import org.wfanet.measurement.common.db.r2dbc.DatabaseClient +import org.wfanet.measurement.common.db.r2dbc.postgres.SerializableErrors.withSerializableErrorRetries +import org.wfanet.measurement.common.grpc.failGrpc import org.wfanet.measurement.common.grpc.grpcRequire import org.wfanet.measurement.common.identity.IdGenerator import org.wfanet.measurement.internal.reporting.v2.BatchCreateMetricsRequest @@ -31,11 +38,13 @@ import org.wfanet.measurement.internal.reporting.v2.MetricSpec import org.wfanet.measurement.internal.reporting.v2.MetricsGrpcKt.MetricsCoroutineImplBase import org.wfanet.measurement.internal.reporting.v2.StreamMetricsRequest import org.wfanet.measurement.internal.reporting.v2.batchCreateMetricsResponse +import org.wfanet.measurement.internal.reporting.v2.batchGetMetricsResponse +import org.wfanet.measurement.reporting.deploy.v2.postgres.readers.MetricReader import org.wfanet.measurement.reporting.deploy.v2.postgres.writers.CreateMetrics import org.wfanet.measurement.reporting.service.internal.MeasurementConsumerNotFoundException import org.wfanet.measurement.reporting.service.internal.ReportingSetNotFoundException -private const val MAX_BATCH_CREATE_SIZE = 200 +private const val MAX_BATCH_SIZE = 1000 class PostgresMetricsService( private val idGenerator: IdGenerator, @@ -70,7 +79,7 @@ class PostgresMetricsService( override suspend fun batchCreateMetrics( request: BatchCreateMetricsRequest ): BatchCreateMetricsResponse { - grpcRequire(request.requestsList.size <= MAX_BATCH_CREATE_SIZE) { "Too many requests." } + grpcRequire(request.requestsList.size <= MAX_BATCH_SIZE) { "Too many requests." } request.requestsList.forEach { grpcRequire(it.metric.hasTimeInterval()) { "Metric missing time interval." } @@ -106,10 +115,52 @@ class PostgresMetricsService( } override suspend fun batchGetMetrics(request: BatchGetMetricsRequest): BatchGetMetricsResponse { - return super.batchGetMetrics(request) + grpcRequire(request.cmmsMeasurementConsumerId.isNotBlank()) { + "CmmsMeasurementConsumerId is missing." + } + + grpcRequire(request.externalMetricIdsList.size <= MAX_BATCH_SIZE) { "Too many requests." } + + val readContext = client.readTransaction() + val metrics = + try { + MetricReader(readContext) + .batchGetMetrics(request) + .map { it.metric } + .withSerializableErrorRetries() + .toList() + } catch (e: IllegalStateException) { + failGrpc(Status.NOT_FOUND) { "Metric is not found" } + } finally { + readContext.close() + } + + if (metrics.size < request.externalMetricIdsList.size) { + failGrpc(Status.NOT_FOUND) { "Metric is not found" } + } + + return batchGetMetricsResponse { this.metrics += metrics } } override fun streamMetrics(request: StreamMetricsRequest): Flow { - return super.streamMetrics(request) + grpcRequire(request.filter.cmmsMeasurementConsumerId.isNotBlank()) { + "Filter is missing CmmsMeasurementConsumerId" + } + + return flow { + val readContext = client.readTransaction() + try { + emitAll( + MetricReader(readContext) + .readMetrics(request) + .map { it.metric } + .withSerializableErrorRetries() + ) + } catch (e: IllegalStateException) { + failGrpc(Status.NOT_FOUND) { "Metric is not found" } + } finally { + readContext.close() + } + } } } diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/readers/MetricReader.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/readers/MetricReader.kt index e2543510b1d..a3c8e532d81 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/readers/MetricReader.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/readers/MetricReader.kt @@ -16,11 +16,32 @@ package org.wfanet.measurement.reporting.deploy.v2.postgres.readers +import com.google.protobuf.Timestamp +import java.time.Instant +import java.util.UUID import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.emptyFlow +import kotlinx.coroutines.flow.flow +import org.wfanet.measurement.common.db.r2dbc.BoundStatement import org.wfanet.measurement.common.db.r2dbc.ReadContext +import org.wfanet.measurement.common.db.r2dbc.ResultRow +import org.wfanet.measurement.common.db.r2dbc.boundStatement +import org.wfanet.measurement.common.identity.ExternalId import org.wfanet.measurement.common.identity.InternalId +import org.wfanet.measurement.common.toProtoTime +import org.wfanet.measurement.internal.reporting.v2.BatchGetMetricsRequest +import org.wfanet.measurement.internal.reporting.v2.Measurement import org.wfanet.measurement.internal.reporting.v2.Metric +import org.wfanet.measurement.internal.reporting.v2.MetricKt +import org.wfanet.measurement.internal.reporting.v2.MetricSpec +import org.wfanet.measurement.internal.reporting.v2.MetricSpecKt +import org.wfanet.measurement.internal.reporting.v2.ReportingSetKt +import org.wfanet.measurement.internal.reporting.v2.StreamMetricsRequest +import org.wfanet.measurement.internal.reporting.v2.TimeInterval +import org.wfanet.measurement.internal.reporting.v2.measurement +import org.wfanet.measurement.internal.reporting.v2.metric +import org.wfanet.measurement.internal.reporting.v2.metricSpec +import org.wfanet.measurement.internal.reporting.v2.timeInterval class MetricReader(private val readContext: ReadContext) { data class Result( @@ -30,7 +51,98 @@ class MetricReader(private val readContext: ReadContext) { val metric: Metric ) - suspend fun readMetricsByRequestId( + private data class MetricInfo( + val measurementConsumerId: InternalId, + val cmmsMeasurementConsumerId: String, + val createMetricRequestId: String?, + val externalReportingSetId: ExternalId, + val metricId: InternalId, + val externalMetricId: ExternalId, + val createTime: Timestamp, + val timeInterval: TimeInterval, + val metricSpec: MetricSpec, + val weightedMeasurementInfoMap: MutableMap, + val details: Metric.Details, + ) + + private data class MetricMeasurementKey( + val measurementConsumerId: InternalId, + val metricId: InternalId, + val measurementId: InternalId, + ) + + private data class WeightedMeasurementInfo( + val weight: Int, + val measurementInfo: MeasurementInfo, + ) + + private data class MeasurementInfo( + val cmmsMeasurementId: String?, + val cmmsCreateMeasurementRequestId: String, + val timeInterval: TimeInterval, + // Key is primitiveReportingSetBasisId. + val primitiveReportingSetBasisInfoMap: MutableMap, + val state: Measurement.State, + val details: Measurement.Details, + ) + + private data class PrimitiveReportingSetBasisInfo( + val externalReportingSetId: ExternalId, + val filterSet: MutableSet, + ) + + private val baseSqlSelect: String = + """ + SELECT + CmmsMeasurementConsumerId, + Metrics.MeasurementConsumerId, + Metrics.CreateMetricRequestId, + ReportingSets.ExternalReportingSetId AS MetricsExternalReportingSetId, + Metrics.MetricId, + Metrics.ExternalMetricId, + Metrics.TimeIntervalStart AS MetricsTimeIntervalStart, + Metrics.TimeIntervalEndExclusive AS MetricsTimeIntervalEndExclusive, + Metrics.MetricType, + Metrics.DifferentialPrivacyEpsilon, + Metrics.DifferentialPrivacyDelta, + Metrics.FrequencyDifferentialPrivacyEpsilon, + Metrics.FrequencyDifferentialPrivacyDelta, + Metrics.MaximumFrequencyPerUser, + Metrics.MaximumWatchDurationPerUser, + Metrics.VidSamplingIntervalStart, + Metrics.VidSamplingIntervalWidth, + Metrics.CreateTime, + Metrics.MetricDetails, + MetricMeasurements.Coefficient, + Measurements.MeasurementId, + Measurements.CmmsCreateMeasurementRequestId, + Measurements.CmmsMeasurementId, + Measurements.TimeIntervalStart AS MeasurementsTimeIntervalStart, + Measurements.TimeIntervalEndExclusive AS MeasurementsTimeIntervalEndExclusive, + Measurements.State, + Measurements.MeasurementDetails, + PrimitiveReportingSetBases.PrimitiveReportingSetBasisId, + PrimitiveReportingSets.ExternalReportingSetId AS PrimitiveExternalReportingSetId, + PrimitiveReportingSetBasisFilters.Filter AS PrimitiveReportingSetBasisFilter + """ + + private val baseSqlJoins: String = + """ + JOIN ReportingSets USING(MeasurementConsumerId, ReportingSetId) + JOIN MetricMeasurements USING(MeasurementConsumerId, MetricId) + JOIN Measurements USING(MeasurementConsumerId, MeasurementId) + JOIN MeasurementPrimitiveReportingSetBases USING(MeasurementConsumerId, MeasurementId) + JOIN PrimitiveReportingSetBases USING(MeasurementConsumerId, PrimitiveReportingSetBasisId) + JOIN ReportingSets AS PrimitiveReportingSets + ON PrimitiveReportingSetBases.MeasurementConsumerId = PrimitiveReportingSets.MeasurementConsumerId + AND PrimitiveReportingSetBases.PrimitiveReportingSetId = PrimitiveReportingSets.ReportingSetId + LEFT JOIN PrimitiveReportingSetBasisFilters + ON PrimitiveReportingSetBases.MeasurementConsumerId = PrimitiveReportingSetBasisFilters.MeasurementConsumerId + AND PrimitiveReportingSetBases.PrimitiveReportingSetBasisId = PrimitiveReportingSetBasisFilters.PrimitiveReportingSetBasisId + """ + .trimIndent() + + fun readMetricsByRequestId( measurementConsumerId: InternalId, createMetricRequestIds: Collection ): Flow { @@ -38,7 +150,352 @@ class MetricReader(private val readContext: ReadContext) { return emptyFlow() } - // TODO(tristanvuong2021): implement read metric - return emptyFlow() + val sql = + StringBuilder( + baseSqlSelect + + """ + FROM MeasurementConsumers + JOIN Metrics USING(MeasurementConsumerId) + """ + + baseSqlJoins + + """ + WHERE Metrics.MeasurementConsumerId = $1 + AND CreateMetricRequestId IN + """ + ) + + var i = 2 + val bindingMap = mutableMapOf() + val inList = + createMetricRequestIds.joinToString(separator = ",", prefix = "(", postfix = ")") { + val index = "$$i" + bindingMap[it] = index + i++ + index + } + sql.append(inList) + + val statement = + boundStatement(sql.toString()) { + bind("$1", measurementConsumerId) + createMetricRequestIds.forEach { bind(bindingMap.getValue(it), it) } + } + + return createResultFlow(statement) + } + + fun batchGetMetrics( + request: BatchGetMetricsRequest, + ): Flow { + val sql = + StringBuilder( + baseSqlSelect + + """ + FROM MeasurementConsumers + JOIN Metrics USING(MeasurementConsumerId) + """ + + baseSqlJoins + + """ + WHERE CmmsMeasurementConsumerId = $1 + AND ExternalMetricId IN + """ + ) + + var i = 2 + val bindingMap = mutableMapOf() + val inList = + request.externalMetricIdsList.joinToString(separator = ",", prefix = "(", postfix = ")") { + val index = "$$i" + bindingMap[it] = index + i++ + index + } + sql.append(inList) + + val statement = + boundStatement(sql.toString()) { + bind("$1", request.cmmsMeasurementConsumerId) + request.externalMetricIdsList.forEach { bind(bindingMap.getValue(it), it) } + } + + return createResultFlow(statement) + } + + fun readMetrics( + request: StreamMetricsRequest, + ): Flow { + val statement = + boundStatement( + baseSqlSelect + + """ + FROM ( + SELECT * + FROM MeasurementConsumers + JOIN Metrics USING (MeasurementConsumerId) + WHERE CmmsMeasurementConsumerId = $1 + AND ExternalMetricId > $2 + ORDER BY ExternalMetricId ASC + LIMIT $3 + ) AS Metrics + """ + + baseSqlJoins + + """ + ORDER BY ExternalMetricId ASC + """ + ) { + bind("$1", request.filter.cmmsMeasurementConsumerId) + bind("$2", request.filter.externalMetricIdAfter) + if (request.limit > 0) { + bind("$3", request.limit) + } else { + bind("$3", 50) + } + } + + return createResultFlow(statement) + } + + private fun createResultFlow(statement: BoundStatement): Flow { + return flow { + val metricInfoMap = buildResultMap(statement) + + for (entry in metricInfoMap) { + val metricId = entry.key + val metricInfo = entry.value + + val metric = metric { + cmmsMeasurementConsumerId = metricInfo.cmmsMeasurementConsumerId + externalMetricId = metricInfo.externalMetricId.value + externalReportingSetId = metricInfo.externalReportingSetId.value + createTime = metricInfo.createTime + timeInterval = metricInfo.timeInterval + metricSpec = metricInfo.metricSpec + metricInfo.weightedMeasurementInfoMap.values.forEach { + weightedMeasurements += + MetricKt.weightedMeasurement { + weight = it.weight + measurement = measurement { + cmmsMeasurementConsumerId = metricInfo.cmmsMeasurementConsumerId + if (it.measurementInfo.cmmsMeasurementId != null) { + cmmsMeasurementId = it.measurementInfo.cmmsMeasurementId + } + cmmsCreateMeasurementRequestId = it.measurementInfo.cmmsCreateMeasurementRequestId + timeInterval = it.measurementInfo.timeInterval + it.measurementInfo.primitiveReportingSetBasisInfoMap.values.forEach { + primitiveReportingSetBases += + ReportingSetKt.primitiveReportingSetBasis { + externalReportingSetId = it.externalReportingSetId.value + filters += it.filterSet + } + } + state = it.measurementInfo.state + if (it.measurementInfo.details != Measurement.Details.getDefaultInstance()) { + details = it.measurementInfo.details + } + } + } + } + if (metricInfo.details != Metric.Details.getDefaultInstance()) { + details = metricInfo.details + } + } + + val createMetricRequestId = metricInfo.createMetricRequestId ?: "" + emit( + Result( + measurementConsumerId = metricInfo.measurementConsumerId, + metricId = metricId, + createMetricRequestId = createMetricRequestId, + metric = metric + ) + ) + } + } + } + + /** Returns a map that maintains the order of the query result. */ + private suspend fun buildResultMap(statement: BoundStatement): Map { + // Key is metricId. + val metricInfoMap: MutableMap = linkedMapOf() + + val translate: (row: ResultRow) -> Unit = { row: ResultRow -> + val measurementConsumerId: InternalId = row["MeasurementConsumerId"] + val cmmsMeasurementConsumerId: String = row["CmmsMeasurementConsumerId"] + val createMetricRequestId: String? = row["CreateMetricRequestId"] + val externalReportingSetId: ExternalId = row["MetricsExternalReportingSetId"] + val metricId: InternalId = row["MetricId"] + val externalMetricId: ExternalId = row["ExternalMetricId"] + val metricTimeIntervalStart: Instant = row["MetricsTimeIntervalStart"] + val metricTimeIntervalEnd: Instant = row["MetricsTimeIntervalEndExclusive"] + val metricType: MetricSpec.TypeCase = MetricSpec.TypeCase.forNumber(row["MetricType"]) + val differentialPrivacyEpsilon: Double = row["DifferentialPrivacyEpsilon"] + val differentialPrivacyDelta: Double = row["DifferentialPrivacyDelta"] + val frequencyDifferentialPrivacyEpsilon: Double? = row["FrequencyDifferentialPrivacyEpsilon"] + val frequencyDifferentialPrivacyDelta: Double? = row["FrequencyDifferentialPrivacyDelta"] + val maximumFrequencyPerUser: Int? = row["MaximumFrequencyPerUser"] + val maximumWatchDurationPerUser: Int? = row["MaximumWatchDurationPerUser"] + val vidSamplingStart: Float = row["VidSamplingIntervalStart"] + val vidSamplingWidth: Float = row["VidSamplingIntervalWidth"] + val createTime: Instant = row["CreateTime"] + val metricDetails: Metric.Details = + row.getProtoMessage("MetricDetails", Metric.Details.parser()) + val weight: Int = row["Coefficient"] + val measurementId: InternalId = row["MeasurementId"] + val cmmsCreateMeasurementRequestId: UUID = row["CmmsCreateMeasurementRequestId"] + val cmmsMeasurementId: String? = row["CmmsMeasurementId"] + val measurementTimeIntervalStart: Instant = row["MeasurementsTimeIntervalStart"] + val measurementTimeIntervalEnd: Instant = row["MeasurementsTimeIntervalEndExclusive"] + val measurementState: Measurement.State = Measurement.State.forNumber(row["State"]) + val measurementDetails: Measurement.Details = + row.getProtoMessage("MeasurementDetails", Measurement.Details.parser()) + val primitiveReportingSetBasisId: InternalId = row["PrimitiveReportingSetBasisId"] + val primitiveExternalReportingSetId: ExternalId = row["PrimitiveExternalReportingSetId"] + val primitiveReportingSetBasisFilter: String? = row["PrimitiveReportingSetBasisFilter"] + + val metricInfo = + metricInfoMap.computeIfAbsent(metricId) { + val metricTimeInterval = timeInterval { + startTime = metricTimeIntervalStart.toProtoTime() + endTime = metricTimeIntervalEnd.toProtoTime() + } + + val vidSamplingInterval = + MetricSpecKt.vidSamplingInterval { + start = vidSamplingStart + width = vidSamplingWidth + } + + val metricSpec = metricSpec { + when (metricType) { + MetricSpec.TypeCase.REACH -> + reach = + MetricSpecKt.reachParams { + privacyParams = + MetricSpecKt.differentialPrivacyParams { + epsilon = differentialPrivacyEpsilon + delta = differentialPrivacyDelta + } + } + MetricSpec.TypeCase.FREQUENCY_HISTOGRAM -> { + if ( + frequencyDifferentialPrivacyDelta == null || + frequencyDifferentialPrivacyEpsilon == null || + maximumFrequencyPerUser == null + ) { + throw IllegalStateException() + } + + frequencyHistogram = + MetricSpecKt.frequencyHistogramParams { + reachPrivacyParams = + MetricSpecKt.differentialPrivacyParams { + epsilon = differentialPrivacyEpsilon + delta = differentialPrivacyDelta + } + frequencyPrivacyParams = + MetricSpecKt.differentialPrivacyParams { + epsilon = frequencyDifferentialPrivacyEpsilon + delta = frequencyDifferentialPrivacyDelta + } + this.maximumFrequencyPerUser = maximumFrequencyPerUser + } + } + MetricSpec.TypeCase.IMPRESSION_COUNT -> { + if (maximumFrequencyPerUser == null) { + throw IllegalStateException() + } + + impressionCount = + MetricSpecKt.impressionCountParams { + privacyParams = + MetricSpecKt.differentialPrivacyParams { + epsilon = differentialPrivacyEpsilon + delta = differentialPrivacyDelta + } + this.maximumFrequencyPerUser = maximumFrequencyPerUser + } + } + MetricSpec.TypeCase.WATCH_DURATION -> { + if (maximumWatchDurationPerUser == null) { + throw IllegalStateException() + } + + watchDuration = + MetricSpecKt.watchDurationParams { + privacyParams = + MetricSpecKt.differentialPrivacyParams { + epsilon = differentialPrivacyEpsilon + delta = differentialPrivacyDelta + } + this.maximumWatchDurationPerUser = maximumWatchDurationPerUser + } + } + MetricSpec.TypeCase.TYPE_NOT_SET -> throw IllegalStateException() + } + this.vidSamplingInterval = vidSamplingInterval + } + + MetricInfo( + measurementConsumerId = measurementConsumerId, + cmmsMeasurementConsumerId = cmmsMeasurementConsumerId, + createMetricRequestId = createMetricRequestId, + externalReportingSetId = externalReportingSetId, + metricId = metricId, + externalMetricId = externalMetricId, + createTime = createTime.toProtoTime(), + timeInterval = metricTimeInterval, + metricSpec = metricSpec, + details = metricDetails, + weightedMeasurementInfoMap = mutableMapOf() + ) + } + + val weightedMeasurementInfo = + metricInfo.weightedMeasurementInfoMap.computeIfAbsent( + MetricMeasurementKey( + measurementConsumerId = measurementConsumerId, + measurementId = measurementId, + metricId = metricId, + ) + ) { + val timeInterval = timeInterval { + startTime = measurementTimeIntervalStart.toProtoTime() + endTime = measurementTimeIntervalEnd.toProtoTime() + } + + val measurementInfo = + MeasurementInfo( + cmmsMeasurementId = cmmsMeasurementId, + cmmsCreateMeasurementRequestId = cmmsCreateMeasurementRequestId.toString(), + timeInterval = timeInterval, + state = measurementState, + details = measurementDetails, + primitiveReportingSetBasisInfoMap = mutableMapOf(), + ) + + WeightedMeasurementInfo( + weight = weight, + measurementInfo = measurementInfo, + ) + } + + val primitiveReportingSetBasisInfo = + weightedMeasurementInfo.measurementInfo.primitiveReportingSetBasisInfoMap.computeIfAbsent( + primitiveReportingSetBasisId + ) { + PrimitiveReportingSetBasisInfo( + externalReportingSetId = primitiveExternalReportingSetId, + filterSet = mutableSetOf() + ) + } + + if (primitiveReportingSetBasisFilter != null) { + primitiveReportingSetBasisInfo.filterSet.add(primitiveReportingSetBasisFilter) + } + } + + readContext.executeQuery(statement).consume(translate).collect {} + + return metricInfoMap } } diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/readers/ReportingSetReader.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/readers/ReportingSetReader.kt index 26ec37e6460..464aa000c44 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/readers/ReportingSetReader.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/readers/ReportingSetReader.kt @@ -187,7 +187,10 @@ class ReportingSetReader(private val readContext: ReadContext) { LIMIT $3 ) AS ReportingSets """ + - baseSqlJoins + baseSqlJoins + + """ + ORDER BY RootExternalReportingSetId ASC + """ ) { bind("$1", request.filter.cmmsMeasurementConsumerId) bind("$2", request.filter.externalReportingSetIdAfter) @@ -277,9 +280,10 @@ class ReportingSetReader(private val readContext: ReadContext) { } } + /** Returns a map that maintains the order of the query result. */ private suspend fun buildResultMap(statement: BoundStatement): Map { // Key is reportingSetId. - val reportingSetInfoMap: MutableMap = mutableMapOf() + val reportingSetInfoMap: MutableMap = linkedMapOf() val translate: Translate = { row: ResultRow -> val measurementConsumerId: InternalId = row["ReportingSetsMeasurementConsumerId"] diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/writers/CreateMetric.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/writers/CreateMetric.kt index 686518d8115..594d59999b0 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/writers/CreateMetric.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/writers/CreateMetric.kt @@ -326,15 +326,17 @@ class CreateMetrics(private val requests: List) : measurementPrimitiveReportingSetBasesBinders.forEach { addBinding(it) } } - transactionContext.run { - executeStatement(statement) - executeStatement(measurementsStatement) - executeStatement(metricMeasurementsStatement) - executeStatement(primitiveReportingSetBasesStatement) - if (primitiveReportingSetBasisFiltersBinders.size > 0) { - executeStatement(primitiveReportingSetBasisFiltersStatement) + if (existingMetricsMap.size < requests.size) { + transactionContext.run { + executeStatement(statement) + executeStatement(measurementsStatement) + executeStatement(metricMeasurementsStatement) + executeStatement(primitiveReportingSetBasesStatement) + if (primitiveReportingSetBasisFiltersBinders.size > 0) { + executeStatement(primitiveReportingSetBasisFiltersStatement) + } + executeStatement(measurementPrimitiveReportingSetBasesStatement) } - executeStatement(measurementPrimitiveReportingSetBasesStatement) } return metrics diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/internal/testing/v2/MetricsServiceTest.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/internal/testing/v2/MetricsServiceTest.kt index 93338bdeb21..b399014f17e 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/internal/testing/v2/MetricsServiceTest.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/internal/testing/v2/MetricsServiceTest.kt @@ -17,27 +17,32 @@ package org.wfanet.measurement.reporting.service.internal.testing.v2 import com.google.common.truth.Truth.assertThat +import com.google.common.truth.extensions.proto.ProtoTruth.assertThat import com.google.protobuf.timestamp import io.grpc.Status import io.grpc.StatusRuntimeException import java.time.Clock import kotlin.random.Random import kotlin.test.assertFailsWith +import kotlinx.coroutines.flow.toList import kotlinx.coroutines.runBlocking import org.junit.Before -import org.junit.Ignore import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.JUnit4 import org.wfanet.measurement.common.identity.IdGenerator import org.wfanet.measurement.common.identity.RandomIdGenerator +import org.wfanet.measurement.internal.reporting.v2.CreateMetricRequest import org.wfanet.measurement.internal.reporting.v2.MeasurementConsumersGrpcKt.MeasurementConsumersCoroutineImplBase import org.wfanet.measurement.internal.reporting.v2.MetricKt import org.wfanet.measurement.internal.reporting.v2.MetricSpecKt import org.wfanet.measurement.internal.reporting.v2.MetricsGrpcKt.MetricsCoroutineImplBase +import org.wfanet.measurement.internal.reporting.v2.ReportingSet import org.wfanet.measurement.internal.reporting.v2.ReportingSetKt import org.wfanet.measurement.internal.reporting.v2.ReportingSetsGrpcKt.ReportingSetsCoroutineImplBase +import org.wfanet.measurement.internal.reporting.v2.StreamMetricsRequestKt import org.wfanet.measurement.internal.reporting.v2.batchCreateMetricsRequest +import org.wfanet.measurement.internal.reporting.v2.batchGetMetricsRequest import org.wfanet.measurement.internal.reporting.v2.copy import org.wfanet.measurement.internal.reporting.v2.createMetricRequest import org.wfanet.measurement.internal.reporting.v2.measurement @@ -45,10 +50,11 @@ import org.wfanet.measurement.internal.reporting.v2.measurementConsumer import org.wfanet.measurement.internal.reporting.v2.metric import org.wfanet.measurement.internal.reporting.v2.metricSpec import org.wfanet.measurement.internal.reporting.v2.reportingSet +import org.wfanet.measurement.internal.reporting.v2.streamMetricsRequest import org.wfanet.measurement.internal.reporting.v2.timeInterval private const val CMMS_MEASUREMENT_CONSUMER_ID = "1234" -private const val MAX_BATCH_CREATE_SIZE = 200 +private const val MAX_BATCH_SIZE = 1000 @RunWith(JUnit4::class) abstract class MetricsServiceTest { @@ -79,24 +85,8 @@ abstract class MetricsServiceTest { @Test fun `createMetric succeeds when MetricSpec type is Reach`() = runBlocking { - measurementConsumersService.createMeasurementConsumer( - measurementConsumer { cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID } - ) - - val reportingSet = reportingSet { - cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID - primitive = - ReportingSetKt.primitive { - eventGroupKeys += - ReportingSetKt.PrimitiveKt.eventGroupKey { - cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID - cmmsDataProviderId = "1235" - cmmsEventGroupId = "1236" - } - } - } - - val createdReportingSet = reportingSetsService.createReportingSet(reportingSet) + createMeasurementConsumer(CMMS_MEASUREMENT_CONSUMER_ID, measurementConsumersService) + val createdReportingSet = createReportingSet(CMMS_MEASUREMENT_CONSUMER_ID, reportingSetsService) val metric = metric { cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID @@ -147,7 +137,7 @@ abstract class MetricsServiceTest { MetricKt.weightedMeasurement { weight = 3 measurement = measurement { - cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID + "2" + cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID timeInterval = timeInterval { startTime = timestamp { seconds = 10 } endTime = timestamp { seconds = 100 } @@ -184,24 +174,8 @@ abstract class MetricsServiceTest { @Test fun `createMetric succeeds when MetricSpec type is FrequencyHistogram`() = runBlocking { - measurementConsumersService.createMeasurementConsumer( - measurementConsumer { cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID } - ) - - val reportingSet = reportingSet { - cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID - primitive = - ReportingSetKt.primitive { - eventGroupKeys += - ReportingSetKt.PrimitiveKt.eventGroupKey { - cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID - cmmsDataProviderId = "1235" - cmmsEventGroupId = "1236" - } - } - } - - val createdReportingSet = reportingSetsService.createReportingSet(reportingSet) + createMeasurementConsumer(CMMS_MEASUREMENT_CONSUMER_ID, measurementConsumersService) + val createdReportingSet = createReportingSet(CMMS_MEASUREMENT_CONSUMER_ID, reportingSetsService) val metric = metric { cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID @@ -258,7 +232,7 @@ abstract class MetricsServiceTest { MetricKt.weightedMeasurement { weight = 3 measurement = measurement { - cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID + "2" + cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID timeInterval = timeInterval { startTime = timestamp { seconds = 10 } endTime = timestamp { seconds = 100 } @@ -295,24 +269,8 @@ abstract class MetricsServiceTest { @Test fun `createMetric succeeds when MetricSpec type is ImpressionCount`() = runBlocking { - measurementConsumersService.createMeasurementConsumer( - measurementConsumer { cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID } - ) - - val reportingSet = reportingSet { - cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID - primitive = - ReportingSetKt.primitive { - eventGroupKeys += - ReportingSetKt.PrimitiveKt.eventGroupKey { - cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID - cmmsDataProviderId = "1235" - cmmsEventGroupId = "1236" - } - } - } - - val createdReportingSet = reportingSetsService.createReportingSet(reportingSet) + createMeasurementConsumer(CMMS_MEASUREMENT_CONSUMER_ID, measurementConsumersService) + val createdReportingSet = createReportingSet(CMMS_MEASUREMENT_CONSUMER_ID, reportingSetsService) val metric = metric { cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID @@ -364,7 +322,7 @@ abstract class MetricsServiceTest { MetricKt.weightedMeasurement { weight = 3 measurement = measurement { - cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID + "2" + cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID timeInterval = timeInterval { startTime = timestamp { seconds = 10 } endTime = timestamp { seconds = 100 } @@ -401,24 +359,8 @@ abstract class MetricsServiceTest { @Test fun `createMetric succeeds when MetricSpec type is WatchDuration`() = runBlocking { - measurementConsumersService.createMeasurementConsumer( - measurementConsumer { cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID } - ) - - val reportingSet = reportingSet { - cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID - primitive = - ReportingSetKt.primitive { - eventGroupKeys += - ReportingSetKt.PrimitiveKt.eventGroupKey { - cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID - cmmsDataProviderId = "1235" - cmmsEventGroupId = "1236" - } - } - } - - val createdReportingSet = reportingSetsService.createReportingSet(reportingSet) + createMeasurementConsumer(CMMS_MEASUREMENT_CONSUMER_ID, measurementConsumersService) + val createdReportingSet = createReportingSet(CMMS_MEASUREMENT_CONSUMER_ID, reportingSetsService) val metric = metric { cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID @@ -470,7 +412,7 @@ abstract class MetricsServiceTest { MetricKt.weightedMeasurement { weight = 3 measurement = measurement { - cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID + "2" + cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID timeInterval = timeInterval { startTime = timestamp { seconds = 10 } endTime = timestamp { seconds = 100 } @@ -506,25 +448,9 @@ abstract class MetricsServiceTest { } @Test - fun `createMetric succeeds when no filters in measurements`() = runBlocking { - measurementConsumersService.createMeasurementConsumer( - measurementConsumer { cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID } - ) - - val reportingSet = reportingSet { - cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID - primitive = - ReportingSetKt.primitive { - eventGroupKeys += - ReportingSetKt.PrimitiveKt.eventGroupKey { - cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID - cmmsDataProviderId = "1235" - cmmsEventGroupId = "1236" - } - } - } - - val createdReportingSet = reportingSetsService.createReportingSet(reportingSet) + fun `createMetric succeeds when no filters in bases in measurements`() = runBlocking { + createMeasurementConsumer(CMMS_MEASUREMENT_CONSUMER_ID, measurementConsumersService) + val createdReportingSet = createReportingSet(CMMS_MEASUREMENT_CONSUMER_ID, reportingSetsService) val metric = metric { cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID @@ -580,28 +506,10 @@ abstract class MetricsServiceTest { } } - /** TODO(tristanvuong2021): implement read methods for metric */ - @Ignore @Test fun `createMetric returns the same metric when using an existing request id`() = runBlocking { - measurementConsumersService.createMeasurementConsumer( - measurementConsumer { cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID } - ) - - val reportingSet = reportingSet { - cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID - primitive = - ReportingSetKt.primitive { - eventGroupKeys += - ReportingSetKt.PrimitiveKt.eventGroupKey { - cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID - cmmsDataProviderId = "1235" - cmmsEventGroupId = "1236" - } - } - } - - val createdReportingSet = reportingSetsService.createReportingSet(reportingSet) + createMeasurementConsumer(CMMS_MEASUREMENT_CONSUMER_ID, measurementConsumersService) + val createdReportingSet = createReportingSet(CMMS_MEASUREMENT_CONSUMER_ID, reportingSetsService) val metric = metric { cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID @@ -658,7 +566,7 @@ abstract class MetricsServiceTest { MetricKt.weightedMeasurement { weight = 3 measurement = measurement { - cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID + "2" + cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID timeInterval = timeInterval { startTime = timestamp { seconds = 10 } endTime = timestamp { seconds = 100 } @@ -702,29 +610,13 @@ abstract class MetricsServiceTest { } ) - assertThat(createdMetric).isEqualTo(sameCreatedMetric) + assertThat(createdMetric).ignoringRepeatedFieldOrder().isEqualTo(sameCreatedMetric) } @Test fun `createMetric throws NOT_FOUND when ReportingSet in basis not found`() = runBlocking { - measurementConsumersService.createMeasurementConsumer( - measurementConsumer { cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID } - ) - - val reportingSet = reportingSet { - cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID - primitive = - ReportingSetKt.primitive { - eventGroupKeys += - ReportingSetKt.PrimitiveKt.eventGroupKey { - cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID - cmmsDataProviderId = "1235" - cmmsEventGroupId = "1236" - } - } - } - - val createdReportingSet = reportingSetsService.createReportingSet(reportingSet) + createMeasurementConsumer(CMMS_MEASUREMENT_CONSUMER_ID, measurementConsumersService) + val createdReportingSet = createReportingSet(CMMS_MEASUREMENT_CONSUMER_ID, reportingSetsService) val metric = metric { cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID @@ -789,24 +681,8 @@ abstract class MetricsServiceTest { @Test fun `createMetric throws NOT_FOUND when ReportingSet in metric not found`() = runBlocking { - measurementConsumersService.createMeasurementConsumer( - measurementConsumer { cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID } - ) - - val reportingSet = reportingSet { - cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID - primitive = - ReportingSetKt.primitive { - eventGroupKeys += - ReportingSetKt.PrimitiveKt.eventGroupKey { - cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID - cmmsDataProviderId = "1235" - cmmsEventGroupId = "1236" - } - } - } - - val createdReportingSet = reportingSetsService.createReportingSet(reportingSet) + createMeasurementConsumer(CMMS_MEASUREMENT_CONSUMER_ID, measurementConsumersService) + val createdReportingSet = createReportingSet(CMMS_MEASUREMENT_CONSUMER_ID, reportingSetsService) val metric = metric { cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID @@ -871,24 +747,8 @@ abstract class MetricsServiceTest { @Test fun `createMetric throws INVALID_ARGUMENT when metric missing time interval`() = runBlocking { - measurementConsumersService.createMeasurementConsumer( - measurementConsumer { cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID } - ) - - val reportingSet = reportingSet { - cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID - primitive = - ReportingSetKt.primitive { - eventGroupKeys += - ReportingSetKt.PrimitiveKt.eventGroupKey { - cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID - cmmsDataProviderId = "1235" - cmmsEventGroupId = "1236" - } - } - } - - val createdReportingSet = reportingSetsService.createReportingSet(reportingSet) + createMeasurementConsumer(CMMS_MEASUREMENT_CONSUMER_ID, measurementConsumersService) + val createdReportingSet = createReportingSet(CMMS_MEASUREMENT_CONSUMER_ID, reportingSetsService) val metric = metric { cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID @@ -943,24 +803,8 @@ abstract class MetricsServiceTest { @Test fun `createMetric throws INVALID_ARGUMENT when metric spec missing type`() = runBlocking { - measurementConsumersService.createMeasurementConsumer( - measurementConsumer { cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID } - ) - - val reportingSet = reportingSet { - cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID - primitive = - ReportingSetKt.primitive { - eventGroupKeys += - ReportingSetKt.PrimitiveKt.eventGroupKey { - cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID - cmmsDataProviderId = "1235" - cmmsEventGroupId = "1236" - } - } - } - - val createdReportingSet = reportingSetsService.createReportingSet(reportingSet) + createMeasurementConsumer(CMMS_MEASUREMENT_CONSUMER_ID, measurementConsumersService) + val createdReportingSet = createReportingSet(CMMS_MEASUREMENT_CONSUMER_ID, reportingSetsService) val metric = metric { cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID @@ -1012,24 +856,9 @@ abstract class MetricsServiceTest { @Test fun `createMetric throws INVALID_ARGUMENT when metric spec missing vid sampling interval`() = runBlocking { - measurementConsumersService.createMeasurementConsumer( - measurementConsumer { cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID } - ) - - val reportingSet = reportingSet { - cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID - primitive = - ReportingSetKt.primitive { - eventGroupKeys += - ReportingSetKt.PrimitiveKt.eventGroupKey { - cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID - cmmsDataProviderId = "1235" - cmmsEventGroupId = "1236" - } - } - } - - val createdReportingSet = reportingSetsService.createReportingSet(reportingSet) + createMeasurementConsumer(CMMS_MEASUREMENT_CONSUMER_ID, measurementConsumersService) + val createdReportingSet = + createReportingSet(CMMS_MEASUREMENT_CONSUMER_ID, reportingSetsService) val metric = metric { cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID @@ -1084,24 +913,9 @@ abstract class MetricsServiceTest { @Test fun `createMetric throws INVALID_ARGUMENT when metric missing weighted measurements`() = runBlocking { - measurementConsumersService.createMeasurementConsumer( - measurementConsumer { cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID } - ) - - val reportingSet = reportingSet { - cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID - primitive = - ReportingSetKt.primitive { - eventGroupKeys += - ReportingSetKt.PrimitiveKt.eventGroupKey { - cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID - cmmsDataProviderId = "1235" - cmmsEventGroupId = "1236" - } - } - } - - val createdReportingSet = reportingSetsService.createReportingSet(reportingSet) + createMeasurementConsumer(CMMS_MEASUREMENT_CONSUMER_ID, measurementConsumersService) + val createdReportingSet = + createReportingSet(CMMS_MEASUREMENT_CONSUMER_ID, reportingSetsService) val metric = metric { cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID @@ -1143,24 +957,8 @@ abstract class MetricsServiceTest { @Test fun `createMetric throws FAILED_PRECONDITION when MC not found`() = runBlocking { - measurementConsumersService.createMeasurementConsumer( - measurementConsumer { cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID } - ) - - val reportingSet = reportingSet { - cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID - primitive = - ReportingSetKt.primitive { - eventGroupKeys += - ReportingSetKt.PrimitiveKt.eventGroupKey { - cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID - cmmsDataProviderId = "1235" - cmmsEventGroupId = "1236" - } - } - } - - val createdReportingSet = reportingSetsService.createReportingSet(reportingSet) + createMeasurementConsumer(CMMS_MEASUREMENT_CONSUMER_ID, measurementConsumersService) + val createdReportingSet = createReportingSet(CMMS_MEASUREMENT_CONSUMER_ID, reportingSetsService) val metric = metric { cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID + "2" @@ -1225,24 +1023,8 @@ abstract class MetricsServiceTest { @Test fun `batchCreateMetrics succeeds for one create metric request`() = runBlocking { - measurementConsumersService.createMeasurementConsumer( - measurementConsumer { cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID } - ) - - val reportingSet = reportingSet { - cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID - primitive = - ReportingSetKt.primitive { - eventGroupKeys += - ReportingSetKt.PrimitiveKt.eventGroupKey { - cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID - cmmsDataProviderId = "1235" - cmmsEventGroupId = "1236" - } - } - } - - val createdReportingSet = reportingSetsService.createReportingSet(reportingSet) + createMeasurementConsumer(CMMS_MEASUREMENT_CONSUMER_ID, measurementConsumersService) + val createdReportingSet = createReportingSet(CMMS_MEASUREMENT_CONSUMER_ID, reportingSetsService) val metric = metric { cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID @@ -1304,35 +1086,19 @@ abstract class MetricsServiceTest { } ) + val createdMetric = batchCreateMetricsResponse.metricsList.first() assertThat(batchCreateMetricsResponse.metricsList).hasSize(1) - assertThat(batchCreateMetricsResponse.metricsList.first().externalMetricId).isNotEqualTo(0) - batchCreateMetricsResponse.metricsList.forEach { - it.weightedMeasurementsList.forEach { weightedMeasurement -> - assertThat(weightedMeasurement.measurement.cmmsCreateMeasurementRequestId).isNotEmpty() - } + assertThat(createdMetric.externalMetricId).isNotEqualTo(0) + assertThat(createdMetric.hasCreateTime()).isTrue() + createdMetric.weightedMeasurementsList.forEach { + assertThat(it.measurement.cmmsCreateMeasurementRequestId).isNotEmpty() } } @Test fun `batchCreateMetrics succeeds for two create metric requests`() = runBlocking { - measurementConsumersService.createMeasurementConsumer( - measurementConsumer { cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID } - ) - - val reportingSet = reportingSet { - cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID - primitive = - ReportingSetKt.primitive { - eventGroupKeys += - ReportingSetKt.PrimitiveKt.eventGroupKey { - cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID - cmmsDataProviderId = "1235" - cmmsEventGroupId = "1236" - } - } - } - - val createdReportingSet = reportingSetsService.createReportingSet(reportingSet) + createMeasurementConsumer(CMMS_MEASUREMENT_CONSUMER_ID, measurementConsumersService) + val createdReportingSet = createReportingSet(CMMS_MEASUREMENT_CONSUMER_ID, reportingSetsService) val metric = metric { cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID @@ -1395,9 +1161,13 @@ abstract class MetricsServiceTest { } ) + val createdMetric = batchCreateMetricsResponse.metricsList.first() + val createdMetric2 = batchCreateMetricsResponse.metricsList.last() assertThat(batchCreateMetricsResponse.metricsList).hasSize(2) - assertThat(batchCreateMetricsResponse.metricsList.first().externalMetricId).isNotEqualTo(0) - assertThat(batchCreateMetricsResponse.metricsList.last().externalMetricId).isNotEqualTo(0) + assertThat(createdMetric.externalMetricId).isNotEqualTo(0) + assertThat(createdMetric.hasCreateTime()).isTrue() + assertThat(createdMetric2.externalMetricId).isNotEqualTo(0) + assertThat(createdMetric2.hasCreateTime()).isTrue() batchCreateMetricsResponse.metricsList.forEach { it.weightedMeasurementsList.forEach { weightedMeasurement -> assertThat(weightedMeasurement.measurement.cmmsCreateMeasurementRequestId).isNotEmpty() @@ -1405,29 +1175,12 @@ abstract class MetricsServiceTest { } } - /** TODO(tristanvuong2021): implement read methods for metric */ - @Ignore @Test fun `batchCreateMetrics succeeds for two create metric requests with one already existing`() = runBlocking { - measurementConsumersService.createMeasurementConsumer( - measurementConsumer { cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID } - ) - - val reportingSet = reportingSet { - cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID - primitive = - ReportingSetKt.primitive { - eventGroupKeys += - ReportingSetKt.PrimitiveKt.eventGroupKey { - cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID - cmmsDataProviderId = "1235" - cmmsEventGroupId = "1236" - } - } - } - - val createdReportingSet = reportingSetsService.createReportingSet(reportingSet) + createMeasurementConsumer(CMMS_MEASUREMENT_CONSUMER_ID, measurementConsumersService) + val createdReportingSet = + createReportingSet(CMMS_MEASUREMENT_CONSUMER_ID, reportingSetsService) val metric = metric { cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID @@ -1502,10 +1255,14 @@ abstract class MetricsServiceTest { ) assertThat(batchCreateMetricsResponse.metricsList).hasSize(2) - assertThat(batchCreateMetricsResponse.metricsList.first().externalMetricId) - .isEqualTo(createdMetric.externalMetricId) - assertThat(batchCreateMetricsResponse.metricsList.last().externalMetricId).isNotEqualTo(0) - batchCreateMetricsResponse.metricsList.last().weightedMeasurementsList.forEach { + assertThat(batchCreateMetricsResponse.metricsList.first()) + .ignoringRepeatedFieldOrder() + .isEqualTo(createdMetric) + + val createdMetric2 = batchCreateMetricsResponse.metricsList.last() + assertThat(createdMetric2.externalMetricId).isNotEqualTo(0) + assertThat(createdMetric2.hasCreateTime()).isTrue() + createdMetric2.weightedMeasurementsList.forEach { assertThat(it.measurement.cmmsCreateMeasurementRequestId).isNotEmpty() } } @@ -1513,24 +1270,9 @@ abstract class MetricsServiceTest { @Test fun `batchCreateMetrics throws INVALID_ARGUMENT when metric missing time interval`() = runBlocking { - measurementConsumersService.createMeasurementConsumer( - measurementConsumer { cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID } - ) - - val reportingSet = reportingSet { - cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID - primitive = - ReportingSetKt.primitive { - eventGroupKeys += - ReportingSetKt.PrimitiveKt.eventGroupKey { - cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID - cmmsDataProviderId = "1235" - cmmsEventGroupId = "1236" - } - } - } - - val createdReportingSet = reportingSetsService.createReportingSet(reportingSet) + createMeasurementConsumer(CMMS_MEASUREMENT_CONSUMER_ID, measurementConsumersService) + val createdReportingSet = + createReportingSet(CMMS_MEASUREMENT_CONSUMER_ID, reportingSetsService) val metric = metric { cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID @@ -1601,24 +1343,8 @@ abstract class MetricsServiceTest { @Test fun `batchCreateMetrics throws INVALID_ARGUMENT when metric spec missing type`() = runBlocking { - measurementConsumersService.createMeasurementConsumer( - measurementConsumer { cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID } - ) - - val reportingSet = reportingSet { - cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID - primitive = - ReportingSetKt.primitive { - eventGroupKeys += - ReportingSetKt.PrimitiveKt.eventGroupKey { - cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID - cmmsDataProviderId = "1235" - cmmsEventGroupId = "1236" - } - } - } - - val createdReportingSet = reportingSetsService.createReportingSet(reportingSet) + createMeasurementConsumer(CMMS_MEASUREMENT_CONSUMER_ID, measurementConsumersService) + val createdReportingSet = createReportingSet(CMMS_MEASUREMENT_CONSUMER_ID, reportingSetsService) val metric = metric { cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID @@ -1692,24 +1418,9 @@ abstract class MetricsServiceTest { @Test fun `batchCreateMetrics throws INVALID_ARGUMENT when metric spec missing vid sampling`() = runBlocking { - measurementConsumersService.createMeasurementConsumer( - measurementConsumer { cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID } - ) - - val reportingSet = reportingSet { - cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID - primitive = - ReportingSetKt.primitive { - eventGroupKeys += - ReportingSetKt.PrimitiveKt.eventGroupKey { - cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID - cmmsDataProviderId = "1235" - cmmsEventGroupId = "1236" - } - } - } - - val createdReportingSet = reportingSetsService.createReportingSet(reportingSet) + createMeasurementConsumer(CMMS_MEASUREMENT_CONSUMER_ID, measurementConsumersService) + val createdReportingSet = + createReportingSet(CMMS_MEASUREMENT_CONSUMER_ID, reportingSetsService) val metric = metric { cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID @@ -1784,24 +1495,9 @@ abstract class MetricsServiceTest { @Test fun `batchCreateMetrics throws INVALID_ARGUMENT when metric missing weighted measurements`() = runBlocking { - measurementConsumersService.createMeasurementConsumer( - measurementConsumer { cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID } - ) - - val reportingSet = reportingSet { - cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID - primitive = - ReportingSetKt.primitive { - eventGroupKeys += - ReportingSetKt.PrimitiveKt.eventGroupKey { - cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID - cmmsDataProviderId = "1235" - cmmsEventGroupId = "1236" - } - } - } - - val createdReportingSet = reportingSetsService.createReportingSet(reportingSet) + createMeasurementConsumer(CMMS_MEASUREMENT_CONSUMER_ID, measurementConsumersService) + val createdReportingSet = + createReportingSet(CMMS_MEASUREMENT_CONSUMER_ID, reportingSetsService) val metric = metric { cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID @@ -1875,24 +1571,9 @@ abstract class MetricsServiceTest { @Test fun `batchCreateMetrics throws INVALID_ARGUMENT when cmms mc id doesn't match create request`() = runBlocking { - measurementConsumersService.createMeasurementConsumer( - measurementConsumer { cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID } - ) - - val reportingSet = reportingSet { - cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID - primitive = - ReportingSetKt.primitive { - eventGroupKeys += - ReportingSetKt.PrimitiveKt.eventGroupKey { - cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID - cmmsDataProviderId = "1235" - cmmsEventGroupId = "1236" - } - } - } - - val createdReportingSet = reportingSetsService.createReportingSet(reportingSet) + createMeasurementConsumer(CMMS_MEASUREMENT_CONSUMER_ID, measurementConsumersService) + val createdReportingSet = + createReportingSet(CMMS_MEASUREMENT_CONSUMER_ID, reportingSetsService) val metric = metric { cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID @@ -1966,24 +1647,8 @@ abstract class MetricsServiceTest { @Test fun `batchCreateMetrics throws INVALID_ARGUMENT when too many requests`() = runBlocking { - measurementConsumersService.createMeasurementConsumer( - measurementConsumer { cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID } - ) - - val reportingSet = reportingSet { - cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID - primitive = - ReportingSetKt.primitive { - eventGroupKeys += - ReportingSetKt.PrimitiveKt.eventGroupKey { - cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID - cmmsDataProviderId = "1235" - cmmsEventGroupId = "1236" - } - } - } - - val createdReportingSet = reportingSetsService.createReportingSet(reportingSet) + createMeasurementConsumer(CMMS_MEASUREMENT_CONSUMER_ID, measurementConsumersService) + val createdReportingSet = createReportingSet(CMMS_MEASUREMENT_CONSUMER_ID, reportingSetsService) val metric = metric { cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID @@ -2042,7 +1707,7 @@ abstract class MetricsServiceTest { service.batchCreateMetrics( batchCreateMetricsRequest { cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID - for (i in 1..(MAX_BATCH_CREATE_SIZE + 1)) { + for (i in 1..(MAX_BATCH_SIZE + 1)) { requests += createMetricRequest { this.metric = metric } } } @@ -2052,4 +1717,570 @@ abstract class MetricsServiceTest { assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) assertThat(exception.message).contains("Too many") } + + @Test + fun `batchGetMetrics succeeds when metric spec type is reach`(): Unit = runBlocking { + createMeasurementConsumer(CMMS_MEASUREMENT_CONSUMER_ID, measurementConsumersService) + + val createMetricRequest = + createCreateMetricRequest(CMMS_MEASUREMENT_CONSUMER_ID, reportingSetsService).copy { + metric = + metric.copy { + metricSpec = metricSpec { + reach = + MetricSpecKt.reachParams { + privacyParams = + MetricSpecKt.differentialPrivacyParams { + epsilon = 1.0 + delta = 2.0 + } + } + vidSamplingInterval = + MetricSpecKt.vidSamplingInterval { + start = 0.1f + width = 0.5f + } + } + } + } + val createdMetric = service.createMetric(createMetricRequest) + + val retrievedMetrics = + service.batchGetMetrics( + batchGetMetricsRequest { + cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID + externalMetricIds += createdMetric.externalMetricId + } + ) + + assertThat(retrievedMetrics.metricsList) + .ignoringRepeatedFieldOrder() + .containsExactly(createdMetric) + } + + @Test + fun `batchGetMetrics succeeds when metric spec type is frequency duration`(): Unit = runBlocking { + createMeasurementConsumer(CMMS_MEASUREMENT_CONSUMER_ID, measurementConsumersService) + + val createMetricRequest = + createCreateMetricRequest(CMMS_MEASUREMENT_CONSUMER_ID, reportingSetsService).copy { + metric = + metric.copy { + metricSpec = metricSpec { + frequencyHistogram = + MetricSpecKt.frequencyHistogramParams { + reachPrivacyParams = + MetricSpecKt.differentialPrivacyParams { + epsilon = 1.0 + delta = 2.0 + } + frequencyPrivacyParams = + MetricSpecKt.differentialPrivacyParams { + epsilon = 1.0 + delta = 2.0 + } + maximumFrequencyPerUser = 5 + } + vidSamplingInterval = + MetricSpecKt.vidSamplingInterval { + start = 0.1f + width = 0.5f + } + } + } + } + val createdMetric = service.createMetric(createMetricRequest) + + val retrievedMetrics = + service.batchGetMetrics( + batchGetMetricsRequest { + cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID + externalMetricIds += createdMetric.externalMetricId + } + ) + + assertThat(retrievedMetrics.metricsList) + .ignoringRepeatedFieldOrder() + .containsExactly(createdMetric) + } + + @Test + fun `batchGetMetrics succeeds when metric spec type is impression count`(): Unit = runBlocking { + createMeasurementConsumer(CMMS_MEASUREMENT_CONSUMER_ID, measurementConsumersService) + + val createMetricRequest = + createCreateMetricRequest(CMMS_MEASUREMENT_CONSUMER_ID, reportingSetsService).copy { + metric = + metric.copy { + metricSpec = metricSpec { + impressionCount = + MetricSpecKt.impressionCountParams { + privacyParams = + MetricSpecKt.differentialPrivacyParams { + epsilon = 1.0 + delta = 2.0 + } + maximumFrequencyPerUser = 5 + } + vidSamplingInterval = + MetricSpecKt.vidSamplingInterval { + start = 0.1f + width = 0.5f + } + } + } + } + val createdMetric = service.createMetric(createMetricRequest) + + val retrievedMetrics = + service.batchGetMetrics( + batchGetMetricsRequest { + cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID + externalMetricIds += createdMetric.externalMetricId + } + ) + + assertThat(retrievedMetrics.metricsList) + .ignoringRepeatedFieldOrder() + .containsExactly(createdMetric) + } + + @Test + fun `batchGetMetrics succeeds when metric spec type is watch duration`(): Unit = runBlocking { + createMeasurementConsumer(CMMS_MEASUREMENT_CONSUMER_ID, measurementConsumersService) + + val createMetricRequest = + createCreateMetricRequest(CMMS_MEASUREMENT_CONSUMER_ID, reportingSetsService).copy { + metric = + metric.copy { + metricSpec = metricSpec { + watchDuration = + MetricSpecKt.watchDurationParams { + privacyParams = + MetricSpecKt.differentialPrivacyParams { + epsilon = 1.0 + delta = 2.0 + } + maximumWatchDurationPerUser = 100 + } + vidSamplingInterval = + MetricSpecKt.vidSamplingInterval { + start = 0.1f + width = 0.5f + } + } + } + } + val createdMetric = service.createMetric(createMetricRequest) + + val retrievedMetrics = + service.batchGetMetrics( + batchGetMetricsRequest { + cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID + externalMetricIds += createdMetric.externalMetricId + } + ) + + assertThat(retrievedMetrics.metricsList) + .ignoringRepeatedFieldOrder() + .containsExactly(createdMetric) + } + + @Test + fun `batchGetMetrics succeeds when asking for two metrics`(): Unit = runBlocking { + createMeasurementConsumer(CMMS_MEASUREMENT_CONSUMER_ID, measurementConsumersService) + + val createMetricRequest = + createCreateMetricRequest(CMMS_MEASUREMENT_CONSUMER_ID, reportingSetsService) + val createdMetricRequest2 = + createCreateMetricRequest(CMMS_MEASUREMENT_CONSUMER_ID, reportingSetsService) + val createdMetric = service.createMetric(createMetricRequest) + val createdMetric2 = service.createMetric(createdMetricRequest2) + + val retrievedMetrics = + service.batchGetMetrics( + batchGetMetricsRequest { + cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID + externalMetricIds += createdMetric.externalMetricId + externalMetricIds += createdMetric2.externalMetricId + } + ) + + assertThat(retrievedMetrics.metricsList) + .ignoringRepeatedFieldOrder() + .containsExactly(createdMetric, createdMetric2) + } + + @Test + fun `batchGetMetrics succeeds when no filters in bases in measurements`(): Unit = runBlocking { + createMeasurementConsumer(CMMS_MEASUREMENT_CONSUMER_ID, measurementConsumersService) + + val createMetricRequest = + createCreateMetricRequest(CMMS_MEASUREMENT_CONSUMER_ID, reportingSetsService).copy { + metric = + metric.copy { + val externalPrimitiveReportingSetId = + weightedMeasurements + .first() + .measurement + .primitiveReportingSetBasesList + .first() + .externalReportingSetId + weightedMeasurements.clear() + weightedMeasurements += + MetricKt.weightedMeasurement { + weight = 2 + measurement = measurement { + cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID + timeInterval = timeInterval { + startTime = timestamp { seconds = 10 } + endTime = timestamp { seconds = 100 } + } + primitiveReportingSetBases += + ReportingSetKt.primitiveReportingSetBasis { + externalReportingSetId = externalPrimitiveReportingSetId + } + } + } + } + } + val createdMetric = service.createMetric(createMetricRequest) + + val retrievedMetrics = + service.batchGetMetrics( + batchGetMetricsRequest { + cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID + externalMetricIds += createdMetric.externalMetricId + } + ) + + assertThat(retrievedMetrics.metricsList) + .ignoringRepeatedFieldOrder() + .containsExactly(createdMetric) + } + + @Test + fun `batchGetMetrics throws NOT_FOUND when not all metrics found`(): Unit = runBlocking { + createMeasurementConsumer(CMMS_MEASUREMENT_CONSUMER_ID, measurementConsumersService) + val createMetricRequest = + createCreateMetricRequest(CMMS_MEASUREMENT_CONSUMER_ID, reportingSetsService) + val createdMetric = service.createMetric(createMetricRequest) + + val exception = + assertFailsWith { + service.batchGetMetrics( + batchGetMetricsRequest { + cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID + externalMetricIds += 1L + externalMetricIds += createdMetric.externalMetricId + } + ) + } + + assertThat(exception.status.code).isEqualTo(Status.Code.NOT_FOUND) + assertThat(exception.message).contains("not found") + } + + @Test + fun `batchGetMetrics throws INVALID_ARGUMENT when missing mc id`(): Unit = runBlocking { + val exception = + assertFailsWith { + service.batchGetMetrics(batchGetMetricsRequest { externalMetricIds += 1L }) + } + + assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) + assertThat(exception.message).contains("CmmsMeasurementConsumerId") + } + + @Test + fun `batchGetMetrics throws INVALID_ARGUMENT when too many to get`(): Unit = runBlocking { + val exception = + assertFailsWith { + service.batchGetMetrics( + batchGetMetricsRequest { + cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID + for (i in 1L..(MAX_BATCH_SIZE + 1)) { + externalMetricIds += i + } + } + ) + } + + assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) + assertThat(exception.message).contains("Too many") + } + + @Test + fun `streamMetrics filters when measurement consumer filter is set`(): Unit = runBlocking { + createMeasurementConsumer(CMMS_MEASUREMENT_CONSUMER_ID, measurementConsumersService) + val differentCmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID + 2 + createMeasurementConsumer(differentCmmsMeasurementConsumerId, measurementConsumersService) + + val createMetricRequest = + createCreateMetricRequest(CMMS_MEASUREMENT_CONSUMER_ID, reportingSetsService) + val createMetricRequest2 = + createCreateMetricRequest(differentCmmsMeasurementConsumerId, reportingSetsService) + val createdMetric = service.createMetric(createMetricRequest) + service.createMetric(createMetricRequest2) + + val retrievedMetrics = + service + .streamMetrics( + streamMetricsRequest { + filter = + StreamMetricsRequestKt.filter { + cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID + } + } + ) + .toList() + + assertThat(retrievedMetrics).ignoringRepeatedFieldOrder().containsExactly(createdMetric) + } + + @Test + fun `streamReportingSets filters when id after filter is set`(): Unit = runBlocking { + createMeasurementConsumer(CMMS_MEASUREMENT_CONSUMER_ID, measurementConsumersService) + + val createMetricRequest = + createCreateMetricRequest(CMMS_MEASUREMENT_CONSUMER_ID, reportingSetsService) + val createMetricRequest2 = + createCreateMetricRequest(CMMS_MEASUREMENT_CONSUMER_ID, reportingSetsService) + + val createdMetric = service.createMetric(createMetricRequest) + val createdMetric2 = service.createMetric(createMetricRequest2) + + val afterId = minOf(createdMetric.externalMetricId, createdMetric2.externalMetricId) + + val retrievedMetrics = + service + .streamMetrics( + streamMetricsRequest { + filter = + StreamMetricsRequestKt.filter { + cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID + externalMetricIdAfter = afterId + } + } + ) + .toList() + + if (createdMetric.externalMetricId == afterId) { + assertThat(retrievedMetrics).ignoringRepeatedFieldOrder().containsExactly(createdMetric2) + } else { + assertThat(retrievedMetrics).ignoringRepeatedFieldOrder().containsExactly(createdMetric) + } + } + + @Test + fun `streamMetrics filters when both mc and after filter are set`(): Unit = runBlocking { + createMeasurementConsumer(CMMS_MEASUREMENT_CONSUMER_ID, measurementConsumersService) + val differentCmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID + 2 + createMeasurementConsumer(differentCmmsMeasurementConsumerId, measurementConsumersService) + + val createMetricRequest = + createCreateMetricRequest(CMMS_MEASUREMENT_CONSUMER_ID, reportingSetsService) + val createMetricRequest2 = + createCreateMetricRequest(CMMS_MEASUREMENT_CONSUMER_ID, reportingSetsService) + val createMetricRequest3 = + createCreateMetricRequest(differentCmmsMeasurementConsumerId, reportingSetsService) + + val createdMetric = service.createMetric(createMetricRequest) + val createdMetric2 = service.createMetric(createMetricRequest2) + service.createMetric(createMetricRequest3) + + val afterId = minOf(createdMetric.externalMetricId, createdMetric2.externalMetricId) + + val retrievedMetrics = + service + .streamMetrics( + streamMetricsRequest { + filter = + StreamMetricsRequestKt.filter { + cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID + externalMetricIdAfter = afterId + } + } + ) + .toList() + + if (createdMetric.externalMetricId == afterId) { + assertThat(retrievedMetrics).ignoringRepeatedFieldOrder().containsExactly(createdMetric2) + } else { + assertThat(retrievedMetrics).ignoringRepeatedFieldOrder().containsExactly(createdMetric) + } + } + + @Test + fun `streamMetrics limits the number of results when limit is set`(): Unit = runBlocking { + createMeasurementConsumer(CMMS_MEASUREMENT_CONSUMER_ID, measurementConsumersService) + + val createMetricRequest = + createCreateMetricRequest(CMMS_MEASUREMENT_CONSUMER_ID, reportingSetsService) + val createMetricRequest2 = + createCreateMetricRequest(CMMS_MEASUREMENT_CONSUMER_ID, reportingSetsService) + + val createdMetric = service.createMetric(createMetricRequest) + val createdMetric2 = service.createMetric(createMetricRequest2) + + val retrievedMetrics = + service + .streamMetrics( + streamMetricsRequest { + filter = + StreamMetricsRequestKt.filter { + cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID + } + limit = 1 + } + ) + .toList() + + if (createdMetric.externalMetricId < createdMetric2.externalMetricId) { + assertThat(retrievedMetrics).ignoringRepeatedFieldOrder().containsExactly(createdMetric) + } else { + assertThat(retrievedMetrics).ignoringRepeatedFieldOrder().containsExactly(createdMetric2) + } + } + + @Test + fun `streamMetrics returns empty flow when no metrics are found`() = runBlocking { + createMeasurementConsumer(CMMS_MEASUREMENT_CONSUMER_ID, measurementConsumersService) + + val retrievedMetrics = + service + .streamMetrics( + streamMetricsRequest { + filter = + StreamMetricsRequestKt.filter { + cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID + } + } + ) + .toList() + + assertThat(retrievedMetrics).hasSize(0) + } + + @Test + fun `streamMetrics throws INVALID_ARGUMENT when MC filter missing`() = runBlocking { + val exception = + assertFailsWith { service.streamMetrics(streamMetricsRequest {}) } + + assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) + } + + companion object { + private suspend fun createCreateMetricRequest( + cmmsMeasurementConsumerId: String, + reportingSetsService: ReportingSetsCoroutineImplBase, + ): CreateMetricRequest { + val createdReportingSet = createReportingSet(cmmsMeasurementConsumerId, reportingSetsService) + + val metric = metric { + this.cmmsMeasurementConsumerId = cmmsMeasurementConsumerId + externalReportingSetId = createdReportingSet.externalReportingSetId + timeInterval = timeInterval { + startTime = timestamp { seconds = 10 } + endTime = timestamp { seconds = 100 } + } + metricSpec = metricSpec { + reach = + MetricSpecKt.reachParams { + privacyParams = + MetricSpecKt.differentialPrivacyParams { + epsilon = 1.0 + delta = 2.0 + } + } + vidSamplingInterval = + MetricSpecKt.vidSamplingInterval { + start = 0.1f + width = 0.5f + } + } + weightedMeasurements += + MetricKt.weightedMeasurement { + weight = 2 + measurement = measurement { + this.cmmsMeasurementConsumerId = cmmsMeasurementConsumerId + timeInterval = timeInterval { + startTime = timestamp { seconds = 10 } + endTime = timestamp { seconds = 100 } + } + primitiveReportingSetBases += + ReportingSetKt.primitiveReportingSetBasis { + externalReportingSetId = createdReportingSet.externalReportingSetId + filters += "filter1" + filters += "filter2" + } + primitiveReportingSetBases += + ReportingSetKt.primitiveReportingSetBasis { + externalReportingSetId = createdReportingSet.externalReportingSetId + filters += "filter3" + filters += "filter4" + } + } + } + weightedMeasurements += + MetricKt.weightedMeasurement { + weight = 3 + measurement = measurement { + this.cmmsMeasurementConsumerId = cmmsMeasurementConsumerId + timeInterval = timeInterval { + startTime = timestamp { seconds = 10 } + endTime = timestamp { seconds = 100 } + } + primitiveReportingSetBases += + ReportingSetKt.primitiveReportingSetBasis { + externalReportingSetId = createdReportingSet.externalReportingSetId + filters += "filter5" + filters += "filter6" + } + primitiveReportingSetBases += + ReportingSetKt.primitiveReportingSetBasis { + externalReportingSetId = createdReportingSet.externalReportingSetId + filters += "filter7" + filters += "filter8" + } + } + } + details = + MetricKt.details { + filters += "filter1" + filters += "filter2" + } + } + + return createMetricRequest { this.metric = metric } + } + + private suspend fun createReportingSet( + cmmsMeasurementConsumerId: String, + reportingSetsService: ReportingSetsCoroutineImplBase + ): ReportingSet { + val reportingSet = reportingSet { + this.cmmsMeasurementConsumerId = cmmsMeasurementConsumerId + primitive = + ReportingSetKt.primitive { + eventGroupKeys += + ReportingSetKt.PrimitiveKt.eventGroupKey { + this.cmmsMeasurementConsumerId = cmmsMeasurementConsumerId + cmmsDataProviderId = "1235" + cmmsEventGroupId = cmmsMeasurementConsumerId + "123" + } + } + } + return reportingSetsService.createReportingSet(reportingSet) + } + + private suspend fun createMeasurementConsumer( + cmmsMeasurementConsumerId: String, + measurementConsumersService: MeasurementConsumersCoroutineImplBase, + ) { + measurementConsumersService.createMeasurementConsumer( + measurementConsumer { this.cmmsMeasurementConsumerId = cmmsMeasurementConsumerId } + ) + } + } } diff --git a/src/main/proto/wfa/measurement/internal/reporting/v2/metrics_service.proto b/src/main/proto/wfa/measurement/internal/reporting/v2/metrics_service.proto index 69bffd23e2c..000ae2b1ecf 100644 --- a/src/main/proto/wfa/measurement/internal/reporting/v2/metrics_service.proto +++ b/src/main/proto/wfa/measurement/internal/reporting/v2/metrics_service.proto @@ -43,7 +43,7 @@ message BatchCreateMetricsRequest { // `MeasurementConsumer` ID from the CMMS public API. string cmms_measurement_consumer_id = 1; - // Maximum is 200. + // Maximum is 1000. repeated CreateMetricRequest requests = 2; } @@ -55,7 +55,7 @@ message BatchGetMetricsRequest { // `MeasurementConsumer` ID from the CMMS public API. string cmms_measurement_consumer_id = 1; - // Maximum is 200. + // Maximum is 1000. repeated fixed64 external_metric_ids = 2; }