From 9898ebff2c422b7b3ef92e8b0687f1be4cbb2c09 Mon Sep 17 00:00:00 2001 From: Tristan Vuong Date: Thu, 8 Feb 2024 21:21:34 +0000 Subject: [PATCH 01/13] Refactor SubmitBatchRequests to use Coroutines --- .../service/api/SubmitBatchRequests.kt | 29 ++++++++++++++----- .../service/api/v2alpha/MetricsService.kt | 18 ++++-------- .../api/v2alpha/ReportSchedulesService.kt | 4 +-- .../service/api/v2alpha/ReportsService.kt | 16 +++++----- .../internal/testing/v2/ReportsServiceTest.kt | 5 ++-- .../service/api/SubmitBatchRequestsTest.kt | 15 ++++------ 6 files changed, 46 insertions(+), 41 deletions(-) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequests.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequests.kt index 12fdae3c352..37a2e5400a3 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequests.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequests.kt @@ -16,11 +16,14 @@ package org.wfanet.measurement.reporting.service.api -import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.Deferred +import kotlinx.coroutines.async +import kotlinx.coroutines.awaitAll +import kotlinx.coroutines.coroutineScope import kotlinx.coroutines.flow.Flow -import kotlinx.coroutines.flow.asFlow -import kotlinx.coroutines.flow.flatMapConcat import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.sync.Semaphore +import kotlinx.coroutines.sync.withPermit class BatchRequestException(message: String? = null, cause: Throwable? = null) : Exception(message, cause) @@ -48,13 +51,12 @@ fun Flow.chunked(chunkSize: Int): Flow> { } /** Submits multiple RPCs by dividing the input items to batches. */ -@OptIn(ExperimentalCoroutinesApi::class) // For `flatMapConcat`. suspend fun submitBatchRequests( - items: Flow, + items: Collection, limit: Int, callRpc: suspend (List) -> RESP, parseResponse: (RESP) -> List, -): Flow { +): List { if (limit <= 0) { throw BatchRequestException( "Invalid limit", @@ -62,5 +64,18 @@ suspend fun submitBatchRequests( ) } - return items.chunked(limit).flatMapConcat { batch -> parseResponse(callRpc(batch)).asFlow() } + // For network requests, the number of concurrent coroutines needs to be capped. To be on the safe + // side, a low number is chosen. + val batchSemaphore = Semaphore(3) + return coroutineScope { + val deferred: List>> = items.chunked(limit).map { batch: List -> + async { + batchSemaphore.withPermit { + parseResponse(callRpc(batch)) + } + } + } + val responses: List> = deferred.awaitAll() + responses.flatten() + } } diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt index 8c76fdf5467..f322cb65480 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt @@ -293,7 +293,7 @@ class MetricsService( val measurementConsumer: MeasurementConsumer = getMeasurementConsumer(principal) // Gets all external IDs of primitive reporting sets from the metric list. - val externalPrimitiveReportingSetIds: Flow = + val externalPrimitiveReportingSetIds: List = internalMetricsList .flatMap { internalMetric -> internalMetric.weightedMeasurementsList.flatMap { weightedMeasurement -> @@ -303,7 +303,6 @@ class MetricsService( } } .distinct() - .asFlow() val callBatchGetInternalReportingSetsRpc: suspend (List) -> BatchGetReportingSetsResponse = @@ -319,7 +318,6 @@ class MetricsService( ) { response: BatchGetReportingSetsResponse -> response.reportingSetsList } - .toList() .associateBy { it.externalReportingSetId } val dataProviderNames = mutableSetOf() @@ -357,9 +355,9 @@ class MetricsService( batchCreateCmmsMeasurements(principal, items) } - val cmmsMeasurements: Flow = + val cmmsMeasurements: List = submitBatchRequests( - cmmsCreateMeasurementRequests.asFlow(), + cmmsCreateMeasurementRequests, BATCH_KINGDOM_MEASUREMENTS_LIMIT, callBatchCreateMeasurementsRpc, ) { response: BatchCreateMeasurementsResponse -> @@ -385,7 +383,6 @@ class MetricsService( ) { response: BatchSetCmmsMeasurementIdsResponse -> response.measurementsList } - .toList() } /** Sets a batch of CMMS [MeasurementIds] to the [InternalMeasurement] table. */ @@ -792,13 +789,12 @@ class MetricsService( batchSetInternalMeasurementResults(items, apiAuthenticationKey, principal) } submitBatchRequests( - measurementsList.asFlow(), + measurementsList, BATCH_SET_MEASUREMENT_RESULTS_LIMIT, callBatchSetInternalMeasurementResultsRpc, ) { response: BatchSetCmmsMeasurementResultsResponse -> response.measurementsList } - .toList() anyUpdate = true } @@ -815,13 +811,12 @@ class MetricsService( ) } submitBatchRequests( - measurementsList.asFlow(), + measurementsList, BATCH_SET_MEASUREMENT_FAILURES_LIMIT, callBatchSetInternalMeasurementFailuresRpc, ) { response: BatchSetCmmsMeasurementFailuresResponse -> response.measurementsList } - .toList() anyUpdate = true } @@ -918,13 +913,12 @@ class MetricsService( } return submitBatchRequests( - measurementNames.asFlow(), + measurementNames, BATCH_KINGDOM_MEASUREMENTS_LIMIT, callBatchGetMeasurementsRpc, ) { response: BatchGetMeasurementsResponse -> response.measurementsList } - .toList() } /** Batch get CMMS measurements. */ diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportSchedulesService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportSchedulesService.kt index 50921b4900f..83e6b904656 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportSchedulesService.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportSchedulesService.kt @@ -674,14 +674,14 @@ class ReportSchedulesService( retrievedExternalReportingSetIdSet.addAll(externalReportingSetIdSet) submitBatchRequests( - externalReportingSetIdSet.asFlow(), + externalReportingSetIdSet, BATCH_GET_REPORTING_SETS_LIMIT, callRpc, ) { response -> externalReportingSetIdSet.clear() response.reportingSetsList } - .collect { + .forEach { if (it.hasComposite()) { val lhsExternalReportingSetId = it.composite.lhs.externalReportingSetId if (lhsExternalReportingSetId.isNotEmpty()) { diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsService.kt index 702bd952140..ec025690e44 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsService.kt @@ -164,8 +164,8 @@ class ReportsService( results.subList(0, min(results.size, listReportsPageToken.pageSize)) // Get metrics. - val metricNames: Flow = - subResults.flatMap { internalReport -> internalReport.metricNames }.distinct().asFlow() + val metricNames: List = + subResults.flatMap { internalReport -> internalReport.metricNames }.distinct() val callRpc: suspend (List) -> BatchGetMetricsResponse = { items -> batchGetMetrics(principal.resourceKey.toName(), items) @@ -174,7 +174,6 @@ class ReportsService( submitBatchRequests(metricNames, BATCH_GET_METRICS_LIMIT, callRpc) { response -> response.metricsList } - .toList() .associateBy { checkNotNull(MetricKey.fromName(it.name)).metricId } return listReportsResponse { @@ -230,7 +229,7 @@ class ReportsService( } // Get metrics. - val metricNames: Flow = internalReport.metricNames.distinct().asFlow() + val metricNames: List = internalReport.metricNames.distinct() val callRpc: suspend (List) -> BatchGetMetricsResponse = { items -> batchGetMetrics(principal.resourceKey.toName(), items) @@ -239,7 +238,6 @@ class ReportsService( submitBatchRequests(metricNames, BATCH_GET_METRICS_LIMIT, callRpc) { response -> response.metricsList } - .toList() .associateBy { checkNotNull(MetricKey.fromName(it.name)).metricId } // Convert the internal report to public and return. @@ -309,6 +307,8 @@ class ReportsService( key.metricCalculationSpecId } } + .distinct() + val externalIdToMetricCalculationSpecMap: Map = createExternalIdToMetricCalculationSpecMap( parentKey.measurementConsumerId, @@ -360,7 +360,7 @@ class ReportsService( } // Create metrics. - val createMetricRequests: Flow = + val createMetricRequests: List = internalReport.reportingMetricEntriesMap .flatMap { (reportingSetId, reportingMetricCalculationSpec) -> reportingMetricCalculationSpec.metricCalculationSpecReportingMetricsList.flatMap { @@ -377,7 +377,6 @@ class ReportsService( } } } - .asFlow() val callRpc: suspend (List) -> BatchCreateMetricsResponse = { items -> batchCreateMetrics(request.parent, items) @@ -387,7 +386,6 @@ class ReportsService( response: BatchCreateMetricsResponse -> response.metricsList } - .toList() .associateBy { checkNotNull(MetricKey.fromName(it.name)).metricId } // Once all metrics are created, get the updated internal report with the metric IDs filled. @@ -482,6 +480,8 @@ class ReportsService( it.externalMetricCalculationSpecId } } + .distinct() + val externalIdToMetricCalculationMap: Map = createExternalIdToMetricCalculationSpecMap( internalReport.cmmsMeasurementConsumerId, diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/internal/testing/v2/ReportsServiceTest.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/internal/testing/v2/ReportsServiceTest.kt index f3bf4853844..1f6a78db895 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/internal/testing/v2/ReportsServiceTest.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/internal/testing/v2/ReportsServiceTest.kt @@ -635,7 +635,7 @@ abstract class ReportsServiceTest { ) var metricIndex = 0 - val createMetricsRequests: Flow = + val createMetricsRequests: List = createdReport.reportingMetricEntriesMap.entries .flatMap { entry -> val reportingSet = createdReportingSetsByExternalId.getValue(entry.key) @@ -660,7 +660,7 @@ abstract class ReportsServiceTest { } } } - .asFlow() + val callRpc: suspend (List) -> BatchCreateMetricsResponse = { items -> metricsService.batchCreateMetrics( batchCreateMetricsRequest { @@ -672,7 +672,6 @@ abstract class ReportsServiceTest { submitBatchRequests(createMetricsRequests, MAX_BATCH_SIZE, callRpc) { response -> response.metricsList } - .toList() val retrievedReport = service.getReport( diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequestsTest.kt b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequestsTest.kt index a9537be540f..26197539406 100644 --- a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequestsTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequestsTest.kt @@ -89,7 +89,7 @@ class SubmitBatchRequestsTest { ceil(INTERNAL_PRIMITIVE_REPORTING_SETS.size / BATCH_GET_REPORTING_SETS_LIMIT.toFloat()) .toInt() - val items = INTERNAL_PRIMITIVE_REPORTING_SETS.map { it.externalReportingSetId }.asFlow() + val items = INTERNAL_PRIMITIVE_REPORTING_SETS.map { it.externalReportingSetId } val parseResponse: (BatchGetReportingSetsResponse) -> List = { response -> @@ -103,7 +103,6 @@ class SubmitBatchRequestsTest { ::batchGetReportingSets, parseResponse, ) - .toList() val batchGetReportingSetsCaptor: KArgumentCaptor = argumentCaptor() @@ -121,7 +120,7 @@ class SubmitBatchRequestsTest { val expectedReportingSets = INTERNAL_PRIMITIVE_REPORTING_SETS.subList(0, numberTargetReportingSet) val expectedNumberBatches = 1 - val items = expectedReportingSets.map { it.externalReportingSetId }.asFlow() + val items = expectedReportingSets.map { it.externalReportingSetId } val parseResponse: (BatchGetReportingSetsResponse) -> List = { response -> @@ -135,7 +134,6 @@ class SubmitBatchRequestsTest { ::batchGetReportingSets, parseResponse, ) - .toList() val batchGetReportingSetsCaptor: KArgumentCaptor = argumentCaptor() @@ -154,12 +152,11 @@ class SubmitBatchRequestsTest { val result: List = submitBatchRequests( - flow {}, - BATCH_GET_REPORTING_SETS_LIMIT, - ::batchGetReportingSets, - parseResponse, + emptyList(), + BATCH_GET_REPORTING_SETS_LIMIT, + ::batchGetReportingSets, + parseResponse, ) - .toList() val batchGetReportingSetsCaptor: KArgumentCaptor = argumentCaptor() From 51faadea09b01b50acdf46deb36e01b86691a669 Mon Sep 17 00:00:00 2001 From: Tristan Vuong Date: Thu, 8 Feb 2024 21:27:18 +0000 Subject: [PATCH 02/13] lint fix --- .../service/api/SubmitBatchRequests.kt | 9 +-- .../service/api/v2alpha/MetricsService.kt | 60 +++++++++---------- .../api/v2alpha/ReportSchedulesService.kt | 8 +-- .../service/api/v2alpha/ReportsService.kt | 56 ++++++++--------- .../service/api/SubmitBatchRequestsTest.kt | 25 ++++---- 5 files changed, 73 insertions(+), 85 deletions(-) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequests.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequests.kt index 37a2e5400a3..50629ae785b 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequests.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequests.kt @@ -68,13 +68,10 @@ suspend fun submitBatchRequests( // side, a low number is chosen. val batchSemaphore = Semaphore(3) return coroutineScope { - val deferred: List>> = items.chunked(limit).map { batch: List -> - async { - batchSemaphore.withPermit { - parseResponse(callRpc(batch)) - } + val deferred: List>> = + items.chunked(limit).map { batch: List -> + async { batchSemaphore.withPermit { parseResponse(callRpc(batch)) } } } - } val responses: List> = deferred.awaitAll() responses.flatten() } diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt index f322cb65480..b0cc55a8317 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt @@ -47,8 +47,6 @@ import kotlinx.coroutines.asExecutor import kotlinx.coroutines.async import kotlinx.coroutines.awaitAll import kotlinx.coroutines.coroutineScope -import kotlinx.coroutines.flow.Flow -import kotlinx.coroutines.flow.asFlow import kotlinx.coroutines.flow.map import kotlinx.coroutines.flow.toList import kotlinx.coroutines.withContext @@ -372,17 +370,17 @@ class MetricsService( } submitBatchRequests( - cmmsMeasurements.map { - measurementIds { - cmmsCreateMeasurementRequestId = it.measurementReferenceId - cmmsMeasurementId = MeasurementKey.fromName(it.name)!!.measurementId - } - }, - BATCH_SET_CMMS_MEASUREMENT_IDS_LIMIT, - callBatchSetCmmsMeasurementIdsRpc, - ) { response: BatchSetCmmsMeasurementIdsResponse -> - response.measurementsList - } + cmmsMeasurements.map { + measurementIds { + cmmsCreateMeasurementRequestId = it.measurementReferenceId + cmmsMeasurementId = MeasurementKey.fromName(it.name)!!.measurementId + } + }, + BATCH_SET_CMMS_MEASUREMENT_IDS_LIMIT, + callBatchSetCmmsMeasurementIdsRpc, + ) { response: BatchSetCmmsMeasurementIdsResponse -> + response.measurementsList + } } /** Sets a batch of CMMS [MeasurementIds] to the [InternalMeasurement] table. */ @@ -789,12 +787,12 @@ class MetricsService( batchSetInternalMeasurementResults(items, apiAuthenticationKey, principal) } submitBatchRequests( - measurementsList, - BATCH_SET_MEASUREMENT_RESULTS_LIMIT, - callBatchSetInternalMeasurementResultsRpc, - ) { response: BatchSetCmmsMeasurementResultsResponse -> - response.measurementsList - } + measurementsList, + BATCH_SET_MEASUREMENT_RESULTS_LIMIT, + callBatchSetInternalMeasurementResultsRpc, + ) { response: BatchSetCmmsMeasurementResultsResponse -> + response.measurementsList + } anyUpdate = true } @@ -811,12 +809,12 @@ class MetricsService( ) } submitBatchRequests( - measurementsList, - BATCH_SET_MEASUREMENT_FAILURES_LIMIT, - callBatchSetInternalMeasurementFailuresRpc, - ) { response: BatchSetCmmsMeasurementFailuresResponse -> - response.measurementsList - } + measurementsList, + BATCH_SET_MEASUREMENT_FAILURES_LIMIT, + callBatchSetInternalMeasurementFailuresRpc, + ) { response: BatchSetCmmsMeasurementFailuresResponse -> + response.measurementsList + } anyUpdate = true } @@ -913,12 +911,12 @@ class MetricsService( } return submitBatchRequests( - measurementNames, - BATCH_KINGDOM_MEASUREMENTS_LIMIT, - callBatchGetMeasurementsRpc, - ) { response: BatchGetMeasurementsResponse -> - response.measurementsList - } + measurementNames, + BATCH_KINGDOM_MEASUREMENTS_LIMIT, + callBatchGetMeasurementsRpc, + ) { response: BatchGetMeasurementsResponse -> + response.measurementsList + } } /** Batch get CMMS measurements. */ diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportSchedulesService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportSchedulesService.kt index 83e6b904656..b3330858e22 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportSchedulesService.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportSchedulesService.kt @@ -33,7 +33,6 @@ import java.time.ZonedDateTime import java.time.temporal.TemporalAdjusters import java.time.zone.ZoneRulesException import kotlin.math.min -import kotlinx.coroutines.flow.asFlow import org.projectnessie.cel.Env import org.wfanet.measurement.api.v2alpha.DataProvider import org.wfanet.measurement.api.v2alpha.DataProviderKey @@ -673,11 +672,8 @@ class ReportSchedulesService( while (externalReportingSetIdSet.isNotEmpty()) { retrievedExternalReportingSetIdSet.addAll(externalReportingSetIdSet) - submitBatchRequests( - externalReportingSetIdSet, - BATCH_GET_REPORTING_SETS_LIMIT, - callRpc, - ) { response -> + submitBatchRequests(externalReportingSetIdSet, BATCH_GET_REPORTING_SETS_LIMIT, callRpc) { + response -> externalReportingSetIdSet.clear() response.reportingSetsList } diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsService.kt index ec025690e44..542b8ef651e 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsService.kt @@ -35,8 +35,6 @@ import java.time.temporal.Temporal import java.time.temporal.TemporalAdjusters import java.time.zone.ZoneRulesException import kotlin.math.min -import kotlinx.coroutines.flow.Flow -import kotlinx.coroutines.flow.asFlow import kotlinx.coroutines.flow.toList import org.projectnessie.cel.Env import org.wfanet.measurement.api.v2alpha.MeasurementConsumerKey @@ -298,15 +296,16 @@ class ReportsService( validateTime(request.report) val externalMetricCalculationSpecIds: List = - request.report.reportingMetricEntriesList.flatMap { reportingMetricEntry -> - reportingMetricEntry.value.metricCalculationSpecsList.map { - val key = - grpcRequireNotNull(MetricCalculationSpecKey.fromName(it)) { - "MetricCalculationSpec name $it is invalid." - } - key.metricCalculationSpecId + request.report.reportingMetricEntriesList + .flatMap { reportingMetricEntry -> + reportingMetricEntry.value.metricCalculationSpecsList.map { + val key = + grpcRequireNotNull(MetricCalculationSpecKey.fromName(it)) { + "MetricCalculationSpec name $it is invalid." + } + key.metricCalculationSpecId + } } - } .distinct() val externalIdToMetricCalculationSpecMap: Map = @@ -361,22 +360,22 @@ class ReportsService( // Create metrics. val createMetricRequests: List = - internalReport.reportingMetricEntriesMap - .flatMap { (reportingSetId, reportingMetricCalculationSpec) -> - reportingMetricCalculationSpec.metricCalculationSpecReportingMetricsList.flatMap { - metricCalculationSpecReportingMetrics -> - metricCalculationSpecReportingMetrics.reportingMetricsList.map { - it.toCreateMetricRequest( - principal.resourceKey, - reportingSetId, - externalIdToMetricCalculationSpecMap - .getValue(metricCalculationSpecReportingMetrics.externalMetricCalculationSpecId) - .details - .filter, - ) - } + internalReport.reportingMetricEntriesMap.flatMap { + (reportingSetId, reportingMetricCalculationSpec) -> + reportingMetricCalculationSpec.metricCalculationSpecReportingMetricsList.flatMap { + metricCalculationSpecReportingMetrics -> + metricCalculationSpecReportingMetrics.reportingMetricsList.map { + it.toCreateMetricRequest( + principal.resourceKey, + reportingSetId, + externalIdToMetricCalculationSpecMap + .getValue(metricCalculationSpecReportingMetrics.externalMetricCalculationSpecId) + .details + .filter, + ) } } + } val callRpc: suspend (List) -> BatchCreateMetricsResponse = { items -> batchCreateMetrics(request.parent, items) @@ -475,11 +474,12 @@ class ReportsService( if (state == Report.State.SUCCEEDED || state == Report.State.FAILED) { val externalMetricCalculationSpecIds = - internalReport.reportingMetricEntriesMap.flatMap { reportingMetricCalculationSpec -> - reportingMetricCalculationSpec.value.metricCalculationSpecReportingMetricsList.map { - it.externalMetricCalculationSpecId + internalReport.reportingMetricEntriesMap + .flatMap { reportingMetricCalculationSpec -> + reportingMetricCalculationSpec.value.metricCalculationSpecReportingMetricsList.map { + it.externalMetricCalculationSpecId + } } - } .distinct() val externalIdToMetricCalculationMap: Map = diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequestsTest.kt b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequestsTest.kt index 26197539406..a0c9416369a 100644 --- a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequestsTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequestsTest.kt @@ -20,9 +20,6 @@ import com.google.common.truth.extensions.proto.ProtoTruth.assertThat import io.grpc.Status import io.grpc.StatusException import kotlin.math.ceil -import kotlinx.coroutines.flow.asFlow -import kotlinx.coroutines.flow.flow -import kotlinx.coroutines.flow.toList import kotlinx.coroutines.runBlocking import org.junit.Before import org.junit.Rule @@ -98,11 +95,11 @@ class SubmitBatchRequestsTest { val result = submitBatchRequests( - items, - BATCH_GET_REPORTING_SETS_LIMIT, - ::batchGetReportingSets, - parseResponse, - ) + items, + BATCH_GET_REPORTING_SETS_LIMIT, + ::batchGetReportingSets, + parseResponse, + ) val batchGetReportingSetsCaptor: KArgumentCaptor = argumentCaptor() @@ -129,11 +126,11 @@ class SubmitBatchRequestsTest { val result = submitBatchRequests( - items, - BATCH_GET_REPORTING_SETS_LIMIT, - ::batchGetReportingSets, - parseResponse, - ) + items, + BATCH_GET_REPORTING_SETS_LIMIT, + ::batchGetReportingSets, + parseResponse, + ) val batchGetReportingSetsCaptor: KArgumentCaptor = argumentCaptor() @@ -156,7 +153,7 @@ class SubmitBatchRequestsTest { BATCH_GET_REPORTING_SETS_LIMIT, ::batchGetReportingSets, parseResponse, - ) + ) val batchGetReportingSetsCaptor: KArgumentCaptor = argumentCaptor() From aaa1d70fa52be870b1f817b8cc0b58ad5e9a615a Mon Sep 17 00:00:00 2001 From: Tristan Vuong Date: Thu, 8 Feb 2024 21:39:38 +0000 Subject: [PATCH 03/13] lint fix --- .../service/api/v2alpha/MetricsService.kt | 1 - .../internal/testing/v2/ReportsServiceTest.kt | 49 +++++++++---------- 2 files changed, 23 insertions(+), 27 deletions(-) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt index b0cc55a8317..b26c81e5c76 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt @@ -47,7 +47,6 @@ import kotlinx.coroutines.asExecutor import kotlinx.coroutines.async import kotlinx.coroutines.awaitAll import kotlinx.coroutines.coroutineScope -import kotlinx.coroutines.flow.map import kotlinx.coroutines.flow.toList import kotlinx.coroutines.withContext import org.jetbrains.annotations.BlockingExecutor diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/internal/testing/v2/ReportsServiceTest.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/internal/testing/v2/ReportsServiceTest.kt index 1f6a78db895..d42f8db0cd0 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/internal/testing/v2/ReportsServiceTest.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/internal/testing/v2/ReportsServiceTest.kt @@ -30,8 +30,6 @@ import io.grpc.StatusRuntimeException import java.time.Clock import kotlin.random.Random import kotlin.test.assertFailsWith -import kotlinx.coroutines.flow.Flow -import kotlinx.coroutines.flow.asFlow import kotlinx.coroutines.flow.toList import kotlinx.coroutines.runBlocking import org.junit.Before @@ -636,30 +634,29 @@ abstract class ReportsServiceTest { var metricIndex = 0 val createMetricsRequests: List = - createdReport.reportingMetricEntriesMap.entries - .flatMap { entry -> - val reportingSet = createdReportingSetsByExternalId.getValue(entry.key) - - entry.value.metricCalculationSpecReportingMetricsList.flatMap { - metricCalculationSpecReportingMetrics -> - val metricCalculationSpecFilter = - createdMetricCalculationSpecsByExternalId - .getValue(metricCalculationSpecReportingMetrics.externalMetricCalculationSpecId) - .details - .filter - metricCalculationSpecReportingMetrics.reportingMetricsList.map { reportingMetric -> - val externalMetricId = "externalMetricId$metricIndex" - metricIndex++ - buildCreateMetricRequest( - createdReport.cmmsMeasurementConsumerId, - externalMetricId, - reportingSet, - reportingMetric, - metricCalculationSpecFilter, - ) - } + createdReport.reportingMetricEntriesMap.entries.flatMap { entry -> + val reportingSet = createdReportingSetsByExternalId.getValue(entry.key) + + entry.value.metricCalculationSpecReportingMetricsList.flatMap { + metricCalculationSpecReportingMetrics -> + val metricCalculationSpecFilter = + createdMetricCalculationSpecsByExternalId + .getValue(metricCalculationSpecReportingMetrics.externalMetricCalculationSpecId) + .details + .filter + metricCalculationSpecReportingMetrics.reportingMetricsList.map { reportingMetric -> + val externalMetricId = "externalMetricId$metricIndex" + metricIndex++ + buildCreateMetricRequest( + createdReport.cmmsMeasurementConsumerId, + externalMetricId, + reportingSet, + reportingMetric, + metricCalculationSpecFilter, + ) } } + } val callRpc: suspend (List) -> BatchCreateMetricsResponse = { items -> metricsService.batchCreateMetrics( @@ -670,8 +667,8 @@ abstract class ReportsServiceTest { ) } submitBatchRequests(createMetricsRequests, MAX_BATCH_SIZE, callRpc) { response -> - response.metricsList - } + response.metricsList + } val retrievedReport = service.getReport( From 200bb1b5d68ea02a7f2cc9f3cade68638cf8e507 Mon Sep 17 00:00:00 2001 From: Tristan Vuong Date: Mon, 12 Feb 2024 22:48:26 +0000 Subject: [PATCH 04/13] Flow of lists is returned now. --- .../service/api/SubmitBatchRequests.kt | 19 +++++++----- .../service/api/v2alpha/MetricsService.kt | 20 +++++++----- .../api/v2alpha/ReportSchedulesService.kt | 26 +++++++++------- .../service/api/v2alpha/ReportsService.kt | 31 ++++++++++++++----- .../internal/testing/v2/ReportsServiceTest.kt | 2 +- .../service/api/SubmitBatchRequestsTest.kt | 7 +++-- 6 files changed, 67 insertions(+), 38 deletions(-) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequests.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequests.kt index 50629ae785b..f2eafb31af8 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequests.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequests.kt @@ -18,7 +18,6 @@ package org.wfanet.measurement.reporting.service.api import kotlinx.coroutines.Deferred import kotlinx.coroutines.async -import kotlinx.coroutines.awaitAll import kotlinx.coroutines.coroutineScope import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.flow @@ -56,7 +55,7 @@ suspend fun submitBatchRequests( limit: Int, callRpc: suspend (List) -> RESP, parseResponse: (RESP) -> List, -): List { +): Flow> { if (limit <= 0) { throw BatchRequestException( "Invalid limit", @@ -67,12 +66,16 @@ suspend fun submitBatchRequests( // For network requests, the number of concurrent coroutines needs to be capped. To be on the safe // side, a low number is chosen. val batchSemaphore = Semaphore(3) - return coroutineScope { - val deferred: List>> = - items.chunked(limit).map { batch: List -> - async { batchSemaphore.withPermit { parseResponse(callRpc(batch)) } } + return flow { + coroutineScope { + val deferred: List>> = + items.chunked(limit).map { batch: List -> + async { batchSemaphore.withPermit { parseResponse(callRpc(batch)) } } + } + + deferred.forEach { + emit(it.await()) } - val responses: List> = deferred.awaitAll() - responses.flatten() + } } } diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt index b26c81e5c76..50241040349 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt @@ -308,14 +308,20 @@ class MetricsService( } val internalPrimitiveReportingSetMap: Map = - submitBatchRequests( + buildMap { + submitBatchRequests( externalPrimitiveReportingSetIds, BATCH_GET_REPORTING_SETS_LIMIT, callBatchGetInternalReportingSetsRpc, ) { response: BatchGetReportingSetsResponse -> response.reportingSetsList } - .associateBy { it.externalReportingSetId } + .collect { reportingSets: List -> + for (reportingSet in reportingSets) { + computeIfAbsent(reportingSet.externalReportingSetId) { reportingSet } + } + } + } val dataProviderNames = mutableSetOf() for (internalPrimitiveReportingSet in internalPrimitiveReportingSetMap.values) { @@ -359,7 +365,7 @@ class MetricsService( callBatchCreateMeasurementsRpc, ) { response: BatchCreateMeasurementsResponse -> response.measurementsList - } + }.toList().flatten() // Set CMMS measurement IDs. val callBatchSetCmmsMeasurementIdsRpc: @@ -379,7 +385,7 @@ class MetricsService( callBatchSetCmmsMeasurementIdsRpc, ) { response: BatchSetCmmsMeasurementIdsResponse -> response.measurementsList - } + }.collect {} } /** Sets a batch of CMMS [MeasurementIds] to the [InternalMeasurement] table. */ @@ -791,7 +797,7 @@ class MetricsService( callBatchSetInternalMeasurementResultsRpc, ) { response: BatchSetCmmsMeasurementResultsResponse -> response.measurementsList - } + }.collect {} anyUpdate = true } @@ -813,7 +819,7 @@ class MetricsService( callBatchSetInternalMeasurementFailuresRpc, ) { response: BatchSetCmmsMeasurementFailuresResponse -> response.measurementsList - } + }.collect {} anyUpdate = true } @@ -915,7 +921,7 @@ class MetricsService( callBatchGetMeasurementsRpc, ) { response: BatchGetMeasurementsResponse -> response.measurementsList - } + }.toList().flatten() } /** Batch get CMMS measurements. */ diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportSchedulesService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportSchedulesService.kt index b3330858e22..bbd78879add 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportSchedulesService.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportSchedulesService.kt @@ -677,23 +677,25 @@ class ReportSchedulesService( externalReportingSetIdSet.clear() response.reportingSetsList } - .forEach { - if (it.hasComposite()) { - val lhsExternalReportingSetId = it.composite.lhs.externalReportingSetId - if (lhsExternalReportingSetId.isNotEmpty()) { - if (!retrievedExternalReportingSetIdSet.contains(lhsExternalReportingSetId)) { - externalReportingSetIdSet.add(lhsExternalReportingSetId) + .collect { internalReportingSets: List -> + for (internalReportingSet in internalReportingSets) { + if (internalReportingSet.hasComposite()) { + val lhsExternalReportingSetId = internalReportingSet.composite.lhs.externalReportingSetId + if (lhsExternalReportingSetId.isNotEmpty()) { + if (!retrievedExternalReportingSetIdSet.contains(lhsExternalReportingSetId)) { + externalReportingSetIdSet.add(lhsExternalReportingSetId) + } } - } - val rhsExternalReportingSetId = it.composite.rhs.externalReportingSetId - if (rhsExternalReportingSetId.isNotEmpty()) { - if (!retrievedExternalReportingSetIdSet.contains(rhsExternalReportingSetId)) { - externalReportingSetIdSet.add(rhsExternalReportingSetId) + val rhsExternalReportingSetId = internalReportingSet.composite.rhs.externalReportingSetId + if (rhsExternalReportingSetId.isNotEmpty()) { + if (!retrievedExternalReportingSetIdSet.contains(rhsExternalReportingSetId)) { + externalReportingSetIdSet.add(rhsExternalReportingSetId) + } } } + reportingSets.add(internalReportingSet) } - reportingSets.add(it) } } diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsService.kt index 542b8ef651e..9a40e616e61 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsService.kt @@ -169,10 +169,16 @@ class ReportsService( batchGetMetrics(principal.resourceKey.toName(), items) } val externalIdToMetricMap: Map = - submitBatchRequests(metricNames, BATCH_GET_METRICS_LIMIT, callRpc) { response -> + buildMap { + submitBatchRequests(metricNames, BATCH_GET_METRICS_LIMIT, callRpc) { response -> response.metricsList } - .associateBy { checkNotNull(MetricKey.fromName(it.name)).metricId } + .collect { metrics: List -> + for (metric in metrics) { + computeIfAbsent(checkNotNull(MetricKey.fromName(metric.name)).metricId) { metric } + } + } + } return listReportsResponse { reports += @@ -233,10 +239,16 @@ class ReportsService( batchGetMetrics(principal.resourceKey.toName(), items) } val externalIdToMetricMap: Map = - submitBatchRequests(metricNames, BATCH_GET_METRICS_LIMIT, callRpc) { response -> + buildMap { + submitBatchRequests(metricNames, BATCH_GET_METRICS_LIMIT, callRpc) { response -> response.metricsList } - .associateBy { checkNotNull(MetricKey.fromName(it.name)).metricId } + .collect { metrics: List -> + for (metric in metrics) { + computeIfAbsent(checkNotNull(MetricKey.fromName(metric.name)).metricId) { metric } + } + } + } // Convert the internal report to public and return. return convertInternalReportToPublic(internalReport, externalIdToMetricMap) @@ -381,11 +393,16 @@ class ReportsService( batchCreateMetrics(request.parent, items) } val externalIdToMetricMap: Map = - submitBatchRequests(createMetricRequests, BATCH_CREATE_METRICS_LIMIT, callRpc) { - response: BatchCreateMetricsResponse -> + buildMap { + submitBatchRequests(createMetricRequests, BATCH_CREATE_METRICS_LIMIT, callRpc) { response -> response.metricsList } - .associateBy { checkNotNull(MetricKey.fromName(it.name)).metricId } + .collect { metrics: List -> + for (metric in metrics) { + computeIfAbsent(checkNotNull(MetricKey.fromName(metric.name)).metricId) { metric } + } + } + } // Once all metrics are created, get the updated internal report with the metric IDs filled. val updatedInternalReport = diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/internal/testing/v2/ReportsServiceTest.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/internal/testing/v2/ReportsServiceTest.kt index d42f8db0cd0..62231a4bf79 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/internal/testing/v2/ReportsServiceTest.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/internal/testing/v2/ReportsServiceTest.kt @@ -668,7 +668,7 @@ abstract class ReportsServiceTest { } submitBatchRequests(createMetricsRequests, MAX_BATCH_SIZE, callRpc) { response -> response.metricsList - } + }.collect {} val retrievedReport = service.getReport( diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequestsTest.kt b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequestsTest.kt index a0c9416369a..bf4eef1b307 100644 --- a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequestsTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequestsTest.kt @@ -41,6 +41,7 @@ import org.wfanet.measurement.internal.reporting.v2.ReportingSetsGrpcKt import org.wfanet.measurement.internal.reporting.v2.batchGetReportingSetsRequest import org.wfanet.measurement.internal.reporting.v2.batchGetReportingSetsResponse import org.wfanet.measurement.internal.reporting.v2.reportingSet as internalReportingSet +import kotlinx.coroutines.flow.toList import org.wfanet.measurement.reporting.service.api.v2alpha.ReportingSetsService private const val MEASUREMENT_CONSUMER_ID = "mc_id" @@ -99,7 +100,7 @@ class SubmitBatchRequestsTest { BATCH_GET_REPORTING_SETS_LIMIT, ::batchGetReportingSets, parseResponse, - ) + ).toList().flatten() val batchGetReportingSetsCaptor: KArgumentCaptor = argumentCaptor() @@ -130,7 +131,7 @@ class SubmitBatchRequestsTest { BATCH_GET_REPORTING_SETS_LIMIT, ::batchGetReportingSets, parseResponse, - ) + ).toList().flatten() val batchGetReportingSetsCaptor: KArgumentCaptor = argumentCaptor() @@ -153,7 +154,7 @@ class SubmitBatchRequestsTest { BATCH_GET_REPORTING_SETS_LIMIT, ::batchGetReportingSets, parseResponse, - ) + ).toList().flatten() val batchGetReportingSetsCaptor: KArgumentCaptor = argumentCaptor() From 896ee818122ca1f904f48246f65b68fcf484e112 Mon Sep 17 00:00:00 2001 From: Tristan Vuong Date: Mon, 12 Feb 2024 22:52:03 +0000 Subject: [PATCH 05/13] lint fix --- .../service/api/SubmitBatchRequests.kt | 4 +- .../service/api/v2alpha/MetricsService.kt | 92 ++++++++++--------- .../api/v2alpha/ReportSchedulesService.kt | 6 +- .../service/api/v2alpha/ReportsService.kt | 45 +++++---- .../internal/testing/v2/ReportsServiceTest.kt | 5 +- .../service/api/SubmitBatchRequestsTest.kt | 38 ++++---- 6 files changed, 100 insertions(+), 90 deletions(-) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequests.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequests.kt index f2eafb31af8..3abaf303781 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequests.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequests.kt @@ -73,9 +73,7 @@ suspend fun submitBatchRequests( async { batchSemaphore.withPermit { parseResponse(callRpc(batch)) } } } - deferred.forEach { - emit(it.await()) - } + deferred.forEach { emit(it.await()) } } } } diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt index 50241040349..cae4caad1cd 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt @@ -307,21 +307,20 @@ class MetricsService( batchGetInternalReportingSets(principal.resourceKey.measurementConsumerId, items) } - val internalPrimitiveReportingSetMap: Map = - buildMap { - submitBatchRequests( + val internalPrimitiveReportingSetMap: Map = buildMap { + submitBatchRequests( externalPrimitiveReportingSetIds, BATCH_GET_REPORTING_SETS_LIMIT, callBatchGetInternalReportingSetsRpc, ) { response: BatchGetReportingSetsResponse -> response.reportingSetsList } - .collect { reportingSets: List -> - for (reportingSet in reportingSets) { - computeIfAbsent(reportingSet.externalReportingSetId) { reportingSet } - } + .collect { reportingSets: List -> + for (reportingSet in reportingSets) { + computeIfAbsent(reportingSet.externalReportingSetId) { reportingSet } } - } + } + } val dataProviderNames = mutableSetOf() for (internalPrimitiveReportingSet in internalPrimitiveReportingSetMap.values) { @@ -360,12 +359,14 @@ class MetricsService( val cmmsMeasurements: List = submitBatchRequests( - cmmsCreateMeasurementRequests, - BATCH_KINGDOM_MEASUREMENTS_LIMIT, - callBatchCreateMeasurementsRpc, - ) { response: BatchCreateMeasurementsResponse -> - response.measurementsList - }.toList().flatten() + cmmsCreateMeasurementRequests, + BATCH_KINGDOM_MEASUREMENTS_LIMIT, + callBatchCreateMeasurementsRpc, + ) { response: BatchCreateMeasurementsResponse -> + response.measurementsList + } + .toList() + .flatten() // Set CMMS measurement IDs. val callBatchSetCmmsMeasurementIdsRpc: @@ -375,17 +376,18 @@ class MetricsService( } submitBatchRequests( - cmmsMeasurements.map { - measurementIds { - cmmsCreateMeasurementRequestId = it.measurementReferenceId - cmmsMeasurementId = MeasurementKey.fromName(it.name)!!.measurementId - } - }, - BATCH_SET_CMMS_MEASUREMENT_IDS_LIMIT, - callBatchSetCmmsMeasurementIdsRpc, - ) { response: BatchSetCmmsMeasurementIdsResponse -> - response.measurementsList - }.collect {} + cmmsMeasurements.map { + measurementIds { + cmmsCreateMeasurementRequestId = it.measurementReferenceId + cmmsMeasurementId = MeasurementKey.fromName(it.name)!!.measurementId + } + }, + BATCH_SET_CMMS_MEASUREMENT_IDS_LIMIT, + callBatchSetCmmsMeasurementIdsRpc, + ) { response: BatchSetCmmsMeasurementIdsResponse -> + response.measurementsList + } + .collect {} } /** Sets a batch of CMMS [MeasurementIds] to the [InternalMeasurement] table. */ @@ -792,12 +794,13 @@ class MetricsService( batchSetInternalMeasurementResults(items, apiAuthenticationKey, principal) } submitBatchRequests( - measurementsList, - BATCH_SET_MEASUREMENT_RESULTS_LIMIT, - callBatchSetInternalMeasurementResultsRpc, - ) { response: BatchSetCmmsMeasurementResultsResponse -> - response.measurementsList - }.collect {} + measurementsList, + BATCH_SET_MEASUREMENT_RESULTS_LIMIT, + callBatchSetInternalMeasurementResultsRpc, + ) { response: BatchSetCmmsMeasurementResultsResponse -> + response.measurementsList + } + .collect {} anyUpdate = true } @@ -814,12 +817,13 @@ class MetricsService( ) } submitBatchRequests( - measurementsList, - BATCH_SET_MEASUREMENT_FAILURES_LIMIT, - callBatchSetInternalMeasurementFailuresRpc, - ) { response: BatchSetCmmsMeasurementFailuresResponse -> - response.measurementsList - }.collect {} + measurementsList, + BATCH_SET_MEASUREMENT_FAILURES_LIMIT, + callBatchSetInternalMeasurementFailuresRpc, + ) { response: BatchSetCmmsMeasurementFailuresResponse -> + response.measurementsList + } + .collect {} anyUpdate = true } @@ -916,12 +920,14 @@ class MetricsService( } return submitBatchRequests( - measurementNames, - BATCH_KINGDOM_MEASUREMENTS_LIMIT, - callBatchGetMeasurementsRpc, - ) { response: BatchGetMeasurementsResponse -> - response.measurementsList - }.toList().flatten() + measurementNames, + BATCH_KINGDOM_MEASUREMENTS_LIMIT, + callBatchGetMeasurementsRpc, + ) { response: BatchGetMeasurementsResponse -> + response.measurementsList + } + .toList() + .flatten() } /** Batch get CMMS measurements. */ diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportSchedulesService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportSchedulesService.kt index bbd78879add..710be54f309 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportSchedulesService.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportSchedulesService.kt @@ -680,14 +680,16 @@ class ReportSchedulesService( .collect { internalReportingSets: List -> for (internalReportingSet in internalReportingSets) { if (internalReportingSet.hasComposite()) { - val lhsExternalReportingSetId = internalReportingSet.composite.lhs.externalReportingSetId + val lhsExternalReportingSetId = + internalReportingSet.composite.lhs.externalReportingSetId if (lhsExternalReportingSetId.isNotEmpty()) { if (!retrievedExternalReportingSetIdSet.contains(lhsExternalReportingSetId)) { externalReportingSetIdSet.add(lhsExternalReportingSetId) } } - val rhsExternalReportingSetId = internalReportingSet.composite.rhs.externalReportingSetId + val rhsExternalReportingSetId = + internalReportingSet.composite.rhs.externalReportingSetId if (rhsExternalReportingSetId.isNotEmpty()) { if (!retrievedExternalReportingSetIdSet.contains(rhsExternalReportingSetId)) { externalReportingSetIdSet.add(rhsExternalReportingSetId) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsService.kt index 9a40e616e61..78832999473 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsService.kt @@ -168,17 +168,16 @@ class ReportsService( val callRpc: suspend (List) -> BatchGetMetricsResponse = { items -> batchGetMetrics(principal.resourceKey.toName(), items) } - val externalIdToMetricMap: Map = - buildMap { - submitBatchRequests(metricNames, BATCH_GET_METRICS_LIMIT, callRpc) { response -> + val externalIdToMetricMap: Map = buildMap { + submitBatchRequests(metricNames, BATCH_GET_METRICS_LIMIT, callRpc) { response -> response.metricsList } - .collect { metrics: List -> - for (metric in metrics) { - computeIfAbsent(checkNotNull(MetricKey.fromName(metric.name)).metricId) { metric } - } + .collect { metrics: List -> + for (metric in metrics) { + computeIfAbsent(checkNotNull(MetricKey.fromName(metric.name)).metricId) { metric } } - } + } + } return listReportsResponse { reports += @@ -238,17 +237,16 @@ class ReportsService( val callRpc: suspend (List) -> BatchGetMetricsResponse = { items -> batchGetMetrics(principal.resourceKey.toName(), items) } - val externalIdToMetricMap: Map = - buildMap { - submitBatchRequests(metricNames, BATCH_GET_METRICS_LIMIT, callRpc) { response -> + val externalIdToMetricMap: Map = buildMap { + submitBatchRequests(metricNames, BATCH_GET_METRICS_LIMIT, callRpc) { response -> response.metricsList } - .collect { metrics: List -> - for (metric in metrics) { - computeIfAbsent(checkNotNull(MetricKey.fromName(metric.name)).metricId) { metric } - } + .collect { metrics: List -> + for (metric in metrics) { + computeIfAbsent(checkNotNull(MetricKey.fromName(metric.name)).metricId) { metric } } - } + } + } // Convert the internal report to public and return. return convertInternalReportToPublic(internalReport, externalIdToMetricMap) @@ -392,17 +390,16 @@ class ReportsService( val callRpc: suspend (List) -> BatchCreateMetricsResponse = { items -> batchCreateMetrics(request.parent, items) } - val externalIdToMetricMap: Map = - buildMap { - submitBatchRequests(createMetricRequests, BATCH_CREATE_METRICS_LIMIT, callRpc) { response -> + val externalIdToMetricMap: Map = buildMap { + submitBatchRequests(createMetricRequests, BATCH_CREATE_METRICS_LIMIT, callRpc) { response -> response.metricsList } - .collect { metrics: List -> - for (metric in metrics) { - computeIfAbsent(checkNotNull(MetricKey.fromName(metric.name)).metricId) { metric } - } + .collect { metrics: List -> + for (metric in metrics) { + computeIfAbsent(checkNotNull(MetricKey.fromName(metric.name)).metricId) { metric } } - } + } + } // Once all metrics are created, get the updated internal report with the metric IDs filled. val updatedInternalReport = diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/internal/testing/v2/ReportsServiceTest.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/internal/testing/v2/ReportsServiceTest.kt index 62231a4bf79..d4ce867f432 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/internal/testing/v2/ReportsServiceTest.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/internal/testing/v2/ReportsServiceTest.kt @@ -667,8 +667,9 @@ abstract class ReportsServiceTest { ) } submitBatchRequests(createMetricsRequests, MAX_BATCH_SIZE, callRpc) { response -> - response.metricsList - }.collect {} + response.metricsList + } + .collect {} val retrievedReport = service.getReport( diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequestsTest.kt b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequestsTest.kt index bf4eef1b307..adc71611d53 100644 --- a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequestsTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequestsTest.kt @@ -20,6 +20,7 @@ import com.google.common.truth.extensions.proto.ProtoTruth.assertThat import io.grpc.Status import io.grpc.StatusException import kotlin.math.ceil +import kotlinx.coroutines.flow.toList import kotlinx.coroutines.runBlocking import org.junit.Before import org.junit.Rule @@ -41,7 +42,6 @@ import org.wfanet.measurement.internal.reporting.v2.ReportingSetsGrpcKt import org.wfanet.measurement.internal.reporting.v2.batchGetReportingSetsRequest import org.wfanet.measurement.internal.reporting.v2.batchGetReportingSetsResponse import org.wfanet.measurement.internal.reporting.v2.reportingSet as internalReportingSet -import kotlinx.coroutines.flow.toList import org.wfanet.measurement.reporting.service.api.v2alpha.ReportingSetsService private const val MEASUREMENT_CONSUMER_ID = "mc_id" @@ -96,11 +96,13 @@ class SubmitBatchRequestsTest { val result = submitBatchRequests( - items, - BATCH_GET_REPORTING_SETS_LIMIT, - ::batchGetReportingSets, - parseResponse, - ).toList().flatten() + items, + BATCH_GET_REPORTING_SETS_LIMIT, + ::batchGetReportingSets, + parseResponse, + ) + .toList() + .flatten() val batchGetReportingSetsCaptor: KArgumentCaptor = argumentCaptor() @@ -127,11 +129,13 @@ class SubmitBatchRequestsTest { val result = submitBatchRequests( - items, - BATCH_GET_REPORTING_SETS_LIMIT, - ::batchGetReportingSets, - parseResponse, - ).toList().flatten() + items, + BATCH_GET_REPORTING_SETS_LIMIT, + ::batchGetReportingSets, + parseResponse, + ) + .toList() + .flatten() val batchGetReportingSetsCaptor: KArgumentCaptor = argumentCaptor() @@ -150,11 +154,13 @@ class SubmitBatchRequestsTest { val result: List = submitBatchRequests( - emptyList(), - BATCH_GET_REPORTING_SETS_LIMIT, - ::batchGetReportingSets, - parseResponse, - ).toList().flatten() + emptyList(), + BATCH_GET_REPORTING_SETS_LIMIT, + ::batchGetReportingSets, + parseResponse, + ) + .toList() + .flatten() val batchGetReportingSetsCaptor: KArgumentCaptor = argumentCaptor() From 2a3392b0d37dc0a8c955533d445397c07ff78aa4 Mon Sep 17 00:00:00 2001 From: Tristan Vuong Date: Mon, 12 Feb 2024 23:36:52 +0000 Subject: [PATCH 06/13] Clarify method comment --- .../reporting/service/api/SubmitBatchRequests.kt | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequests.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequests.kt index 3abaf303781..2cc1f28505f 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequests.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequests.kt @@ -49,7 +49,11 @@ fun Flow.chunked(chunkSize: Int): Flow> { } } -/** Submits multiple RPCs by dividing the input items to batches. */ +/** + * Submits multiple RPCs by dividing the input items to batches. + * + * @return [Flow] that emits [List]s containing the results of the multiple RPCs. + */ suspend fun submitBatchRequests( items: Collection, limit: Int, From e013889f37b5e2e576fbaabedccf03c388725476 Mon Sep 17 00:00:00 2001 From: Tristan Vuong Date: Wed, 14 Feb 2024 17:49:31 +0000 Subject: [PATCH 07/13] Use flow as input --- .../service/api/SubmitBatchRequests.kt | 11 +- .../service/api/v2alpha/MetricsService.kt | 2215 ++++++++--------- .../api/v2alpha/ReportSchedulesService.kt | 3 +- .../service/api/v2alpha/ReportsService.kt | 85 +- .../internal/testing/v2/ReportsServiceTest.kt | 3 +- .../service/api/SubmitBatchRequestsTest.kt | 14 +- 6 files changed, 1184 insertions(+), 1147 deletions(-) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequests.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequests.kt index 2cc1f28505f..61aee48ad3d 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequests.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequests.kt @@ -55,7 +55,7 @@ fun Flow.chunked(chunkSize: Int): Flow> { * @return [Flow] that emits [List]s containing the results of the multiple RPCs. */ suspend fun submitBatchRequests( - items: Collection, + items: Flow, limit: Int, callRpc: suspend (List) -> RESP, parseResponse: (RESP) -> List, @@ -73,8 +73,13 @@ suspend fun submitBatchRequests( return flow { coroutineScope { val deferred: List>> = - items.chunked(limit).map { batch: List -> - async { batchSemaphore.withPermit { parseResponse(callRpc(batch)) } } + buildList { + items.chunked(limit).collect { batch: List -> + // The batch reference is reused for every collect call. To ensure async works, a copy + // of the contents needs to be saved in a new reference. + val tempBatch = batch.toList() + add(async { batchSemaphore.withPermit { parseResponse(callRpc(tempBatch)) } }) + } } deferred.forEach { emit(it.await()) } diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt index 8cf40689bbe..47997cfe753 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt @@ -43,10 +43,16 @@ import kotlin.math.sqrt import kotlinx.coroutines.CoroutineDispatcher import kotlinx.coroutines.Deferred import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.asExecutor import kotlinx.coroutines.async import kotlinx.coroutines.awaitAll import kotlinx.coroutines.coroutineScope +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.asFlow +import kotlinx.coroutines.flow.flatMapMerge +import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.flow.map import kotlinx.coroutines.flow.toList import kotlinx.coroutines.withContext import org.jetbrains.annotations.BlockingExecutor @@ -152,6 +158,7 @@ import org.wfanet.measurement.measurementconsumer.stats.LiquidLegionsSketchMetho import org.wfanet.measurement.measurementconsumer.stats.LiquidLegionsV2Methodology import org.wfanet.measurement.measurementconsumer.stats.Methodology import org.wfanet.measurement.measurementconsumer.stats.NoiseMechanism as StatsNoiseMechanism +import kotlinx.coroutines.flow.count import org.wfanet.measurement.measurementconsumer.stats.ReachMeasurementParams import org.wfanet.measurement.measurementconsumer.stats.ReachMeasurementVarianceParams import org.wfanet.measurement.measurementconsumer.stats.ReachMetricVarianceParams @@ -207,134 +214,137 @@ private const val BATCH_SET_MEASUREMENT_RESULTS_LIMIT = 1000 private const val BATCH_SET_MEASUREMENT_FAILURES_LIMIT = 1000 class MetricsService( - private val metricSpecConfig: MetricSpecConfig, - private val internalReportingSetsStub: InternalReportingSetsCoroutineStub, - private val internalMetricsStub: InternalMetricsCoroutineStub, - private val variances: Variances, - internalMeasurementsStub: InternalMeasurementsCoroutineStub, - dataProvidersStub: DataProvidersCoroutineStub, - measurementsStub: MeasurementsCoroutineStub, - certificatesStub: CertificatesCoroutineStub, - measurementConsumersStub: MeasurementConsumersCoroutineStub, - encryptionKeyPairStore: EncryptionKeyPairStore, - secureRandom: SecureRandom, - signingPrivateKeyDir: File, - trustedCertificates: Map, - certificateCacheExpirationDuration: Duration = Duration.ofMinutes(60), - dataProviderCacheExpirationDuration: Duration = Duration.ofMinutes(60), - keyReaderContext: @BlockingExecutor CoroutineContext = Dispatchers.IO, - cacheLoaderContext: @NonBlockingExecutor CoroutineContext = Dispatchers.Default, + private val metricSpecConfig: MetricSpecConfig, + private val internalReportingSetsStub: InternalReportingSetsCoroutineStub, + private val internalMetricsStub: InternalMetricsCoroutineStub, + private val variances: Variances, + internalMeasurementsStub: InternalMeasurementsCoroutineStub, + dataProvidersStub: DataProvidersCoroutineStub, + measurementsStub: MeasurementsCoroutineStub, + certificatesStub: CertificatesCoroutineStub, + measurementConsumersStub: MeasurementConsumersCoroutineStub, + encryptionKeyPairStore: EncryptionKeyPairStore, + secureRandom: SecureRandom, + signingPrivateKeyDir: File, + trustedCertificates: Map, + certificateCacheExpirationDuration: Duration = Duration.ofMinutes(60), + dataProviderCacheExpirationDuration: Duration = Duration.ofMinutes(60), + keyReaderContext: @BlockingExecutor CoroutineContext = Dispatchers.IO, + cacheLoaderContext: @NonBlockingExecutor CoroutineContext = Dispatchers.Default, ) : MetricsCoroutineImplBase() { private data class DataProviderInfo( - val dataProviderName: String, - val publicKey: SignedMessage, - val certificateName: String, + val dataProviderName: String, + val publicKey: SignedMessage, + val certificateName: String, ) private val measurementSupplier = - MeasurementSupplier( - internalReportingSetsStub, - internalMeasurementsStub, - measurementsStub, - dataProvidersStub, - certificatesStub, - measurementConsumersStub, - encryptionKeyPairStore, - secureRandom, - signingPrivateKeyDir, - trustedCertificates, - certificateCacheExpirationDuration = certificateCacheExpirationDuration, - dataProviderCacheExpirationDuration = dataProviderCacheExpirationDuration, - keyReaderContext, - cacheLoaderContext, - ) + MeasurementSupplier( + internalReportingSetsStub, + internalMeasurementsStub, + measurementsStub, + dataProvidersStub, + certificatesStub, + measurementConsumersStub, + encryptionKeyPairStore, + secureRandom, + signingPrivateKeyDir, + trustedCertificates, + certificateCacheExpirationDuration = certificateCacheExpirationDuration, + dataProviderCacheExpirationDuration = dataProviderCacheExpirationDuration, + keyReaderContext, + cacheLoaderContext, + ) private class MeasurementSupplier( - private val internalReportingSetsStub: InternalReportingSetsCoroutineStub, - private val internalMeasurementsStub: InternalMeasurementsCoroutineStub, - private val measurementsStub: MeasurementsCoroutineStub, - private val dataProvidersStub: DataProvidersCoroutineStub, - private val certificatesStub: CertificatesCoroutineStub, - private val measurementConsumersStub: MeasurementConsumersCoroutineStub, - private val encryptionKeyPairStore: EncryptionKeyPairStore, - private val secureRandom: SecureRandom, - private val signingPrivateKeyDir: File, - private val trustedCertificates: Map, - certificateCacheExpirationDuration: Duration, - dataProviderCacheExpirationDuration: Duration, - private val keyReaderContext: @BlockingExecutor CoroutineContext = Dispatchers.IO, - cacheLoaderContext: @NonBlockingExecutor CoroutineContext = Dispatchers.Default, + private val internalReportingSetsStub: InternalReportingSetsCoroutineStub, + private val internalMeasurementsStub: InternalMeasurementsCoroutineStub, + private val measurementsStub: MeasurementsCoroutineStub, + private val dataProvidersStub: DataProvidersCoroutineStub, + private val certificatesStub: CertificatesCoroutineStub, + private val measurementConsumersStub: MeasurementConsumersCoroutineStub, + private val encryptionKeyPairStore: EncryptionKeyPairStore, + private val secureRandom: SecureRandom, + private val signingPrivateKeyDir: File, + private val trustedCertificates: Map, + certificateCacheExpirationDuration: Duration, + dataProviderCacheExpirationDuration: Duration, + private val keyReaderContext: @BlockingExecutor CoroutineContext = Dispatchers.IO, + cacheLoaderContext: @NonBlockingExecutor CoroutineContext = Dispatchers.Default, ) { private data class ResourceNameApiAuthenticationKey( - val name: String, - val apiAuthenticationKey: String, + val name: String, + val apiAuthenticationKey: String, ) private val certificateCache: LoadingCache = - LoadingCache( - Caffeine.newBuilder() - .expireAfterWrite(certificateCacheExpirationDuration) - .executor( - (cacheLoaderContext[ContinuationInterceptor] as CoroutineDispatcher).asExecutor() - ) - .buildAsync() - ) { key -> - getCertificate(name = key.name, apiAuthenticationKey = key.apiAuthenticationKey) - } + LoadingCache( + Caffeine.newBuilder() + .expireAfterWrite(certificateCacheExpirationDuration) + .executor( + (cacheLoaderContext[ContinuationInterceptor] as CoroutineDispatcher) + .asExecutor()) + .buildAsync()) { key -> + getCertificate(name = key.name, apiAuthenticationKey = key.apiAuthenticationKey) + } private val dataProviderCache: LoadingCache = - LoadingCache( - Caffeine.newBuilder() - .expireAfterWrite(dataProviderCacheExpirationDuration) - .executor( - (cacheLoaderContext[ContinuationInterceptor] as CoroutineDispatcher).asExecutor() - ) - .buildAsync() - ) { key -> - getDataProvider(name = key.name, apiAuthenticationKey = key.apiAuthenticationKey) - } + LoadingCache( + Caffeine.newBuilder() + .expireAfterWrite(dataProviderCacheExpirationDuration) + .executor( + (cacheLoaderContext[ContinuationInterceptor] as CoroutineDispatcher) + .asExecutor()) + .buildAsync()) { key -> + getDataProvider(name = key.name, apiAuthenticationKey = key.apiAuthenticationKey) + } /** * Creates CMM public [Measurement]s and [InternalMeasurement]s from a list of [InternalMetric]. */ suspend fun createCmmsMeasurements( - internalMetricsList: List, - principal: MeasurementConsumerPrincipal, + internalMetricsList: List, + principal: MeasurementConsumerPrincipal, ) { val measurementConsumer: MeasurementConsumer = getMeasurementConsumer(principal) // Gets all external IDs of primitive reporting sets from the metric list. - val externalPrimitiveReportingSetIds: List = - internalMetricsList - .flatMap { internalMetric -> - internalMetric.weightedMeasurementsList.flatMap { weightedMeasurement -> - weightedMeasurement.measurement.primitiveReportingSetBasesList.map { - it.externalReportingSetId + val externalPrimitiveReportingSetIds: Flow = flow { + buildSet { + for (internalMetric in internalMetricsList) { + for (weightedMeasurement in internalMetric.weightedMeasurementsList) { + for (primitiveReportingSetBasis in + weightedMeasurement.measurement.primitiveReportingSetBasesList) { + if (!contains(primitiveReportingSetBasis.externalReportingSetId)) { + emit(primitiveReportingSetBasis.externalReportingSetId) + add(primitiveReportingSetBasis.externalReportingSetId) + } } } } - .distinct() + } + } val callBatchGetInternalReportingSetsRpc: - suspend (List) -> BatchGetReportingSetsResponse = - { items -> - batchGetInternalReportingSets(principal.resourceKey.measurementConsumerId, items) - } + suspend (List) -> BatchGetReportingSetsResponse = + { items -> + batchGetInternalReportingSets(principal.resourceKey.measurementConsumerId, items) + } val internalPrimitiveReportingSetMap: Map = buildMap { submitBatchRequests( - externalPrimitiveReportingSetIds, - BATCH_GET_REPORTING_SETS_LIMIT, - callBatchGetInternalReportingSetsRpc, - ) { response: BatchGetReportingSetsResponse -> - response.reportingSetsList - } - .collect { reportingSets: List -> - for (reportingSet in reportingSets) { - computeIfAbsent(reportingSet.externalReportingSetId) { reportingSet } + externalPrimitiveReportingSetIds, + BATCH_GET_REPORTING_SETS_LIMIT, + callBatchGetInternalReportingSetsRpc, + ) { response: BatchGetReportingSetsResponse -> + response.reportingSetsList + } + .collect { reportingSets: List -> + for (reportingSet in reportingSets) { + computeIfAbsent(reportingSet.externalReportingSetId) { reportingSet } + } } - } } val dataProviderNames = mutableSetOf() @@ -344,79 +354,80 @@ class MetricsService( } } val dataProviderInfoMap: Map = - buildDataProviderInfoMap(principal.config.apiKey, dataProviderNames) + buildDataProviderInfoMap(principal.config.apiKey, dataProviderNames) val measurementConsumerSigningKey = getMeasurementConsumerSigningKey(principal) - val cmmsCreateMeasurementRequests: List = - internalMetricsList.flatMap { internalMetric -> - internalMetric.weightedMeasurementsList - .filter { it.measurement.cmmsMeasurementId.isBlank() } - .map { - buildCreateMeasurementRequest( - it.measurement, - internalMetric.metricSpec, - internalPrimitiveReportingSetMap, - measurementConsumer, - principal, - dataProviderInfoMap, - measurementConsumerSigningKey, - ) + val cmmsCreateMeasurementRequests: Flow = flow { + for (internalMetric in internalMetricsList) { + for (weightedMeasurement in internalMetric.weightedMeasurementsList) { + if (weightedMeasurement.measurement.cmmsMeasurementId.isBlank()) { + emit( + buildCreateMeasurementRequest( + weightedMeasurement.measurement, + internalMetric.metricSpec, + internalPrimitiveReportingSetMap, + measurementConsumer, + principal, + dataProviderInfoMap, + measurementConsumerSigningKey, + )) } + } } + } // Create CMMS measurements. val callBatchCreateMeasurementsRpc: - suspend (List) -> BatchCreateMeasurementsResponse = - { items -> - batchCreateCmmsMeasurements(principal, items) - } - - val cmmsMeasurements: List = - submitBatchRequests( - cmmsCreateMeasurementRequests, - BATCH_KINGDOM_MEASUREMENTS_LIMIT, - callBatchCreateMeasurementsRpc, - ) { response: BatchCreateMeasurementsResponse -> - response.measurementsList + suspend (List) -> BatchCreateMeasurementsResponse = + { items -> + batchCreateCmmsMeasurements(principal, items) } - .toList() - .flatten() + + @OptIn(ExperimentalCoroutinesApi::class) + val cmmsMeasurements: Flow = + submitBatchRequests( + cmmsCreateMeasurementRequests, + BATCH_KINGDOM_MEASUREMENTS_LIMIT, + callBatchCreateMeasurementsRpc, + ) { response: BatchCreateMeasurementsResponse -> + response.measurementsList + } + .flatMapMerge { it.asFlow() } // Set CMMS measurement IDs. val callBatchSetCmmsMeasurementIdsRpc: - suspend (List) -> BatchSetCmmsMeasurementIdsResponse = - { items -> - batchSetCmmsMeasurementIds(principal.resourceKey.measurementConsumerId, items) - } + suspend (List) -> BatchSetCmmsMeasurementIdsResponse = + { items -> + batchSetCmmsMeasurementIds(principal.resourceKey.measurementConsumerId, items) + } submitBatchRequests( - cmmsMeasurements.map { - measurementIds { - cmmsCreateMeasurementRequestId = it.measurementReferenceId - cmmsMeasurementId = MeasurementKey.fromName(it.name)!!.measurementId - } - }, - BATCH_SET_CMMS_MEASUREMENT_IDS_LIMIT, - callBatchSetCmmsMeasurementIdsRpc, - ) { response: BatchSetCmmsMeasurementIdsResponse -> - response.measurementsList - } - .collect {} + cmmsMeasurements.map { + measurementIds { + cmmsCreateMeasurementRequestId = it.measurementReferenceId + cmmsMeasurementId = MeasurementKey.fromName(it.name)!!.measurementId + } + }, + BATCH_SET_CMMS_MEASUREMENT_IDS_LIMIT, + callBatchSetCmmsMeasurementIdsRpc, + ) { response: BatchSetCmmsMeasurementIdsResponse -> + response.measurementsList + } + .collect {} } /** Sets a batch of CMMS [MeasurementIds] to the [InternalMeasurement] table. */ private suspend fun batchSetCmmsMeasurementIds( - cmmsMeasurementConsumerId: String, - measurementIds: List, + cmmsMeasurementConsumerId: String, + measurementIds: List, ): BatchSetCmmsMeasurementIdsResponse { return try { internalMeasurementsStub.batchSetCmmsMeasurementIds( - batchSetCmmsMeasurementIdsRequest { - this.cmmsMeasurementConsumerId = cmmsMeasurementConsumerId - this.measurementIds += measurementIds - } - ) + batchSetCmmsMeasurementIdsRequest { + this.cmmsMeasurementConsumerId = cmmsMeasurementConsumerId + this.measurementIds += measurementIds + }) } catch (e: StatusException) { throw Exception("Unable to set the CMMS measurement IDs for the measurements.", e) } @@ -424,49 +435,49 @@ class MetricsService( /** Batch create CMMS measurements. */ private suspend fun batchCreateCmmsMeasurements( - principal: MeasurementConsumerPrincipal, - createMeasurementRequests: List, + principal: MeasurementConsumerPrincipal, + createMeasurementRequests: List, ): BatchCreateMeasurementsResponse { try { return measurementsStub - .withAuthenticationKey(principal.config.apiKey) - .batchCreateMeasurements( - batchCreateMeasurementsRequest { - parent = principal.resourceKey.toName() - requests += createMeasurementRequests - } - ) + .withAuthenticationKey(principal.config.apiKey) + .batchCreateMeasurements( + batchCreateMeasurementsRequest { + parent = principal.resourceKey.toName() + requests += createMeasurementRequests + }) } catch (e: StatusException) { throw when (e.status.code) { - Status.Code.INVALID_ARGUMENT -> - Status.INVALID_ARGUMENT.withDescription("Required field unspecified or invalid.") - Status.Code.PERMISSION_DENIED -> - Status.PERMISSION_DENIED.withDescription( - "Cannot create CMMS Measurements for another MeasurementConsumer." - ) - Status.Code.FAILED_PRECONDITION -> - Status.FAILED_PRECONDITION.withDescription("Failed precondition.") - Status.Code.NOT_FOUND -> - Status.NOT_FOUND.withDescription("${principal.resourceKey.toName()} is not found.") - else -> Status.UNKNOWN.withDescription("Unable to create CMMS Measurements.") - } - .withCause(e) - .asRuntimeException() + Status.Code.INVALID_ARGUMENT -> + Status.INVALID_ARGUMENT.withDescription("Required field unspecified or invalid.") + Status.Code.PERMISSION_DENIED -> + Status.PERMISSION_DENIED.withDescription( + "Cannot create CMMS Measurements for another MeasurementConsumer.") + Status.Code.FAILED_PRECONDITION -> + Status.FAILED_PRECONDITION.withDescription("Failed precondition.") + Status.Code.NOT_FOUND -> + Status.NOT_FOUND.withDescription( + "${principal.resourceKey.toName()} is not found.") + else -> Status.UNKNOWN.withDescription("Unable to create CMMS Measurements.") + } + .withCause(e) + .asRuntimeException() } } /** Builds a CMMS [CreateMeasurementRequest]. */ private fun buildCreateMeasurementRequest( - internalMeasurement: InternalMeasurement, - metricSpec: InternalMetricSpec, - internalPrimitiveReportingSetMap: Map, - measurementConsumer: MeasurementConsumer, - principal: MeasurementConsumerPrincipal, - dataProviderInfoMap: Map, - measurementConsumerSigningKey: SigningKeyHandle, + internalMeasurement: InternalMeasurement, + metricSpec: InternalMetricSpec, + internalPrimitiveReportingSetMap: Map, + measurementConsumer: MeasurementConsumer, + principal: MeasurementConsumerPrincipal, + dataProviderInfoMap: Map, + measurementConsumerSigningKey: SigningKeyHandle, ): CreateMeasurementRequest { val eventGroupEntriesByDataProvider = - groupEventGroupEntriesByDataProvider(internalMeasurement, internalPrimitiveReportingSetMap) + groupEventGroupEntriesByDataProvider( + internalMeasurement, internalPrimitiveReportingSetMap) val packedMeasurementEncryptionPublicKey = measurementConsumer.publicKey.message return createMeasurementRequest { @@ -475,22 +486,22 @@ class MetricsService( measurementConsumerCertificate = principal.config.signingCertificateName dataProviders += - buildDataProviderEntries( - eventGroupEntriesByDataProvider, - packedMeasurementEncryptionPublicKey, - measurementConsumerSigningKey, - dataProviderInfoMap, - ) + buildDataProviderEntries( + eventGroupEntriesByDataProvider, + packedMeasurementEncryptionPublicKey, + measurementConsumerSigningKey, + dataProviderInfoMap, + ) val unsignedMeasurementSpec: MeasurementSpec = - buildUnsignedMeasurementSpec( - packedMeasurementEncryptionPublicKey, - dataProviders.map { it.value.nonceHash }, - metricSpec, - ) + buildUnsignedMeasurementSpec( + packedMeasurementEncryptionPublicKey, + dataProviders.map { it.value.nonceHash }, + metricSpec, + ) measurementSpec = - signMeasurementSpec(unsignedMeasurementSpec, measurementConsumerSigningKey) + signMeasurementSpec(unsignedMeasurementSpec, measurementConsumerSigningKey) // To help map reporting measurements to cmms measurements. measurementReferenceId = internalMeasurement.cmmsCreateMeasurementRequestId } @@ -500,26 +511,26 @@ class MetricsService( /** Gets a [SigningKeyHandle] for a [MeasurementConsumerPrincipal]. */ private suspend fun getMeasurementConsumerSigningKey( - principal: MeasurementConsumerPrincipal + principal: MeasurementConsumerPrincipal ): SigningKeyHandle { // TODO: Factor this out to a separate class similar to EncryptionKeyPairStore. val signingPrivateKeyDer: ByteString = - withContext(keyReaderContext) { - signingPrivateKeyDir.resolve(principal.config.signingPrivateKeyPath).readByteString() - } + withContext(keyReaderContext) { + signingPrivateKeyDir.resolve(principal.config.signingPrivateKeyPath).readByteString() + } val measurementConsumerCertificate: X509Certificate = - readCertificate(getSigningCertificateDer(principal)) + readCertificate(getSigningCertificateDer(principal)) val signingPrivateKey: PrivateKey = - readPrivateKey(signingPrivateKeyDer, measurementConsumerCertificate.publicKey.algorithm) + readPrivateKey(signingPrivateKeyDer, measurementConsumerCertificate.publicKey.algorithm) return SigningKeyHandle(measurementConsumerCertificate, signingPrivateKey) } /** Builds an unsigned [MeasurementSpec]. */ private fun buildUnsignedMeasurementSpec( - packedMeasurementEncryptionPublicKey: ProtoAny, - nonceHashes: List, - metricSpec: InternalMetricSpec, + packedMeasurementEncryptionPublicKey: ProtoAny, + nonceHashes: List, + metricSpec: InternalMetricSpec, ): MeasurementSpec { return measurementSpec { measurementPublicKey = packedMeasurementEncryptionPublicKey @@ -543,9 +554,9 @@ class MetricsService( population = MeasurementSpec.Population.getDefaultInstance() } InternalMetricSpec.TypeCase.TYPE_NOT_SET -> - failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) { - "Unset metric type should've already raised error." - } + failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) { + "Unset metric type should've already raised error." + } } vidSamplingInterval = metricSpec.vidSamplingInterval.toCmmsVidSamplingInterval() // TODO(@jojijac0b): Add modelLine @@ -554,8 +565,8 @@ class MetricsService( /** Builds a [Map] of [DataProvider] name to [DataProviderInfo]. */ private suspend fun buildDataProviderInfoMap( - apiAuthenticationKey: String, - dataProviderNames: Collection, + apiAuthenticationKey: String, + dataProviderNames: Collection, ): Map { val dataProviderInfoMap = mutableMapOf() @@ -567,55 +578,48 @@ class MetricsService( coroutineScope { for (dataProviderName in dataProviderNames) { deferredDataProviderInfoList.add( - async { - val dataProvider: DataProvider = - dataProviderCache.getValue( - ResourceNameApiAuthenticationKey( - name = dataProviderName, - apiAuthenticationKey = apiAuthenticationKey, - ) - ) - - val certificate = - certificateCache.getValue( - ResourceNameApiAuthenticationKey( - name = dataProvider.certificate, - apiAuthenticationKey = apiAuthenticationKey, - ) - ) - - if ( - certificate.revocationState != - Certificate.RevocationState.REVOCATION_STATE_UNSPECIFIED - ) { - throw Status.FAILED_PRECONDITION.withDescription( - "${certificate.name} revocation state is ${certificate.revocationState}" - ) - .asRuntimeException() - } + async { + val dataProvider: DataProvider = + dataProviderCache.getValue( + ResourceNameApiAuthenticationKey( + name = dataProviderName, + apiAuthenticationKey = apiAuthenticationKey, + )) + + val certificate = + certificateCache.getValue( + ResourceNameApiAuthenticationKey( + name = dataProvider.certificate, + apiAuthenticationKey = apiAuthenticationKey, + )) + + if (certificate.revocationState != + Certificate.RevocationState.REVOCATION_STATE_UNSPECIFIED) { + throw Status.FAILED_PRECONDITION.withDescription( + "${certificate.name} revocation state is ${certificate.revocationState}") + .asRuntimeException() + } - val x509Certificate: X509Certificate = readCertificate(certificate.x509Der) - val trustedIssuer: X509Certificate = - trustedCertificates[checkNotNull(x509Certificate.authorityKeyIdentifier)] - ?: throw Status.FAILED_PRECONDITION.withDescription( - "${certificate.name} not issued by trusted CA" - ) - .asRuntimeException() - try { - verifyEncryptionPublicKey(dataProvider.publicKey, x509Certificate, trustedIssuer) - } catch (e: CertPathValidatorException) { - throw Status.FAILED_PRECONDITION.withCause(e) - .withDescription("Certificate path for ${certificate.name} is invalid") - .asRuntimeException() - } catch (e: SignatureException) { - throw Status.FAILED_PRECONDITION.withCause(e) - .withDescription("DataProvider public key signature is invalid") - .asRuntimeException() - } + val x509Certificate: X509Certificate = readCertificate(certificate.x509Der) + val trustedIssuer: X509Certificate = + trustedCertificates[checkNotNull(x509Certificate.authorityKeyIdentifier)] + ?: throw Status.FAILED_PRECONDITION.withDescription( + "${certificate.name} not issued by trusted CA") + .asRuntimeException() + try { + verifyEncryptionPublicKey(dataProvider.publicKey, x509Certificate, trustedIssuer) + } catch (e: CertPathValidatorException) { + throw Status.FAILED_PRECONDITION.withCause(e) + .withDescription("Certificate path for ${certificate.name} is invalid") + .asRuntimeException() + } catch (e: SignatureException) { + throw Status.FAILED_PRECONDITION.withCause(e) + .withDescription("DataProvider public key signature is invalid") + .asRuntimeException() + } - DataProviderInfo(dataProvider.name, dataProvider.publicKey, certificate.name) - } - ) + DataProviderInfo(dataProvider.name, dataProvider.publicKey, certificate.name) + }) } for (deferredDataProviderInfo in deferredDataProviderInfoList.awaitAll()) { @@ -631,10 +635,10 @@ class MetricsService( * [eventGroupEntriesByDataProvider]. */ private fun buildDataProviderEntries( - eventGroupEntriesByDataProvider: Map>, - packedMeasurementEncryptionPublicKey: ProtoAny, - measurementConsumerSigningKey: SigningKeyHandle, - dataProviderInfoMap: Map, + eventGroupEntriesByDataProvider: Map>, + packedMeasurementEncryptionPublicKey: ProtoAny, + measurementConsumerSigningKey: SigningKeyHandle, + dataProviderInfoMap: Map, ): List { return eventGroupEntriesByDataProvider.map { (dataProviderKey, eventGroupEntriesList) -> val dataProviderName: String = dataProviderKey.toName() @@ -646,20 +650,20 @@ class MetricsService( nonce = secureRandom.nextLong() } val encryptRequisitionSpec = - encryptRequisitionSpec( - signRequisitionSpec(requisitionSpec, measurementConsumerSigningKey), - dataProviderInfo.publicKey.unpack(), - ) + encryptRequisitionSpec( + signRequisitionSpec(requisitionSpec, measurementConsumerSigningKey), + dataProviderInfo.publicKey.unpack(), + ) dataProviderEntry { key = dataProviderName value = - MeasurementKt.DataProviderEntryKt.value { - dataProviderCertificate = dataProviderInfo.certificateName - dataProviderPublicKey = dataProviderInfo.publicKey.message - this.encryptedRequisitionSpec = encryptRequisitionSpec - nonceHash = Hashing.hashSha256(requisitionSpec.nonce) - } + MeasurementKt.DataProviderEntryKt.value { + dataProviderCertificate = dataProviderInfo.certificateName + dataProviderPublicKey = dataProviderInfo.publicKey.message + this.encryptedRequisitionSpec = encryptRequisitionSpec + nonceHash = Hashing.hashSha256(requisitionSpec.nonce) + } } } } @@ -669,42 +673,44 @@ class MetricsService( * grouping them by DataProvider. */ private fun groupEventGroupEntriesByDataProvider( - measurement: InternalMeasurement, - internalPrimitiveReportingSetMap: Map, + measurement: InternalMeasurement, + internalPrimitiveReportingSetMap: Map, ): Map> { return measurement.primitiveReportingSetBasesList - .flatMap { primitiveReportingSetBasis -> - val internalPrimitiveReportingSet = - internalPrimitiveReportingSetMap.getValue( - primitiveReportingSetBasis.externalReportingSetId - ) - - internalPrimitiveReportingSet.primitive.eventGroupKeysList.map { internalEventGroupKey -> - val cmmsEventGroupKey = - CmmsEventGroupKey( - internalEventGroupKey.cmmsDataProviderId, - internalEventGroupKey.cmmsEventGroupId, - ) - val filtersList = primitiveReportingSetBasis.filtersList.filter { !it.isNullOrBlank() } - val filter: String? = if (filtersList.isEmpty()) null else buildConjunction(filtersList) - - cmmsEventGroupKey to - RequisitionSpecKt.eventGroupEntry { - key = cmmsEventGroupKey.toName() - value = - RequisitionSpecKt.EventGroupEntryKt.value { - collectionInterval = measurement.timeInterval - if (filter != null) { - this.filter = RequisitionSpecKt.eventFilter { expression = filter } - } + .flatMap { primitiveReportingSetBasis -> + val internalPrimitiveReportingSet = + internalPrimitiveReportingSetMap.getValue( + primitiveReportingSetBasis.externalReportingSetId) + + internalPrimitiveReportingSet.primitive.eventGroupKeysList.map { internalEventGroupKey + -> + val cmmsEventGroupKey = + CmmsEventGroupKey( + internalEventGroupKey.cmmsDataProviderId, + internalEventGroupKey.cmmsEventGroupId, + ) + val filtersList = + primitiveReportingSetBasis.filtersList.filter { !it.isNullOrBlank() } + val filter: String? = + if (filtersList.isEmpty()) null else buildConjunction(filtersList) + + cmmsEventGroupKey to + RequisitionSpecKt.eventGroupEntry { + key = cmmsEventGroupKey.toName() + value = + RequisitionSpecKt.EventGroupEntryKt.value { + collectionInterval = measurement.timeInterval + if (filter != null) { + this.filter = RequisitionSpecKt.eventFilter { expression = filter } + } + } } - } + } } - } - .groupBy( - { (cmmsEventGroupKey, _) -> DataProviderKey(cmmsEventGroupKey.dataProviderId) }, - { (_, eventGroupEntry) -> eventGroupEntry }, - ) + .groupBy( + { (cmmsEventGroupKey, _) -> DataProviderKey(cmmsEventGroupKey.dataProviderId) }, + { (_, eventGroupEntry) -> eventGroupEntry }, + ) } /** Combines event group filters. */ @@ -714,64 +720,59 @@ class MetricsService( /** Gets a [MeasurementConsumer] based on a CMMS ID. */ private suspend fun getMeasurementConsumer( - principal: MeasurementConsumerPrincipal + principal: MeasurementConsumerPrincipal ): MeasurementConsumer { return try { measurementConsumersStub - .withAuthenticationKey(principal.config.apiKey) - .getMeasurementConsumer( - getMeasurementConsumerRequest { name = principal.resourceKey.toName() } - ) + .withAuthenticationKey(principal.config.apiKey) + .getMeasurementConsumer( + getMeasurementConsumerRequest { name = principal.resourceKey.toName() }) } catch (e: StatusException) { throw when (e.status.code) { - Status.Code.NOT_FOUND -> - Status.NOT_FOUND.withDescription("${principal.resourceKey.toName()} not found.") - else -> - Status.UNKNOWN.withDescription( - "Unable to retrieve the measurement consumer [${principal.resourceKey.toName()}]." - ) - } - .withCause(e) - .asRuntimeException() + Status.Code.NOT_FOUND -> + Status.NOT_FOUND.withDescription("${principal.resourceKey.toName()} not found.") + else -> + Status.UNKNOWN.withDescription( + "Unable to retrieve the measurement consumer [${principal.resourceKey.toName()}].") + } + .withCause(e) + .asRuntimeException() } } /** Gets a batch of [InternalReportingSet]s. */ private suspend fun batchGetInternalReportingSets( - cmmsMeasurementConsumerId: String, - externalReportingSetIds: List, + cmmsMeasurementConsumerId: String, + externalReportingSetIds: List, ): BatchGetReportingSetsResponse { return try { internalReportingSetsStub.batchGetReportingSets( - batchGetReportingSetsRequest { - this.cmmsMeasurementConsumerId = cmmsMeasurementConsumerId - this.externalReportingSetIds += externalReportingSetIds - } - ) + batchGetReportingSetsRequest { + this.cmmsMeasurementConsumerId = cmmsMeasurementConsumerId + this.externalReportingSetIds += externalReportingSetIds + }) } catch (e: StatusException) { throw when (e.status.code) { - Status.Code.NOT_FOUND -> Status.NOT_FOUND.withDescription("Reporting Set not found.") - else -> - Status.UNKNOWN.withDescription( - "Unable to retrieve ReportingSets used in the requesting metric." - ) - } - .withCause(e) - .asRuntimeException() + Status.Code.NOT_FOUND -> Status.NOT_FOUND.withDescription("Reporting Set not found.") + else -> + Status.UNKNOWN.withDescription( + "Unable to retrieve ReportingSets used in the requesting metric.") + } + .withCause(e) + .asRuntimeException() } } /** Gets a signing certificate x509Der in ByteString. */ private suspend fun getSigningCertificateDer( - principal: MeasurementConsumerPrincipal + principal: MeasurementConsumerPrincipal ): ByteString { val certificate = - certificateCache.getValue( - ResourceNameApiAuthenticationKey( - name = principal.config.signingCertificateName, - apiAuthenticationKey = principal.config.apiKey, - ) - ) + certificateCache.getValue( + ResourceNameApiAuthenticationKey( + name = principal.config.signingCertificateName, + apiAuthenticationKey = principal.config.apiKey, + )) return certificate.x509Der } @@ -782,69 +783,77 @@ class MetricsService( * @return a boolean to indicate whether any [InternalMeasurement] was updated. */ suspend fun syncInternalMeasurements( - internalMeasurements: List, - apiAuthenticationKey: String, - principal: MeasurementConsumerPrincipal, + internalMeasurements: List, + apiAuthenticationKey: String, + principal: MeasurementConsumerPrincipal, ): Boolean { - val newStateToCmmsMeasurements: Map> = - getCmmsMeasurements(internalMeasurements, principal).groupBy { measurement -> - measurement.state + val failedMeasurements: MutableList = mutableListOf() + + // Most Measurements are expected to be SUCCEEDED so SUCCEEDED Measurements will be collected + // via a Flow. + val succeededMeasurements: Flow = flow { + getCmmsMeasurements(internalMeasurements, principal).collect { measurements -> + for (measurement in measurements) { + @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Protobuf enum fields cannot be null. + when (measurement.state) { + Measurement.State.SUCCEEDED -> emit(measurement) + Measurement.State.CANCELLED, + Measurement.State.FAILED -> failedMeasurements.add(measurement) + Measurement.State.COMPUTING, + Measurement.State.AWAITING_REQUISITION_FULFILLMENT -> {} + Measurement.State.STATE_UNSPECIFIED -> + failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) { + "The CMMS measurement state should've been set." + } + Measurement.State.UNRECOGNIZED -> { + failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) { + "Unrecognized CMMS measurement state." + } + } + } + } } + } var anyUpdate = false - for ((newState, measurementsList) in newStateToCmmsMeasurements) { - when (newState) { - Measurement.State.SUCCEEDED -> { - val callBatchSetInternalMeasurementResultsRpc: - suspend (List) -> BatchSetCmmsMeasurementResultsResponse = - { items -> - batchSetInternalMeasurementResults(items, apiAuthenticationKey, principal) - } - submitBatchRequests( - measurementsList, - BATCH_SET_MEASUREMENT_RESULTS_LIMIT, - callBatchSetInternalMeasurementResultsRpc, - ) { response: BatchSetCmmsMeasurementResultsResponse -> - response.measurementsList - } - .collect {} - - anyUpdate = true + val callBatchSetInternalMeasurementResultsRpc: + suspend (List) -> BatchSetCmmsMeasurementResultsResponse = + { items -> + batchSetInternalMeasurementResults(items, apiAuthenticationKey, principal) } - Measurement.State.AWAITING_REQUISITION_FULFILLMENT, - Measurement.State.COMPUTING -> {} // Do nothing. - Measurement.State.FAILED, - Measurement.State.CANCELLED -> { - val callBatchSetInternalMeasurementFailuresRpc: - suspend (List) -> BatchSetCmmsMeasurementFailuresResponse = - { items -> - batchSetInternalMeasurementFailures( - items, - principal.resourceKey.measurementConsumerId, - ) - } - submitBatchRequests( - measurementsList, - BATCH_SET_MEASUREMENT_FAILURES_LIMIT, - callBatchSetInternalMeasurementFailuresRpc, - ) { response: BatchSetCmmsMeasurementFailuresResponse -> - response.measurementsList - } - .collect {} - - anyUpdate = true + val count = submitBatchRequests( + succeededMeasurements, + BATCH_SET_MEASUREMENT_RESULTS_LIMIT, + callBatchSetInternalMeasurementResultsRpc, + ) { response: BatchSetCmmsMeasurementResultsResponse -> + response.measurementsList } - Measurement.State.STATE_UNSPECIFIED -> - failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) { - "The CMMS measurement state should've been set." - } - Measurement.State.UNRECOGNIZED -> { - failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) { - "Unrecognized CMMS measurement state." - } + .count() + + if (count > 0) { + anyUpdate = true + } + + if (failedMeasurements.isNotEmpty()) { + val callBatchSetInternalMeasurementFailuresRpc: + suspend (List) -> BatchSetCmmsMeasurementFailuresResponse = + { items -> + batchSetInternalMeasurementFailures( + items, + principal.resourceKey.measurementConsumerId, + ) } + submitBatchRequests( + failedMeasurements.asFlow(), + BATCH_SET_MEASUREMENT_FAILURES_LIMIT, + callBatchSetInternalMeasurementFailuresRpc, + ) { response: BatchSetCmmsMeasurementFailuresResponse -> + response.measurementsList } + .collect {} + + anyUpdate = true } return anyUpdate @@ -855,24 +864,23 @@ class MetricsService( * failed or canceled CMMS [Measurement]s. */ private suspend fun batchSetInternalMeasurementFailures( - failedMeasurementsList: List, - cmmsMeasurementConsumerId: String, + failedMeasurementsList: List, + cmmsMeasurementConsumerId: String, ): BatchSetCmmsMeasurementFailuresResponse { val batchSetInternalMeasurementFailuresRequest = batchSetMeasurementFailuresRequest { this.cmmsMeasurementConsumerId = cmmsMeasurementConsumerId measurementFailures += - failedMeasurementsList.map { measurement -> - measurementFailure { - cmmsMeasurementId = MeasurementKey.fromName(measurement.name)!!.measurementId - failure = measurement.failure.toInternal() + failedMeasurementsList.map { measurement -> + measurementFailure { + cmmsMeasurementId = MeasurementKey.fromName(measurement.name)!!.measurementId + failure = measurement.failure.toInternal() + } } - } } return try { internalMeasurementsStub.batchSetMeasurementFailures( - batchSetInternalMeasurementFailuresRequest - ) + batchSetInternalMeasurementFailuresRequest) } catch (e: StatusException) { throw Exception("Unable to set measurement failures for Measurements.", e) } @@ -883,20 +891,20 @@ class MetricsService( * given succeeded CMMS [Measurement]s. */ private suspend fun batchSetInternalMeasurementResults( - succeededMeasurementsList: List, - apiAuthenticationKey: String, - principal: MeasurementConsumerPrincipal, + succeededMeasurementsList: List, + apiAuthenticationKey: String, + principal: MeasurementConsumerPrincipal, ): BatchSetCmmsMeasurementResultsResponse { val batchSetMeasurementResultsRequest = batchSetMeasurementResultsRequest { cmmsMeasurementConsumerId = principal.resourceKey.measurementConsumerId measurementResults += - succeededMeasurementsList.map { measurement -> - buildInternalMeasurementResult( - measurement, - apiAuthenticationKey, - principal.resourceKey.toName(), - ) - } + succeededMeasurementsList.map { measurement -> + buildInternalMeasurementResult( + measurement, + apiAuthenticationKey, + principal.resourceKey.toName(), + ) + } } return try { @@ -908,135 +916,136 @@ class MetricsService( /** Retrieves [Measurement]s from the CMMS. */ private suspend fun getCmmsMeasurements( - internalMeasurements: List, - principal: MeasurementConsumerPrincipal, - ): List { - val measurementNames: List = - internalMeasurements - .map { internalMeasurement -> - MeasurementKey( - principal.resourceKey.measurementConsumerId, - internalMeasurement.cmmsMeasurementId, - ) - .toName() + internalMeasurements: List, + principal: MeasurementConsumerPrincipal, + ): Flow> { + val measurementNames: Flow = flow { + buildSet { + for (internalMeasurement in internalMeasurements) { + val name = + MeasurementKey( + principal.resourceKey.measurementConsumerId, + internalMeasurement.cmmsMeasurementId, + ) + .toName() + + if (!contains(name)) { + emit(name) + add(name) + } } - .distinct() + } + } val callBatchGetMeasurementsRpc: suspend (List) -> BatchGetMeasurementsResponse = - { items -> - batchGetCmmsMeasurements(principal, items) - } + { items -> + batchGetCmmsMeasurements(principal, items) + } return submitBatchRequests( measurementNames, BATCH_KINGDOM_MEASUREMENTS_LIMIT, callBatchGetMeasurementsRpc, - ) { response: BatchGetMeasurementsResponse -> - response.measurementsList - } - .toList() - .flatten() + ) { response: BatchGetMeasurementsResponse -> + response.measurementsList + } } /** Batch get CMMS measurements. */ private suspend fun batchGetCmmsMeasurements( - principal: MeasurementConsumerPrincipal, - measurementNames: List, + principal: MeasurementConsumerPrincipal, + measurementNames: List, ): BatchGetMeasurementsResponse { try { return measurementsStub - .withAuthenticationKey(principal.config.apiKey) - .batchGetMeasurements( - batchGetMeasurementsRequest { - parent = principal.resourceKey.toName() - names += measurementNames - } - ) + .withAuthenticationKey(principal.config.apiKey) + .batchGetMeasurements( + batchGetMeasurementsRequest { + parent = principal.resourceKey.toName() + names += measurementNames + }) } catch (e: StatusException) { throw when (e.status.code) { - Status.Code.NOT_FOUND -> Status.NOT_FOUND.withDescription("Measurements not found.") - Status.Code.PERMISSION_DENIED -> - Status.PERMISSION_DENIED.withDescription( - "Doesn't have permission to get measurements." - ) - else -> Status.UNKNOWN.withDescription("Unable to retrieve Measurements.") - } - .withCause(e) - .asRuntimeException() + Status.Code.NOT_FOUND -> Status.NOT_FOUND.withDescription("Measurements not found.") + Status.Code.PERMISSION_DENIED -> + Status.PERMISSION_DENIED.withDescription( + "Doesn't have permission to get measurements.") + else -> Status.UNKNOWN.withDescription("Unable to retrieve Measurements.") + } + .withCause(e) + .asRuntimeException() } } /** Builds an [InternalMeasurement.Result]. */ private suspend fun buildInternalMeasurementResult( - measurement: Measurement, - apiAuthenticationKey: String, - principalName: String, + measurement: Measurement, + apiAuthenticationKey: String, + principalName: String, ): BatchSetMeasurementResultsRequest.MeasurementResult { val measurementSpec: MeasurementSpec = measurement.measurementSpec.unpack() val encryptionPrivateKeyHandle = - encryptionKeyPairStore.getPrivateKeyHandle( - principalName, - measurementSpec.measurementPublicKey.unpack().data, - ) - ?: failGrpc(Status.FAILED_PRECONDITION) { - "Encryption private key not found for the measurement ${measurement.name}." - } + encryptionKeyPairStore.getPrivateKeyHandle( + principalName, + measurementSpec.measurementPublicKey.unpack().data, + ) + ?: failGrpc(Status.FAILED_PRECONDITION) { + "Encryption private key not found for the measurement ${measurement.name}." + } val decryptedMeasurementResults: List = - measurement.resultsList.map { - decryptMeasurementResultOutput(it, encryptionPrivateKeyHandle, apiAuthenticationKey) - } + measurement.resultsList.map { + decryptMeasurementResultOutput(it, encryptionPrivateKeyHandle, apiAuthenticationKey) + } return measurementResult { cmmsMeasurementId = MeasurementKey.fromName(measurement.name)!!.measurementId results += - decryptedMeasurementResults.map { - try { - it.toInternal(measurement.protocolConfig) - } catch (e: NoiseMechanismUnrecognizedException) { - failGrpc(Status.UNKNOWN) { - listOfNotNull("Unrecognized noise mechanism.", e.message, e.cause?.message) - .joinToString(separator = "\n") - } - } catch (e: Throwable) { - failGrpc(Status.UNKNOWN) { - listOfNotNull("Unable to read measurement result.", e.message, e.cause?.message) - .joinToString(separator = "\n") + decryptedMeasurementResults.map { + try { + it.toInternal(measurement.protocolConfig) + } catch (e: NoiseMechanismUnrecognizedException) { + failGrpc(Status.UNKNOWN) { + listOfNotNull("Unrecognized noise mechanism.", e.message, e.cause?.message) + .joinToString(separator = "\n") + } + } catch (e: Throwable) { + failGrpc(Status.UNKNOWN) { + listOfNotNull("Unable to read measurement result.", e.message, e.cause?.message) + .joinToString(separator = "\n") + } } } - } } } /** Decrypts a [Measurement.ResultOutput] to [Measurement.Result] */ private suspend fun decryptMeasurementResultOutput( - measurementResultOutput: Measurement.ResultOutput, - encryptionPrivateKeyHandle: PrivateKeyHandle, - apiAuthenticationKey: String, + measurementResultOutput: Measurement.ResultOutput, + encryptionPrivateKeyHandle: PrivateKeyHandle, + apiAuthenticationKey: String, ): Measurement.Result { val certificate = - certificateCache.getValue( - ResourceNameApiAuthenticationKey( - name = measurementResultOutput.certificate, - apiAuthenticationKey = apiAuthenticationKey, - ) - ) + certificateCache.getValue( + ResourceNameApiAuthenticationKey( + name = measurementResultOutput.certificate, + apiAuthenticationKey = apiAuthenticationKey, + )) val signedResult = - decryptResult(measurementResultOutput.encryptedResult, encryptionPrivateKeyHandle) + decryptResult(measurementResultOutput.encryptedResult, encryptionPrivateKeyHandle) if (certificate.revocationState != Certificate.RevocationState.REVOCATION_STATE_UNSPECIFIED) { throw Status.FAILED_PRECONDITION.withDescription( - "${certificate.name} revocation state is ${certificate.revocationState}" - ) - .asRuntimeException() + "${certificate.name} revocation state is ${certificate.revocationState}") + .asRuntimeException() } val x509Certificate: X509Certificate = readCertificate(certificate.x509Der) val trustedIssuer: X509Certificate = - checkNotNull(trustedCertificates[checkNotNull(x509Certificate.authorityKeyIdentifier)]) { - "${certificate.name} not issued by trusted CA" - } + checkNotNull(trustedCertificates[checkNotNull(x509Certificate.authorityKeyIdentifier)]) { + "${certificate.name} not issued by trusted CA" + } // TODO: Record verification failure in internal Measurement rather than having the RPC fail. try { @@ -1060,16 +1069,16 @@ class MetricsService( private suspend fun getCertificate(name: String, apiAuthenticationKey: String): Certificate { return try { certificatesStub - .withAuthenticationKey(apiAuthenticationKey) - .getCertificate(getCertificateRequest { this.name = name }) + .withAuthenticationKey(apiAuthenticationKey) + .getCertificate(getCertificateRequest { this.name = name }) } catch (e: StatusException) { throw when (e.status.code) { - Status.Code.NOT_FOUND -> - Status.FAILED_PRECONDITION.withDescription("Certificate $name not found.") - else -> Status.UNKNOWN.withDescription("Unable to retrieve Certificate $name.") - } - .withCause(e) - .asRuntimeException() + Status.Code.NOT_FOUND -> + Status.FAILED_PRECONDITION.withDescription("Certificate $name not found.") + else -> Status.UNKNOWN.withDescription("Unable to retrieve Certificate $name.") + } + .withCause(e) + .asRuntimeException() } } @@ -1084,24 +1093,24 @@ class MetricsService( private suspend fun getDataProvider(name: String, apiAuthenticationKey: String): DataProvider { return try { dataProvidersStub - .withAuthenticationKey(apiAuthenticationKey) - .getDataProvider(getDataProviderRequest { this.name = name }) + .withAuthenticationKey(apiAuthenticationKey) + .getDataProvider(getDataProviderRequest { this.name = name }) } catch (e: StatusException) { throw when (e.status.code) { - Status.Code.NOT_FOUND -> Status.FAILED_PRECONDITION.withDescription("$name not found") - else -> Status.UNKNOWN.withDescription("Unable to retrieve $name") - } - .withCause(e) - .asRuntimeException() + Status.Code.NOT_FOUND -> Status.FAILED_PRECONDITION.withDescription("$name not found") + else -> Status.UNKNOWN.withDescription("Unable to retrieve $name") + } + .withCause(e) + .asRuntimeException() } } } override suspend fun getMetric(request: GetMetricRequest): Metric { val metricKey = - grpcRequireNotNull(MetricKey.fromName(request.name)) { - "Metric name is either unspecified or invalid." - } + grpcRequireNotNull(MetricKey.fromName(request.name)) { + "Metric name is either unspecified or invalid." + } val principal: ReportingPrincipal = principalFromCurrentContext when (principal) { @@ -1115,7 +1124,7 @@ class MetricsService( } val internalMetric: InternalMetric = - getInternalMetric(metricKey.cmmsMeasurementConsumerId, metricKey.metricId) + getInternalMetric(metricKey.cmmsMeasurementConsumerId, metricKey.metricId) // Early exit when the metric is at a terminal state. if (internalMetric.state != Metric.State.RUNNING) { @@ -1124,18 +1133,18 @@ class MetricsService( // Only syncs pending measurements which can only be in metrics that are still running. val toBeSyncedInternalMeasurements: List = - internalMetric.weightedMeasurementsList - .map { weightedMeasurement -> weightedMeasurement.measurement } - .filter { internalMeasurement -> - internalMeasurement.state == InternalMeasurement.State.PENDING - } + internalMetric.weightedMeasurementsList + .map { weightedMeasurement -> weightedMeasurement.measurement } + .filter { internalMeasurement -> + internalMeasurement.state == InternalMeasurement.State.PENDING + } val anyMeasurementUpdated: Boolean = - measurementSupplier.syncInternalMeasurements( - toBeSyncedInternalMeasurements, - principal.config.apiKey, - principal, - ) + measurementSupplier.syncInternalMeasurements( + toBeSyncedInternalMeasurements, + principal.config.apiKey, + principal, + ) return if (anyMeasurementUpdated) { getInternalMetric(metricKey.cmmsMeasurementConsumerId, metricKey.metricId).toMetric(variances) @@ -1146,9 +1155,9 @@ class MetricsService( override suspend fun batchGetMetrics(request: BatchGetMetricsRequest): BatchGetMetricsResponse { val parentKey = - grpcRequireNotNull(MeasurementConsumerKey.fromName(request.parent)) { - "Parent is either unspecified or invalid." - } + grpcRequireNotNull(MeasurementConsumerKey.fromName(request.parent)) { + "Parent is either unspecified or invalid." + } val principal: ReportingPrincipal = principalFromCurrentContext @@ -1168,55 +1177,55 @@ class MetricsService( } val metricIds: List = - request.namesList.map { metricName -> - val metricKey = - grpcRequireNotNull(MetricKey.fromName(metricName)) { - "Metric name is either unspecified or invalid." - } - metricKey.metricId - } + request.namesList.map { metricName -> + val metricKey = + grpcRequireNotNull(MetricKey.fromName(metricName)) { + "Metric name is either unspecified or invalid." + } + metricKey.metricId + } val internalMetrics: List = - batchGetInternalMetrics(principal.resourceKey.measurementConsumerId, metricIds) + batchGetInternalMetrics(principal.resourceKey.measurementConsumerId, metricIds) // Only syncs pending measurements which can only be in metrics that are still running. val toBeSyncedInternalMeasurements: List = - internalMetrics - .filter { internalMetric -> internalMetric.state == Metric.State.RUNNING } - .flatMap { internalMetric -> internalMetric.weightedMeasurementsList } - .map { weightedMeasurement -> weightedMeasurement.measurement } - .filter { internalMeasurement -> - internalMeasurement.state == InternalMeasurement.State.PENDING - } + internalMetrics + .filter { internalMetric -> internalMetric.state == Metric.State.RUNNING } + .flatMap { internalMetric -> internalMetric.weightedMeasurementsList } + .map { weightedMeasurement -> weightedMeasurement.measurement } + .filter { internalMeasurement -> + internalMeasurement.state == InternalMeasurement.State.PENDING + } val anyMeasurementUpdated: Boolean = - measurementSupplier.syncInternalMeasurements( - toBeSyncedInternalMeasurements, - principal.config.apiKey, - principal, - ) + measurementSupplier.syncInternalMeasurements( + toBeSyncedInternalMeasurements, + principal.config.apiKey, + principal, + ) return batchGetMetricsResponse { metrics += - /** - * TODO(@riemanli): a potential improvement can be done by only getting the metrics whose - * measurements are updated. Re-evaluate when a load-test is ready after deployment. - */ - if (anyMeasurementUpdated) { - batchGetInternalMetrics(principal.resourceKey.measurementConsumerId, metricIds).map { - it.toMetric(variances) + /** + * TODO(@riemanli): a potential improvement can be done by only getting the metrics whose + * measurements are updated. Re-evaluate when a load-test is ready after deployment. + */ + if (anyMeasurementUpdated) { + batchGetInternalMetrics(principal.resourceKey.measurementConsumerId, metricIds).map { + it.toMetric(variances) + } + } else { + internalMetrics.map { it.toMetric(variances) } } - } else { - internalMetrics.map { it.toMetric(variances) } - } } } override suspend fun listMetrics(request: ListMetricsRequest): ListMetricsResponse { val parentKey = - grpcRequireNotNull(MeasurementConsumerKey.fromName(request.parent)) { - "Parent is either unspecified or invalid." - } + grpcRequireNotNull(MeasurementConsumerKey.fromName(request.parent)) { + "Parent is either unspecified or invalid." + } val principal: ReportingPrincipal = principalFromCurrentContext when (principal) { @@ -1233,50 +1242,50 @@ class MetricsService( val apiAuthenticationKey: String = principal.config.apiKey val streamInternalMetricRequest: StreamMetricsRequest = - listMetricsPageToken.toStreamMetricsRequest() + listMetricsPageToken.toStreamMetricsRequest() val results: List = - try { - internalMetricsStub.streamMetrics(streamInternalMetricRequest).toList() - } catch (e: StatusException) { - throw Exception("Unable to list Metrics.", e) - } + try { + internalMetricsStub.streamMetrics(streamInternalMetricRequest).toList() + } catch (e: StatusException) { + throw Exception("Unable to list Metrics.", e) + } if (results.isEmpty()) { return ListMetricsResponse.getDefaultInstance() } val nextPageToken: ListMetricsPageToken? = - if (results.size > listMetricsPageToken.pageSize) { - listMetricsPageToken.copy { - lastMetric = previousPageEnd { - cmmsMeasurementConsumerId = results[results.lastIndex - 1].cmmsMeasurementConsumerId - externalMetricId = results[results.lastIndex - 1].externalMetricId + if (results.size > listMetricsPageToken.pageSize) { + listMetricsPageToken.copy { + lastMetric = previousPageEnd { + cmmsMeasurementConsumerId = results[results.lastIndex - 1].cmmsMeasurementConsumerId + externalMetricId = results[results.lastIndex - 1].externalMetricId + } } + } else { + null } - } else { - null - } val subResults: List = - results.subList(0, min(results.size, listMetricsPageToken.pageSize)) + results.subList(0, min(results.size, listMetricsPageToken.pageSize)) // Only syncs pending measurements which can only be in metrics that are still running. val toBeSyncedInternalMeasurements: List = - subResults - .filter { internalMetric -> internalMetric.state == Metric.State.RUNNING } - .flatMap { internalMetric -> internalMetric.weightedMeasurementsList } - .map { weightedMeasurement -> weightedMeasurement.measurement } - .filter { internalMeasurement -> - internalMeasurement.state == InternalMeasurement.State.PENDING - } + subResults + .filter { internalMetric -> internalMetric.state == Metric.State.RUNNING } + .flatMap { internalMetric -> internalMetric.weightedMeasurementsList } + .map { weightedMeasurement -> weightedMeasurement.measurement } + .filter { internalMeasurement -> + internalMeasurement.state == InternalMeasurement.State.PENDING + } val anyMeasurementUpdated: Boolean = - measurementSupplier.syncInternalMeasurements( - toBeSyncedInternalMeasurements, - apiAuthenticationKey, - principal, - ) + measurementSupplier.syncInternalMeasurements( + toBeSyncedInternalMeasurements, + apiAuthenticationKey, + principal, + ) /** * If any measurement got updated, pull the list of the up-to-date internal metrics. Otherwise, @@ -1286,14 +1295,14 @@ class MetricsService( * measurements are updated. Re-evaluate when a load-test is ready after deployment. */ val internalMetrics: List = - if (anyMeasurementUpdated) { - batchGetInternalMetrics( - principal.resourceKey.measurementConsumerId, - subResults.map { internalMetric -> internalMetric.externalMetricId }, - ) - } else { - subResults - } + if (anyMeasurementUpdated) { + batchGetInternalMetrics( + principal.resourceKey.measurementConsumerId, + subResults.map { internalMetric -> internalMetric.externalMetricId }, + ) + } else { + subResults + } return listMetricsResponse { metrics += internalMetrics.map { it.toMetric(variances) } @@ -1306,8 +1315,8 @@ class MetricsService( /** Gets a batch of [InternalMetric]s. */ private suspend fun batchGetInternalMetrics( - cmmsMeasurementConsumerId: String, - metricIds: List, + cmmsMeasurementConsumerId: String, + metricIds: List, ): List { val batchGetMetricsRequest = batchGetMetricsRequest { this.cmmsMeasurementConsumerId = cmmsMeasurementConsumerId @@ -1323,8 +1332,8 @@ class MetricsService( /** Gets an [InternalMetric]. */ private suspend fun getInternalMetric( - cmmsMeasurementConsumerId: String, - metricId: String, + cmmsMeasurementConsumerId: String, + metricId: String, ): InternalMetric { return try { batchGetInternalMetrics(cmmsMeasurementConsumerId, listOf(metricId)).first() @@ -1336,9 +1345,9 @@ class MetricsService( override suspend fun createMetric(request: CreateMetricRequest): Metric { val parentKey = - grpcRequireNotNull(MeasurementConsumerKey.fromName(request.parent)) { - "Parent is either unspecified or invalid." - } + grpcRequireNotNull(MeasurementConsumerKey.fromName(request.parent)) { + "Parent is either unspecified or invalid." + } val principal: ReportingPrincipal = principalFromCurrentContext @@ -1353,28 +1362,26 @@ class MetricsService( } val internalCreateMetricRequest: InternalCreateMetricRequest = - buildInternalCreateMetricRequest(principal.resourceKey.measurementConsumerId, request) + buildInternalCreateMetricRequest(principal.resourceKey.measurementConsumerId, request) val internalMetric = - try { - internalMetricsStub.createMetric(internalCreateMetricRequest) - } catch (e: StatusException) { - throw when (e.status.code) { - Status.Code.ALREADY_EXISTS -> - Status.ALREADY_EXISTS.withDescription( - "Metric with ID ${request.metricId} already exists under ${request.parent}" - ) - Status.Code.NOT_FOUND -> - Status.NOT_FOUND.withDescription("Reporting set used in the metric not found.") - Status.Code.FAILED_PRECONDITION -> - Status.FAILED_PRECONDITION.withDescription( - "Unable to create the metric. The measurement consumer not found." - ) - else -> Status.UNKNOWN.withDescription("Unable to create Metric.") - } - .withCause(e) - .asRuntimeException() - } + try { + internalMetricsStub.createMetric(internalCreateMetricRequest) + } catch (e: StatusException) { + throw when (e.status.code) { + Status.Code.ALREADY_EXISTS -> + Status.ALREADY_EXISTS.withDescription( + "Metric with ID ${request.metricId} already exists under ${request.parent}") + Status.Code.NOT_FOUND -> + Status.NOT_FOUND.withDescription("Reporting set used in the metric not found.") + Status.Code.FAILED_PRECONDITION -> + Status.FAILED_PRECONDITION.withDescription( + "Unable to create the metric. The measurement consumer not found.") + else -> Status.UNKNOWN.withDescription("Unable to create Metric.") + } + .withCause(e) + .asRuntimeException() + } if (internalMetric.state == Metric.State.RUNNING) { measurementSupplier.createCmmsMeasurements(listOf(internalMetric), principal) @@ -1385,12 +1392,12 @@ class MetricsService( } override suspend fun batchCreateMetrics( - request: BatchCreateMetricsRequest + request: BatchCreateMetricsRequest ): BatchCreateMetricsResponse { val parentKey = - grpcRequireNotNull(MeasurementConsumerKey.fromName(request.parent)) { - "Parent is either unspecified or invalid." - } + grpcRequireNotNull(MeasurementConsumerKey.fromName(request.parent)) { + "Parent is either unspecified or invalid." + } val principal: ReportingPrincipal = principalFromCurrentContext @@ -1415,36 +1422,34 @@ class MetricsService( } val internalCreateMetricRequestsList: List = - request.requestsList.map { createMetricRequest -> - buildInternalCreateMetricRequest(parentKey.measurementConsumerId, createMetricRequest) - } + request.requestsList.map { createMetricRequest -> + buildInternalCreateMetricRequest(parentKey.measurementConsumerId, createMetricRequest) + } val internalMetrics = - try { - internalMetricsStub - .batchCreateMetrics( - internalBatchCreateMetricsRequest { - cmmsMeasurementConsumerId = parentKey.measurementConsumerId - requests += internalCreateMetricRequestsList - } - ) - .metricsList - } catch (e: StatusException) { - throw when (e.status.code) { - Status.Code.NOT_FOUND -> - Status.NOT_FOUND.withDescription("Reporting set used in metrics not found.") - Status.Code.FAILED_PRECONDITION -> - Status.FAILED_PRECONDITION.withDescription( - "Unable to create the metrics. The measurement consumer not found." - ) - else -> Status.UNKNOWN.withDescription("Unable to create Metrics.") - } - .withCause(e) - .asRuntimeException() - } + try { + internalMetricsStub + .batchCreateMetrics( + internalBatchCreateMetricsRequest { + cmmsMeasurementConsumerId = parentKey.measurementConsumerId + requests += internalCreateMetricRequestsList + }) + .metricsList + } catch (e: StatusException) { + throw when (e.status.code) { + Status.Code.NOT_FOUND -> + Status.NOT_FOUND.withDescription("Reporting set used in metrics not found.") + Status.Code.FAILED_PRECONDITION -> + Status.FAILED_PRECONDITION.withDescription( + "Unable to create the metrics. The measurement consumer not found.") + else -> Status.UNKNOWN.withDescription("Unable to create Metrics.") + } + .withCause(e) + .asRuntimeException() + } val internalRunningMetrics = - internalMetrics.filter { internalMetric -> internalMetric.state == Metric.State.RUNNING } + internalMetrics.filter { internalMetric -> internalMetric.state == Metric.State.RUNNING } if (internalRunningMetrics.isNotEmpty()) { measurementSupplier.createCmmsMeasurements(internalRunningMetrics, principal) } @@ -1455,8 +1460,8 @@ class MetricsService( /** Builds an [InternalCreateMetricRequest]. */ private suspend fun buildInternalCreateMetricRequest( - cmmsMeasurementConsumerId: String, - request: CreateMetricRequest, + cmmsMeasurementConsumerId: String, + request: CreateMetricRequest, ): InternalCreateMetricRequest { grpcRequire(request.hasMetric()) { "Metric is not specified." } @@ -1466,34 +1471,31 @@ class MetricsService( } grpcRequire(request.metric.hasTimeInterval()) { "Time interval in metric is not specified." } grpcRequire( - request.metric.timeInterval.startTime.seconds > 0 || - request.metric.timeInterval.startTime.nanos > 0 - ) { - "TimeInterval startTime is unspecified." - } + request.metric.timeInterval.startTime.seconds > 0 || + request.metric.timeInterval.startTime.nanos > 0) { + "TimeInterval startTime is unspecified." + } grpcRequire( - request.metric.timeInterval.endTime.seconds > 0 || - request.metric.timeInterval.endTime.nanos > 0 - ) { - "TimeInterval endTime is unspecified." - } + request.metric.timeInterval.endTime.seconds > 0 || + request.metric.timeInterval.endTime.nanos > 0) { + "TimeInterval endTime is unspecified." + } grpcRequire( - request.metric.timeInterval.endTime.seconds > request.metric.timeInterval.startTime.seconds || - request.metric.timeInterval.endTime.nanos > request.metric.timeInterval.startTime.nanos - ) { - "TimeInterval endTime is not later than startTime." - } + request.metric.timeInterval.endTime.seconds > + request.metric.timeInterval.startTime.seconds || + request.metric.timeInterval.endTime.nanos > + request.metric.timeInterval.startTime.nanos) { + "TimeInterval endTime is not later than startTime." + } grpcRequire(request.metric.hasMetricSpec()) { "Metric spec in metric is not specified." } val internalReportingSet: InternalReportingSet = - getInternalReportingSet(cmmsMeasurementConsumerId, request.metric.reportingSet) + getInternalReportingSet(cmmsMeasurementConsumerId, request.metric.reportingSet) // Utilizes the property of the set expression compilation result -- If the set expression // contains only union operators, the compilation result has to be a single component. - if ( - request.metric.metricSpec.hasReachAndFrequency() && - internalReportingSet.weightedSubsetUnionsList.size != 1 - ) { + if (request.metric.metricSpec.hasReachAndFrequency() && + internalReportingSet.weightedSubsetUnionsList.size != 1) { failGrpc(Status.INVALID_ARGUMENT) { "Reach-and-frequency metrics can only be computed on union-only set expressions." } @@ -1507,22 +1509,22 @@ class MetricsService( externalReportingSetId = internalReportingSet.externalReportingSetId timeInterval = request.metric.timeInterval metricSpec = - try { - request.metric.metricSpec.withDefaults(metricSpecConfig).toInternal() - } catch (e: MetricSpecDefaultsException) { - failGrpc(Status.INVALID_ARGUMENT) { - listOfNotNull("Invalid metric spec.", e.message, e.cause?.message) - .joinToString(separator = "\n") + try { + request.metric.metricSpec.withDefaults(metricSpecConfig).toInternal() + } catch (e: MetricSpecDefaultsException) { + failGrpc(Status.INVALID_ARGUMENT) { + listOfNotNull("Invalid metric spec.", e.message, e.cause?.message) + .joinToString(separator = "\n") + } + } catch (e: Exception) { + failGrpc(Status.UNKNOWN) { "Failed to read the metric spec." } } - } catch (e: Exception) { - failGrpc(Status.UNKNOWN) { "Failed to read the metric spec." } - } weightedMeasurements += - buildInitialInternalMeasurements( - cmmsMeasurementConsumerId, - request.metric, - internalReportingSet, - ) + buildInitialInternalMeasurements( + cmmsMeasurementConsumerId, + request.metric, + internalReportingSet, + ) details = InternalMetricKt.details { filters += request.metric.filtersList } } } @@ -1530,9 +1532,9 @@ class MetricsService( /** Builds [InternalMeasurement]s for a [Metric] over an [InternalReportingSet]. */ private fun buildInitialInternalMeasurements( - cmmsMeasurementConsumerId: String, - metric: Metric, - internalReportingSet: InternalReportingSet, + cmmsMeasurementConsumerId: String, + metric: Metric, + internalReportingSet: InternalReportingSet, ): List { return internalReportingSet.weightedSubsetUnionsList.map { weightedSubsetUnion -> weightedMeasurement { @@ -1542,9 +1544,9 @@ class MetricsService( this.cmmsMeasurementConsumerId = cmmsMeasurementConsumerId timeInterval = metric.timeInterval this.primitiveReportingSetBases += - weightedSubsetUnion.primitiveReportingSetBasesList.map { primitiveReportingSetBasis -> - primitiveReportingSetBasis.copy { filters += metric.filtersList } - } + weightedSubsetUnion.primitiveReportingSetBasesList.map { primitiveReportingSetBasis -> + primitiveReportingSetBasis.copy { filters += metric.filtersList } + } } } } @@ -1552,13 +1554,13 @@ class MetricsService( /** Gets an [InternalReportingSet] based on a reporting set name. */ private suspend fun getInternalReportingSet( - cmmsMeasurementConsumerId: String, - reportingSetName: String, + cmmsMeasurementConsumerId: String, + reportingSetName: String, ): InternalReportingSet { val reportingSetKey = - grpcRequireNotNull(ReportingSetKey.fromName(reportingSetName)) { - "Invalid reporting set name $reportingSetName." - } + grpcRequireNotNull(ReportingSetKey.fromName(reportingSetName)) { + "Invalid reporting set name $reportingSetName." + } if (reportingSetKey.cmmsMeasurementConsumerId != cmmsMeasurementConsumerId) { failGrpc(Status.PERMISSION_DENIED) { "No access to the reporting set [$reportingSetName]." } @@ -1566,18 +1568,17 @@ class MetricsService( return try { internalReportingSetsStub - .batchGetReportingSets( - batchGetReportingSetsRequest { - this.cmmsMeasurementConsumerId = cmmsMeasurementConsumerId - this.externalReportingSetIds += reportingSetKey.reportingSetId - } - ) - .reportingSetsList - .first() + .batchGetReportingSets( + batchGetReportingSetsRequest { + this.cmmsMeasurementConsumerId = cmmsMeasurementConsumerId + this.externalReportingSetIds += reportingSetKey.reportingSetId + }) + .reportingSetsList + .first() } catch (e: StatusException) { throw Exception( - "Unable to retrieve ReportingSet using the provided name [$reportingSetName].", - e, + "Unable to retrieve ReportingSet using the provided name [$reportingSetName].", + e, ) } } @@ -1594,9 +1595,9 @@ fun ListMetricsRequest.toListMetricsPageToken(): ListMetricsPageToken { grpcRequire(source.pageSize >= 0) { "Page size cannot be less than 0." } val parentKey: MeasurementConsumerKey = - grpcRequireNotNull(MeasurementConsumerKey.fromName(source.parent)) { - "Parent is either unspecified or invalid." - } + grpcRequireNotNull(MeasurementConsumerKey.fromName(source.parent)) { + "Parent is either unspecified or invalid." + } val cmmsMeasurementConsumerId = parentKey.measurementConsumerId return if (pageToken.isNotBlank()) { @@ -1612,11 +1613,11 @@ fun ListMetricsRequest.toListMetricsPageToken(): ListMetricsPageToken { } else { listMetricsPageToken { pageSize = - when { - source.pageSize < MIN_PAGE_SIZE -> DEFAULT_PAGE_SIZE - source.pageSize > MAX_PAGE_SIZE -> MAX_PAGE_SIZE - else -> source.pageSize - } + when { + source.pageSize < MIN_PAGE_SIZE -> DEFAULT_PAGE_SIZE + source.pageSize > MAX_PAGE_SIZE -> MAX_PAGE_SIZE + else -> source.pageSize + } this.cmmsMeasurementConsumerId = cmmsMeasurementConsumerId } } @@ -1627,13 +1628,13 @@ private fun InternalMetric.toMetric(variances: Variances): Metric { val source = this return metric { name = - MetricKey( - cmmsMeasurementConsumerId = source.cmmsMeasurementConsumerId, - metricId = source.externalMetricId, - ) - .toName() + MetricKey( + cmmsMeasurementConsumerId = source.cmmsMeasurementConsumerId, + metricId = source.externalMetricId, + ) + .toName() reportingSet = - ReportingSetKey(source.cmmsMeasurementConsumerId, source.externalReportingSetId).toName() + ReportingSetKey(source.cmmsMeasurementConsumerId, source.externalReportingSetId).toName() timeInterval = source.timeInterval metricSpec = source.metricSpec.toMetricSpec() filters += source.details.filtersList @@ -1649,49 +1650,50 @@ private fun InternalMetric.toMetric(variances: Variances): Metric { private fun buildMetricResult(metric: InternalMetric, variances: Variances): MetricResult { return metricResult { cmmsMeasurements += - metric.weightedMeasurementsList.map { - MeasurementKey(metric.cmmsMeasurementConsumerId, it.measurement.cmmsMeasurementId).toName() - } + metric.weightedMeasurementsList.map { + MeasurementKey(metric.cmmsMeasurementConsumerId, it.measurement.cmmsMeasurementId) + .toName() + } @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. when (metric.metricSpec.typeCase) { InternalMetricSpec.TypeCase.REACH -> { reach = - calculateReachResult( - metric.weightedMeasurementsList, - metric.metricSpec.vidSamplingInterval, - metric.metricSpec.reach.privacyParams, - variances, - ) + calculateReachResult( + metric.weightedMeasurementsList, + metric.metricSpec.vidSamplingInterval, + metric.metricSpec.reach.privacyParams, + variances, + ) } InternalMetricSpec.TypeCase.REACH_AND_FREQUENCY -> { reachAndFrequency = reachAndFrequencyResult { reach = - calculateReachResult( - metric.weightedMeasurementsList, - metric.metricSpec.vidSamplingInterval, - metric.metricSpec.reachAndFrequency.reachPrivacyParams, - variances, - ) + calculateReachResult( + metric.weightedMeasurementsList, + metric.metricSpec.vidSamplingInterval, + metric.metricSpec.reachAndFrequency.reachPrivacyParams, + variances, + ) frequencyHistogram = - calculateFrequencyHistogramResults( - metric.weightedMeasurementsList, - metric.metricSpec, - variances, - ) + calculateFrequencyHistogramResults( + metric.weightedMeasurementsList, + metric.metricSpec, + variances, + ) } } InternalMetricSpec.TypeCase.IMPRESSION_COUNT -> { impressionCount = - calculateImpressionResult(metric.weightedMeasurementsList, metric.metricSpec, variances) + calculateImpressionResult(metric.weightedMeasurementsList, metric.metricSpec, variances) } InternalMetricSpec.TypeCase.WATCH_DURATION -> { watchDuration = - calculateWatchDurationResult( - metric.weightedMeasurementsList, - metric.metricSpec, - variances, - ) + calculateWatchDurationResult( + metric.weightedMeasurementsList, + metric.metricSpec, + variances, + ) } InternalMetricSpec.TypeCase.POPULATION_COUNT -> { populationCount = calculatePopulationResult(metric.weightedMeasurementsList) @@ -1707,7 +1709,7 @@ private fun buildMetricResult(metric: InternalMetric, variances: Variances): Met /** Aggregates a list of [InternalMeasurement.Result]s to a [InternalMeasurement.Result] */ private fun aggregateResults( - internalMeasurementResults: List + internalMeasurementResults: List ): InternalMeasurement.Result { if (internalMeasurementResults.isEmpty()) { failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) { @@ -1733,10 +1735,10 @@ private fun aggregateResults( } for ((frequency, percentage) in result.frequency.relativeFrequencyDistributionMap) { val previousTotalReachCount = - frequencyDistribution.getOrDefault(frequency, 0.0) * reachValue + frequencyDistribution.getOrDefault(frequency, 0.0) * reachValue val currentReachCount = percentage * result.reach.value frequencyDistribution[frequency] = - (previousTotalReachCount + currentReachCount) / (reachValue + result.reach.value) + (previousTotalReachCount + currentReachCount) / (reachValue + result.reach.value) } } if (result.hasReach()) { @@ -1759,16 +1761,16 @@ private fun aggregateResults( } if (internalMeasurementResults.first().hasFrequency()) { this.frequency = - InternalMeasurementKt.ResultKt.frequency { - relativeFrequencyDistribution.putAll(frequencyDistribution) - } + InternalMeasurementKt.ResultKt.frequency { + relativeFrequencyDistribution.putAll(frequencyDistribution) + } } if (internalMeasurementResults.first().hasImpression()) { this.impression = InternalMeasurementKt.ResultKt.impression { value = impressionValue } } if (internalMeasurementResults.first().hasWatchDuration()) { this.watchDuration = - InternalMeasurementKt.ResultKt.watchDuration { value = watchDurationValue } + InternalMeasurementKt.ResultKt.watchDuration { value = watchDurationValue } } if (internalMeasurementResults.first().hasPopulation()) { this.population = InternalMeasurementKt.ResultKt.population { value = populationValue } @@ -1778,9 +1780,9 @@ private fun aggregateResults( /** Calculates the watch duration result from [WeightedMeasurement]s. */ private fun calculateWatchDurationResult( - weightedMeasurements: List, - metricSpec: InternalMetricSpec, - variances: Variances, + weightedMeasurements: List, + metricSpec: InternalMetricSpec, + variances: Variances, ): MetricResult.WatchDurationResult { for (weightedMeasurement in weightedMeasurements) { if (weightedMeasurement.measurement.details.resultsList.any { !it.hasWatchDuration() }) { @@ -1791,24 +1793,24 @@ private fun calculateWatchDurationResult( } return watchDurationResult { val watchDuration: ProtoDuration = - weightedMeasurements - .map { weightedMeasurement -> - aggregateResults(weightedMeasurement.measurement.details.resultsList) - .watchDuration - .value * weightedMeasurement.weight - } - .reduce { sum, element -> sum + element } + weightedMeasurements + .map { weightedMeasurement -> + aggregateResults(weightedMeasurement.measurement.details.resultsList) + .watchDuration + .value * weightedMeasurement.weight + } + .reduce { sum, element -> sum + element } value = watchDuration.toDoubleSecond() // Only compute univariate statistics for union-only operations, i.e. single source measurement. if (weightedMeasurements.size == 1) { val weightedMeasurement = weightedMeasurements.first() val weightedMeasurementVarianceParamsList: - List = - buildWeightedWatchDurationMeasurementVarianceParamsPerResult( - weightedMeasurement, - metricSpec, - ) + List = + buildWeightedWatchDurationMeasurementVarianceParamsPerResult( + weightedMeasurement, + metricSpec, + ) // If any measurement result contains insufficient data for variance calculation, univariate // statistics won't be computed. @@ -1817,26 +1819,23 @@ private fun calculateWatchDurationResult( // Watch duration results in a measurement are independent to each other. The variance is // the sum of the variances of each result. standardDeviation = - sqrt( - weightedMeasurementVarianceParamsList.sumOf { weightedMeasurementVarianceParams -> - try { - variances.computeMetricVariance( - WatchDurationMetricVarianceParams( - listOf(requireNotNull(weightedMeasurementVarianceParams)) - ) - ) - } catch (e: Throwable) { - failGrpc(Status.UNKNOWN) { - listOfNotNull( - "Unable to compute variance of watch duration metric.", - e.message, - e.cause?.message, - ) - .joinToString(separator = "\n") - } - } - } - ) + sqrt( + weightedMeasurementVarianceParamsList.sumOf { weightedMeasurementVarianceParams -> + try { + variances.computeMetricVariance( + WatchDurationMetricVarianceParams( + listOf(requireNotNull(weightedMeasurementVarianceParams)))) + } catch (e: Throwable) { + failGrpc(Status.UNKNOWN) { + listOfNotNull( + "Unable to compute variance of watch duration metric.", + e.message, + e.cause?.message, + ) + .joinToString(separator = "\n") + } + } + }) } } } @@ -1845,11 +1844,11 @@ private fun calculateWatchDurationResult( /** Calculates the population result from [WeightedMeasurement]s. */ private fun calculatePopulationResult( - weightedMeasurements: List + weightedMeasurements: List ): MetricResult.PopulationCountResult { // Only take the first measurement because Population measurements will only have one element. val populationResult = - aggregateResults(weightedMeasurements.single().measurement.details.resultsList) + aggregateResults(weightedMeasurements.single().measurement.details.resultsList) return populationCountResult { value = populationResult.population.value } } @@ -1865,11 +1864,11 @@ private fun ProtoDuration.toDoubleSecond(): Double { * @throws io.grpc.StatusRuntimeException when measurement noise mechanism is unrecognized. */ fun buildWeightedWatchDurationMeasurementVarianceParamsPerResult( - weightedMeasurement: WeightedMeasurement, - metricSpec: MetricSpec, + weightedMeasurement: WeightedMeasurement, + metricSpec: MetricSpec, ): List { val watchDurationResults: List = - weightedMeasurement.measurement.details.resultsList.map { it.watchDuration } + weightedMeasurement.measurement.details.resultsList.map { it.watchDuration } if (watchDurationResults.isEmpty()) { failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) { @@ -1879,51 +1878,52 @@ fun buildWeightedWatchDurationMeasurementVarianceParamsPerResult( return watchDurationResults.map { watchDurationResult -> val statsNoiseMechanism: StatsNoiseMechanism = - try { - watchDurationResult.noiseMechanism.toStatsNoiseMechanism() - } catch (e: NoiseMechanismUnspecifiedException) { - return@map null - } catch (e: NoiseMechanismUnrecognizedException) { - failGrpc(Status.UNKNOWN) { - listOfNotNull( - "Unrecognized noise mechanism should've been caught earlier.", - e.message, - e.cause?.message, - ) - .joinToString(separator = "\n") + try { + watchDurationResult.noiseMechanism.toStatsNoiseMechanism() + } catch (e: NoiseMechanismUnspecifiedException) { + return@map null + } catch (e: NoiseMechanismUnrecognizedException) { + failGrpc(Status.UNKNOWN) { + listOfNotNull( + "Unrecognized noise mechanism should've been caught earlier.", + e.message, + e.cause?.message, + ) + .joinToString(separator = "\n") + } } - } val methodology: Methodology = - try { - buildStatsMethodology(watchDurationResult) - } catch (e: MeasurementVarianceNotComputableException) { - return@map null - } + try { + buildStatsMethodology(watchDurationResult) + } catch (e: MeasurementVarianceNotComputableException) { + return@map null + } WeightedWatchDurationMeasurementVarianceParams( - binaryRepresentation = weightedMeasurement.binaryRepresentation, - weight = weightedMeasurement.weight, - measurementVarianceParams = - WatchDurationMeasurementVarianceParams( - duration = max(0.0, watchDurationResult.value.toDoubleSecond()), - measurementParams = - WatchDurationMeasurementParams( - vidSamplingInterval = metricSpec.vidSamplingInterval.toStatsVidSamplingInterval(), - dpParams = metricSpec.watchDuration.privacyParams.toNoiserDpParams(), - maximumDurationPerUser = - metricSpec.watchDuration.maximumWatchDurationPerUser.toDoubleSecond(), - noiseMechanism = statsNoiseMechanism, + binaryRepresentation = weightedMeasurement.binaryRepresentation, + weight = weightedMeasurement.weight, + measurementVarianceParams = + WatchDurationMeasurementVarianceParams( + duration = max(0.0, watchDurationResult.value.toDoubleSecond()), + measurementParams = + WatchDurationMeasurementParams( + vidSamplingInterval = + metricSpec.vidSamplingInterval.toStatsVidSamplingInterval(), + dpParams = metricSpec.watchDuration.privacyParams.toNoiserDpParams(), + maximumDurationPerUser = + metricSpec.watchDuration.maximumWatchDurationPerUser.toDoubleSecond(), + noiseMechanism = statsNoiseMechanism, + ), ), - ), - methodology = methodology, + methodology = methodology, ) } } /** Builds a [Methodology] from an [InternalMeasurement.Result.WatchDuration]. */ fun buildStatsMethodology( - watchDurationResult: InternalMeasurement.Result.WatchDuration + watchDurationResult: InternalMeasurement.Result.WatchDuration ): Methodology { @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") return when (watchDurationResult.methodologyCase) { @@ -1940,8 +1940,7 @@ fun buildStatsMethodology( } CustomDirectMethodology.Variance.TypeCase.UNAVAILABLE -> { throw MeasurementVarianceNotComputableException( - "Watch duration computed from a custom methodology doesn't have variance." - ) + "Watch duration computed from a custom methodology doesn't have variance.") } CustomDirectMethodology.Variance.TypeCase.TYPE_NOT_SET -> { failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) { @@ -1961,9 +1960,9 @@ fun buildStatsMethodology( /** Calculates the impression result from [WeightedMeasurement]s. */ private fun calculateImpressionResult( - weightedMeasurements: List, - metricSpec: InternalMetricSpec, - variances: Variances, + weightedMeasurements: List, + metricSpec: InternalMetricSpec, + variances: Variances, ): MetricResult.ImpressionCountResult { for (weightedMeasurement in weightedMeasurements) { if (weightedMeasurement.measurement.details.resultsList.any { !it.hasImpression() }) { @@ -1975,17 +1974,17 @@ private fun calculateImpressionResult( return impressionCountResult { value = - weightedMeasurements.sumOf { weightedMeasurement -> - aggregateResults(weightedMeasurement.measurement.details.resultsList).impression.value * - weightedMeasurement.weight - } + weightedMeasurements.sumOf { weightedMeasurement -> + aggregateResults(weightedMeasurement.measurement.details.resultsList).impression.value * + weightedMeasurement.weight + } // Only compute univariate statistics for union-only operations, i.e. single source measurement. if (weightedMeasurements.size == 1) { val weightedMeasurement = weightedMeasurements.first() val weightedMeasurementVarianceParamsList: - List = - buildWeightedImpressionMeasurementVarianceParamsPerResult(weightedMeasurement, metricSpec) + List = + buildWeightedImpressionMeasurementVarianceParamsPerResult(weightedMeasurement, metricSpec) // If any measurement result contains insufficient data for variance calculation, univariate // statistics won't be computed. @@ -1994,26 +1993,23 @@ private fun calculateImpressionResult( // Impression results in a measurement are independent to each other. The variance is the // sum of the variances of each result. standardDeviation = - sqrt( - weightedMeasurementVarianceParamsList.sumOf { weightedMeasurementVarianceParams -> - try { - variances.computeMetricVariance( - ImpressionMetricVarianceParams( - listOf(requireNotNull(weightedMeasurementVarianceParams)) - ) - ) - } catch (e: Throwable) { - failGrpc(Status.UNKNOWN) { - listOfNotNull( - "Unable to compute variance of impression metric.", - e.message, - e.cause?.message, - ) - .joinToString(separator = "\n") - } - } - } - ) + sqrt( + weightedMeasurementVarianceParamsList.sumOf { weightedMeasurementVarianceParams -> + try { + variances.computeMetricVariance( + ImpressionMetricVarianceParams( + listOf(requireNotNull(weightedMeasurementVarianceParams)))) + } catch (e: Throwable) { + failGrpc(Status.UNKNOWN) { + listOfNotNull( + "Unable to compute variance of impression metric.", + e.message, + e.cause?.message, + ) + .joinToString(separator = "\n") + } + } + }) } } } @@ -2026,11 +2022,11 @@ private fun calculateImpressionResult( * @throws io.grpc.StatusRuntimeException when measurement noise mechanism is unrecognized. */ fun buildWeightedImpressionMeasurementVarianceParamsPerResult( - weightedMeasurement: WeightedMeasurement, - metricSpec: MetricSpec, + weightedMeasurement: WeightedMeasurement, + metricSpec: MetricSpec, ): List { val impressionResults: List = - weightedMeasurement.measurement.details.resultsList.map { it.impression } + weightedMeasurement.measurement.details.resultsList.map { it.impression } if (impressionResults.isEmpty()) { failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) { @@ -2040,43 +2036,45 @@ fun buildWeightedImpressionMeasurementVarianceParamsPerResult( return impressionResults.map { impressionResult -> val statsNoiseMechanism: StatsNoiseMechanism = - try { - impressionResult.noiseMechanism.toStatsNoiseMechanism() - } catch (e: NoiseMechanismUnspecifiedException) { - return@map null - } catch (e: NoiseMechanismUnrecognizedException) { - failGrpc(Status.UNKNOWN) { - listOfNotNull( - "Unrecognized noise mechanism should've been caught earlier.", - e.message, - e.cause?.message, - ) - .joinToString(separator = "\n") + try { + impressionResult.noiseMechanism.toStatsNoiseMechanism() + } catch (e: NoiseMechanismUnspecifiedException) { + return@map null + } catch (e: NoiseMechanismUnrecognizedException) { + failGrpc(Status.UNKNOWN) { + listOfNotNull( + "Unrecognized noise mechanism should've been caught earlier.", + e.message, + e.cause?.message, + ) + .joinToString(separator = "\n") + } } - } val methodology: Methodology = - try { - buildStatsMethodology(impressionResult) - } catch (e: MeasurementVarianceNotComputableException) { - return@map null - } + try { + buildStatsMethodology(impressionResult) + } catch (e: MeasurementVarianceNotComputableException) { + return@map null + } WeightedImpressionMeasurementVarianceParams( - binaryRepresentation = weightedMeasurement.binaryRepresentation, - weight = weightedMeasurement.weight, - measurementVarianceParams = - ImpressionMeasurementVarianceParams( - impression = max(0L, impressionResult.value), - measurementParams = - ImpressionMeasurementParams( - vidSamplingInterval = metricSpec.vidSamplingInterval.toStatsVidSamplingInterval(), - dpParams = metricSpec.impressionCount.privacyParams.toNoiserDpParams(), - maximumFrequencyPerUser = metricSpec.impressionCount.maximumFrequencyPerUser, - noiseMechanism = statsNoiseMechanism, + binaryRepresentation = weightedMeasurement.binaryRepresentation, + weight = weightedMeasurement.weight, + measurementVarianceParams = + ImpressionMeasurementVarianceParams( + impression = max(0L, impressionResult.value), + measurementParams = + ImpressionMeasurementParams( + vidSamplingInterval = + metricSpec.vidSamplingInterval.toStatsVidSamplingInterval(), + dpParams = metricSpec.impressionCount.privacyParams.toNoiserDpParams(), + maximumFrequencyPerUser = + metricSpec.impressionCount.maximumFrequencyPerUser, + noiseMechanism = statsNoiseMechanism, + ), ), - ), - methodology = methodology, + methodology = methodology, ) } } @@ -2098,8 +2096,7 @@ fun buildStatsMethodology(impressionResult: InternalMeasurement.Result.Impressio } CustomDirectMethodology.Variance.TypeCase.UNAVAILABLE -> { throw MeasurementVarianceNotComputableException( - "Impression computed from a custom methodology doesn't have variance." - ) + "Impression computed from a custom methodology doesn't have variance.") } CustomDirectMethodology.Variance.TypeCase.TYPE_NOT_SET -> { failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) { @@ -2119,37 +2116,35 @@ fun buildStatsMethodology(impressionResult: InternalMeasurement.Result.Impressio /** Calculates the frequency histogram result from [WeightedMeasurement]s. */ private fun calculateFrequencyHistogramResults( - weightedMeasurements: List, - metricSpec: InternalMetricSpec, - variances: Variances, + weightedMeasurements: List, + metricSpec: InternalMetricSpec, + variances: Variances, ): MetricResult.HistogramResult { val aggregatedFrequencyHistogramMap: MutableMap = - weightedMeasurements - .map { weightedMeasurement -> - if ( - weightedMeasurement.measurement.details.resultsList.any { - !it.hasReach() || !it.hasFrequency() + weightedMeasurements + .map { weightedMeasurement -> + if (weightedMeasurement.measurement.details.resultsList.any { + !it.hasReach() || !it.hasFrequency() + }) { + failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) { + "Reach-Frequency measurement is missing." + } + } + val result = aggregateResults(weightedMeasurement.measurement.details.resultsList) + val reach = result.reach.value + result.frequency.relativeFrequencyDistributionMap.mapValues { (_, rate) -> + rate * weightedMeasurement.weight * reach + } } - ) { - failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) { - "Reach-Frequency measurement is missing." + .fold(mutableMapOf().withDefault { 0.0 }) { + aggregatedFrequencyHistogramMap: MutableMap, + weightedFrequencyHistogramMap -> + for ((frequency, count) in weightedFrequencyHistogramMap) { + aggregatedFrequencyHistogramMap[frequency] = + aggregatedFrequencyHistogramMap.getValue(frequency) + count + } + aggregatedFrequencyHistogramMap } - } - val result = aggregateResults(weightedMeasurement.measurement.details.resultsList) - val reach = result.reach.value - result.frequency.relativeFrequencyDistributionMap.mapValues { (_, rate) -> - rate * weightedMeasurement.weight * reach - } - } - .fold(mutableMapOf().withDefault { 0.0 }) { - aggregatedFrequencyHistogramMap: MutableMap, - weightedFrequencyHistogramMap -> - for ((frequency, count) in weightedFrequencyHistogramMap) { - aggregatedFrequencyHistogramMap[frequency] = - aggregatedFrequencyHistogramMap.getValue(frequency) + count - } - aggregatedFrequencyHistogramMap - } // Fill the buckets that don't have any count with zeros. for (frequency in (1L..metricSpec.reachAndFrequency.maximumFrequency)) { @@ -2159,56 +2154,55 @@ private fun calculateFrequencyHistogramResults( } val weightedMeasurementVarianceParamsList: List = - weightedMeasurements.mapNotNull { weightedMeasurement -> - buildWeightedFrequencyMeasurementVarianceParams(weightedMeasurement, metricSpec, variances) - } + weightedMeasurements.mapNotNull { weightedMeasurement -> + buildWeightedFrequencyMeasurementVarianceParams(weightedMeasurement, metricSpec, variances) + } val frequencyVariances: FrequencyVariances? = - if (weightedMeasurementVarianceParamsList.size == weightedMeasurements.size) { - try { - variances.computeMetricVariance( - FrequencyMetricVarianceParams(weightedMeasurementVarianceParamsList) - ) - } catch (e: Throwable) { - failGrpc(Status.UNKNOWN) { - listOfNotNull( - "Unable to compute variance of reach-frequency metric.", - e.message, - e.cause?.message, - ) - .joinToString(separator = "\n") + if (weightedMeasurementVarianceParamsList.size == weightedMeasurements.size) { + try { + variances.computeMetricVariance( + FrequencyMetricVarianceParams(weightedMeasurementVarianceParamsList)) + } catch (e: Throwable) { + failGrpc(Status.UNKNOWN) { + listOfNotNull( + "Unable to compute variance of reach-frequency metric.", + e.message, + e.cause?.message, + ) + .joinToString(separator = "\n") + } } + } else { + null } - } else { - null - } return histogramResult { bins += - aggregatedFrequencyHistogramMap.map { (frequency, count) -> - bin { - label = frequency.toString() - binResult = binResult { value = count } - if (frequencyVariances != null) { - resultUnivariateStatistics = univariateStatistics { - standardDeviation = - sqrt(frequencyVariances.countVariances.getValue(frequency.toInt())) - } - relativeUnivariateStatistics = univariateStatistics { - standardDeviation = - sqrt(frequencyVariances.relativeVariances.getValue(frequency.toInt())) - } - kPlusUnivariateStatistics = univariateStatistics { - standardDeviation = - sqrt(frequencyVariances.kPlusCountVariances.getValue(frequency.toInt())) - } - relativeKPlusUnivariateStatistics = univariateStatistics { - standardDeviation = - sqrt(frequencyVariances.kPlusRelativeVariances.getValue(frequency.toInt())) + aggregatedFrequencyHistogramMap.map { (frequency, count) -> + bin { + label = frequency.toString() + binResult = binResult { value = count } + if (frequencyVariances != null) { + resultUnivariateStatistics = univariateStatistics { + standardDeviation = + sqrt(frequencyVariances.countVariances.getValue(frequency.toInt())) + } + relativeUnivariateStatistics = univariateStatistics { + standardDeviation = + sqrt(frequencyVariances.relativeVariances.getValue(frequency.toInt())) + } + kPlusUnivariateStatistics = univariateStatistics { + standardDeviation = + sqrt(frequencyVariances.kPlusCountVariances.getValue(frequency.toInt())) + } + relativeKPlusUnivariateStatistics = univariateStatistics { + standardDeviation = + sqrt(frequencyVariances.kPlusRelativeVariances.getValue(frequency.toInt())) + } } } } - } } } @@ -2220,81 +2214,83 @@ private fun calculateFrequencyHistogramResults( * @throws io.grpc.StatusRuntimeException when measurement noise mechanism is unrecognized. */ fun buildWeightedFrequencyMeasurementVarianceParams( - weightedMeasurement: WeightedMeasurement, - metricSpec: MetricSpec, - variances: Variances, + weightedMeasurement: WeightedMeasurement, + metricSpec: MetricSpec, + variances: Variances, ): WeightedFrequencyMeasurementVarianceParams? { // Get reach measurement variance params val weightedReachMeasurementVarianceParams: WeightedReachMeasurementVarianceParams = - buildWeightedReachMeasurementVarianceParams( - weightedMeasurement, - metricSpec.vidSamplingInterval, - metricSpec.reachAndFrequency.reachPrivacyParams, - ) ?: return null + buildWeightedReachMeasurementVarianceParams( + weightedMeasurement, + metricSpec.vidSamplingInterval, + metricSpec.reachAndFrequency.reachPrivacyParams, + ) ?: return null val reachMeasurementVariance: Double = - variances.computeMeasurementVariance( - weightedReachMeasurementVarianceParams.methodology, - ReachMeasurementVarianceParams( - weightedReachMeasurementVarianceParams.measurementVarianceParams.reach, - weightedReachMeasurementVarianceParams.measurementVarianceParams.measurementParams, - ), - ) + variances.computeMeasurementVariance( + weightedReachMeasurementVarianceParams.methodology, + ReachMeasurementVarianceParams( + weightedReachMeasurementVarianceParams.measurementVarianceParams.reach, + weightedReachMeasurementVarianceParams.measurementVarianceParams.measurementParams, + ), + ) val frequencyResult: InternalMeasurement.Result.Frequency = - if (weightedMeasurement.measurement.details.resultsList.size == 1) { - weightedMeasurement.measurement.details.resultsList.first().frequency - } else if (weightedMeasurement.measurement.details.resultsList.size > 1) { - failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) { - "No supported methodology generates more than one frequency result." - } - } else { - failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) { - "Frequency measurement should've had frequency results." + if (weightedMeasurement.measurement.details.resultsList.size == 1) { + weightedMeasurement.measurement.details.resultsList.first().frequency + } else if (weightedMeasurement.measurement.details.resultsList.size > 1) { + failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) { + "No supported methodology generates more than one frequency result." + } + } else { + failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) { + "Frequency measurement should've had frequency results." + } } - } val frequencyStatsNoiseMechanism: StatsNoiseMechanism = - try { - frequencyResult.noiseMechanism.toStatsNoiseMechanism() - } catch (e: NoiseMechanismUnspecifiedException) { - return null - } catch (e: NoiseMechanismUnrecognizedException) { - failGrpc(Status.UNKNOWN) { - listOfNotNull( - "Unrecognized noise mechanism should've been caught earlier.", - e.message, - e.cause?.message, - ) - .joinToString(separator = "\n") + try { + frequencyResult.noiseMechanism.toStatsNoiseMechanism() + } catch (e: NoiseMechanismUnspecifiedException) { + return null + } catch (e: NoiseMechanismUnrecognizedException) { + failGrpc(Status.UNKNOWN) { + listOfNotNull( + "Unrecognized noise mechanism should've been caught earlier.", + e.message, + e.cause?.message, + ) + .joinToString(separator = "\n") + } } - } val frequencyMethodology: Methodology = - try { - buildStatsMethodology(frequencyResult) - } catch (e: MeasurementVarianceNotComputableException) { - return null - } + try { + buildStatsMethodology(frequencyResult) + } catch (e: MeasurementVarianceNotComputableException) { + return null + } return WeightedFrequencyMeasurementVarianceParams( - binaryRepresentation = weightedMeasurement.binaryRepresentation, - weight = weightedMeasurement.weight, - measurementVarianceParams = - FrequencyMeasurementVarianceParams( - totalReach = weightedReachMeasurementVarianceParams.measurementVarianceParams.reach, - reachMeasurementVariance = reachMeasurementVariance, - relativeFrequencyDistribution = - frequencyResult.relativeFrequencyDistributionMap.mapKeys { it.key.toInt() }, - measurementParams = - FrequencyMeasurementParams( - vidSamplingInterval = metricSpec.vidSamplingInterval.toStatsVidSamplingInterval(), - dpParams = metricSpec.reachAndFrequency.frequencyPrivacyParams.toNoiserDpParams(), - noiseMechanism = frequencyStatsNoiseMechanism, - maximumFrequency = metricSpec.reachAndFrequency.maximumFrequency, + binaryRepresentation = weightedMeasurement.binaryRepresentation, + weight = weightedMeasurement.weight, + measurementVarianceParams = + FrequencyMeasurementVarianceParams( + totalReach = weightedReachMeasurementVarianceParams.measurementVarianceParams.reach, + reachMeasurementVariance = reachMeasurementVariance, + relativeFrequencyDistribution = + frequencyResult.relativeFrequencyDistributionMap.mapKeys { it.key.toInt() }, + measurementParams = + FrequencyMeasurementParams( + vidSamplingInterval = + metricSpec.vidSamplingInterval.toStatsVidSamplingInterval(), + dpParams = + metricSpec.reachAndFrequency.frequencyPrivacyParams.toNoiserDpParams(), + noiseMechanism = frequencyStatsNoiseMechanism, + maximumFrequency = metricSpec.reachAndFrequency.maximumFrequency, + ), ), - ), - methodology = frequencyMethodology, + methodology = frequencyMethodology, ) } @@ -2312,18 +2308,17 @@ fun buildStatsMethodology(frequencyResult: InternalMeasurement.Result.Frequency) } CustomDirectMethodology.Variance.TypeCase.FREQUENCY -> { CustomDirectFrequencyMethodology( - frequencyResult.customDirectMethodology.variance.frequency.variancesMap.mapKeys { - it.key.toInt() - }, - frequencyResult.customDirectMethodology.variance.frequency.kPlusVariancesMap.mapKeys { - it.key.toInt() - }, + frequencyResult.customDirectMethodology.variance.frequency.variancesMap.mapKeys { + it.key.toInt() + }, + frequencyResult.customDirectMethodology.variance.frequency.kPlusVariancesMap.mapKeys { + it.key.toInt() + }, ) } CustomDirectMethodology.Variance.TypeCase.UNAVAILABLE -> { throw MeasurementVarianceNotComputableException( - "Frequency computed from a custom methodology doesn't have variance." - ) + "Frequency computed from a custom methodology doesn't have variance.") } CustomDirectMethodology.Variance.TypeCase.TYPE_NOT_SET -> { failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) { @@ -2337,15 +2332,16 @@ fun buildStatsMethodology(frequencyResult: InternalMeasurement.Result.Frequency) } InternalMeasurement.Result.Frequency.MethodologyCase.LIQUID_LEGIONS_DISTRIBUTION -> { LiquidLegionsSketchMethodology( - decayRate = frequencyResult.liquidLegionsDistribution.decayRate, - sketchSize = frequencyResult.liquidLegionsDistribution.maxSize, + decayRate = frequencyResult.liquidLegionsDistribution.decayRate, + sketchSize = frequencyResult.liquidLegionsDistribution.maxSize, ) } InternalMeasurement.Result.Frequency.MethodologyCase.LIQUID_LEGIONS_V2 -> { LiquidLegionsV2Methodology( - decayRate = frequencyResult.liquidLegionsV2.sketchParams.decayRate, - sketchSize = frequencyResult.liquidLegionsV2.sketchParams.maxSize, - samplingIndicatorSize = frequencyResult.liquidLegionsV2.sketchParams.samplingIndicatorSize, + decayRate = frequencyResult.liquidLegionsV2.sketchParams.decayRate, + sketchSize = frequencyResult.liquidLegionsV2.sketchParams.maxSize, + samplingIndicatorSize = + frequencyResult.liquidLegionsV2.sketchParams.samplingIndicatorSize, ) } InternalMeasurement.Result.Frequency.MethodologyCase.METHODOLOGY_NOT_SET -> { @@ -2356,10 +2352,10 @@ fun buildStatsMethodology(frequencyResult: InternalMeasurement.Result.Frequency) /** Calculates the reach result from [WeightedMeasurement]s. */ private fun calculateReachResult( - weightedMeasurements: List, - vidSamplingInterval: InternalMetricSpec.VidSamplingInterval, - privacyParams: InternalMetricSpec.DifferentialPrivacyParams, - variances: Variances, + weightedMeasurements: List, + vidSamplingInterval: InternalMetricSpec.VidSamplingInterval, + privacyParams: InternalMetricSpec.DifferentialPrivacyParams, + variances: Variances, ): MetricResult.ReachResult { for (weightedMeasurement in weightedMeasurements) { if (weightedMeasurement.measurement.details.resultsList.any { !it.hasReach() }) { @@ -2371,41 +2367,39 @@ private fun calculateReachResult( return reachResult { value = - weightedMeasurements.sumOf { weightedMeasurement -> - aggregateResults(weightedMeasurement.measurement.details.resultsList).reach.value * - weightedMeasurement.weight - } + weightedMeasurements.sumOf { weightedMeasurement -> + aggregateResults(weightedMeasurement.measurement.details.resultsList).reach.value * + weightedMeasurement.weight + } val weightedMeasurementVarianceParamsList: List = - weightedMeasurements.mapNotNull { weightedMeasurement -> - buildWeightedReachMeasurementVarianceParams( - weightedMeasurement, - vidSamplingInterval, - privacyParams, - ) - } + weightedMeasurements.mapNotNull { weightedMeasurement -> + buildWeightedReachMeasurementVarianceParams( + weightedMeasurement, + vidSamplingInterval, + privacyParams, + ) + } // If any measurement contains insufficient data for variance calculation, univariate statistics // won't be computed. if (weightedMeasurementVarianceParamsList.size == weightedMeasurements.size) { univariateStatistics = univariateStatistics { standardDeviation = - sqrt( - try { - variances.computeMetricVariance( - ReachMetricVarianceParams(weightedMeasurementVarianceParamsList) - ) - } catch (e: Throwable) { - failGrpc(Status.UNKNOWN) { - listOfNotNull( - "Unable to compute variance of reach metric.", - e.message, - e.cause?.message, - ) - .joinToString(separator = "\n") - } - } - ) + sqrt( + try { + variances.computeMetricVariance( + ReachMetricVarianceParams(weightedMeasurementVarianceParamsList)) + } catch (e: Throwable) { + failGrpc(Status.UNKNOWN) { + listOfNotNull( + "Unable to compute variance of reach metric.", + e.message, + e.cause?.message, + ) + .joinToString(separator = "\n") + } + }) } } } @@ -2419,60 +2413,60 @@ private fun calculateReachResult( * @throws io.grpc.StatusRuntimeException when measurement noise mechanism is unrecognized. */ private fun buildWeightedReachMeasurementVarianceParams( - weightedMeasurement: WeightedMeasurement, - vidSamplingInterval: InternalMetricSpec.VidSamplingInterval, - privacyParams: InternalMetricSpec.DifferentialPrivacyParams, + weightedMeasurement: WeightedMeasurement, + vidSamplingInterval: InternalMetricSpec.VidSamplingInterval, + privacyParams: InternalMetricSpec.DifferentialPrivacyParams, ): WeightedReachMeasurementVarianceParams? { val reachResult = - if (weightedMeasurement.measurement.details.resultsList.size == 1) { - weightedMeasurement.measurement.details.resultsList.first().reach - } else if (weightedMeasurement.measurement.details.resultsList.size > 1) { - failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) { - "No supported methodology generates more than one reach result." - } - } else { - failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) { - "Reach measurement should've had reach results." + if (weightedMeasurement.measurement.details.resultsList.size == 1) { + weightedMeasurement.measurement.details.resultsList.first().reach + } else if (weightedMeasurement.measurement.details.resultsList.size > 1) { + failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) { + "No supported methodology generates more than one reach result." + } + } else { + failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) { + "Reach measurement should've had reach results." + } } - } val statsNoiseMechanism: StatsNoiseMechanism = - try { - reachResult.noiseMechanism.toStatsNoiseMechanism() - } catch (e: NoiseMechanismUnspecifiedException) { - return null - } catch (e: NoiseMechanismUnrecognizedException) { - failGrpc(Status.UNKNOWN) { - listOfNotNull( - "Unrecognized noise mechanism should've been caught earlier.", - e.message, - e.cause?.message, - ) - .joinToString(separator = "\n") + try { + reachResult.noiseMechanism.toStatsNoiseMechanism() + } catch (e: NoiseMechanismUnspecifiedException) { + return null + } catch (e: NoiseMechanismUnrecognizedException) { + failGrpc(Status.UNKNOWN) { + listOfNotNull( + "Unrecognized noise mechanism should've been caught earlier.", + e.message, + e.cause?.message, + ) + .joinToString(separator = "\n") + } } - } val methodology: Methodology = - try { - buildStatsMethodology(reachResult) - } catch (e: MeasurementVarianceNotComputableException) { - return null - } + try { + buildStatsMethodology(reachResult) + } catch (e: MeasurementVarianceNotComputableException) { + return null + } return WeightedReachMeasurementVarianceParams( - binaryRepresentation = weightedMeasurement.binaryRepresentation, - weight = weightedMeasurement.weight, - measurementVarianceParams = - ReachMeasurementVarianceParams( - reach = max(0L, reachResult.value), - measurementParams = - ReachMeasurementParams( - vidSamplingInterval = vidSamplingInterval.toStatsVidSamplingInterval(), - dpParams = privacyParams.toNoiserDpParams(), - noiseMechanism = statsNoiseMechanism, + binaryRepresentation = weightedMeasurement.binaryRepresentation, + weight = weightedMeasurement.weight, + measurementVarianceParams = + ReachMeasurementVarianceParams( + reach = max(0L, reachResult.value), + measurementParams = + ReachMeasurementParams( + vidSamplingInterval = vidSamplingInterval.toStatsVidSamplingInterval(), + dpParams = privacyParams.toNoiserDpParams(), + noiseMechanism = statsNoiseMechanism, + ), ), - ), - methodology = methodology, + methodology = methodology, ) } @@ -2493,8 +2487,7 @@ fun buildStatsMethodology(reachResult: InternalMeasurement.Result.Reach): Method } CustomDirectMethodology.Variance.TypeCase.UNAVAILABLE -> { throw MeasurementVarianceNotComputableException( - "Reach computed from a custom methodology doesn't have variance." - ) + "Reach computed from a custom methodology doesn't have variance.") } CustomDirectMethodology.Variance.TypeCase.TYPE_NOT_SET -> { failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) { @@ -2508,22 +2501,22 @@ fun buildStatsMethodology(reachResult: InternalMeasurement.Result.Reach): Method } InternalMeasurement.Result.Reach.MethodologyCase.LIQUID_LEGIONS_COUNT_DISTINCT -> { LiquidLegionsSketchMethodology( - decayRate = reachResult.liquidLegionsCountDistinct.decayRate, - sketchSize = reachResult.liquidLegionsCountDistinct.maxSize, + decayRate = reachResult.liquidLegionsCountDistinct.decayRate, + sketchSize = reachResult.liquidLegionsCountDistinct.maxSize, ) } InternalMeasurement.Result.Reach.MethodologyCase.LIQUID_LEGIONS_V2 -> { LiquidLegionsV2Methodology( - decayRate = reachResult.liquidLegionsV2.sketchParams.decayRate, - sketchSize = reachResult.liquidLegionsV2.sketchParams.maxSize, - samplingIndicatorSize = reachResult.liquidLegionsV2.sketchParams.samplingIndicatorSize, + decayRate = reachResult.liquidLegionsV2.sketchParams.decayRate, + sketchSize = reachResult.liquidLegionsV2.sketchParams.maxSize, + samplingIndicatorSize = reachResult.liquidLegionsV2.sketchParams.samplingIndicatorSize, ) } InternalMeasurement.Result.Reach.MethodologyCase.REACH_ONLY_LIQUID_LEGIONS_V2 -> { LiquidLegionsV2Methodology( - decayRate = reachResult.reachOnlyLiquidLegionsV2.sketchParams.decayRate, - sketchSize = reachResult.reachOnlyLiquidLegionsV2.sketchParams.maxSize, - samplingIndicatorSize = 0L, + decayRate = reachResult.reachOnlyLiquidLegionsV2.sketchParams.decayRate, + sketchSize = reachResult.reachOnlyLiquidLegionsV2.sketchParams.maxSize, + samplingIndicatorSize = 0L, ) } InternalMeasurement.Result.Reach.MethodologyCase.METHODOLOGY_NOT_SET -> { @@ -2536,7 +2529,7 @@ private operator fun ProtoDuration.times(weight: Int): ProtoDuration { val source = this return duration { val weightedTotalNanos: Long = - (TimeUnit.SECONDS.toNanos(source.seconds) + source.nanos) * weight + (TimeUnit.SECONDS.toNanos(source.seconds) + source.nanos) * weight seconds = TimeUnit.NANOSECONDS.toSeconds(weightedTotalNanos) nanos = (weightedTotalNanos % NANOS_PER_SECOND).toInt() } diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportSchedulesService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportSchedulesService.kt index 710be54f309..4eb8d64dfab 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportSchedulesService.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportSchedulesService.kt @@ -67,6 +67,7 @@ import org.wfanet.measurement.internal.reporting.v2.getReportScheduleRequest import org.wfanet.measurement.internal.reporting.v2.listReportSchedulesRequest import org.wfanet.measurement.internal.reporting.v2.report as internalReport import org.wfanet.measurement.internal.reporting.v2.reportSchedule as internalReportSchedule +import kotlinx.coroutines.flow.asFlow import org.wfanet.measurement.internal.reporting.v2.stopReportScheduleRequest import org.wfanet.measurement.reporting.service.api.submitBatchRequests import org.wfanet.measurement.reporting.v2alpha.CreateReportScheduleRequest @@ -672,7 +673,7 @@ class ReportSchedulesService( while (externalReportingSetIdSet.isNotEmpty()) { retrievedExternalReportingSetIdSet.addAll(externalReportingSetIdSet) - submitBatchRequests(externalReportingSetIdSet, BATCH_GET_REPORTING_SETS_LIMIT, callRpc) { + submitBatchRequests(externalReportingSetIdSet.asFlow(), BATCH_GET_REPORTING_SETS_LIMIT, callRpc) { response -> externalReportingSetIdSet.clear() response.reportingSetsList diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsService.kt index 78832999473..b677f4f1778 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsService.kt @@ -58,6 +58,8 @@ import org.wfanet.measurement.internal.reporting.v2.batchGetMetricCalculationSpe import org.wfanet.measurement.internal.reporting.v2.createReportRequest as internalCreateReportRequest import org.wfanet.measurement.internal.reporting.v2.getReportRequest as internalGetReportRequest import org.wfanet.measurement.internal.reporting.v2.report as internalReport +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.flow import org.wfanet.measurement.reporting.service.api.submitBatchRequests import org.wfanet.measurement.reporting.service.api.v2alpha.MetadataPrincipalServerInterceptor.Companion.withPrincipalName import org.wfanet.measurement.reporting.service.api.v2alpha.ReportScheduleInfoServerInterceptor.Companion.reportScheduleInfoFromCurrentContext @@ -162,8 +164,28 @@ class ReportsService( results.subList(0, min(results.size, listReportsPageToken.pageSize)) // Get metrics. - val metricNames: List = - subResults.flatMap { internalReport -> internalReport.metricNames }.distinct() + val metricNames: Flow = + flow { + buildSet { + for (internalReport in subResults) { + for (reportingMetricEntry in internalReport.reportingMetricEntriesMap) { + for (metricCalculationSpecReportingMetrics in reportingMetricEntry.value.metricCalculationSpecReportingMetricsList) { + for (reportingMetric in metricCalculationSpecReportingMetrics.reportingMetricsList) { + val name = MetricKey( + internalReport.cmmsMeasurementConsumerId, + reportingMetric.externalMetricId + ).toName() + + if (!contains(name)) { + emit(name) + add(name) + } + } + } + } + } + } + } val callRpc: suspend (List) -> BatchGetMetricsResponse = { items -> batchGetMetrics(principal.resourceKey.toName(), items) @@ -232,7 +254,26 @@ class ReportsService( } // Get metrics. - val metricNames: List = internalReport.metricNames.distinct() + val metricNames: Flow = + flow { + buildSet { + for (reportingMetricEntry in internalReport.reportingMetricEntriesMap) { + for (metricCalculationSpecReportingMetrics in reportingMetricEntry.value.metricCalculationSpecReportingMetricsList) { + for (reportingMetric in metricCalculationSpecReportingMetrics.reportingMetricsList) { + val name = MetricKey( + internalReport.cmmsMeasurementConsumerId, + reportingMetric.externalMetricId + ).toName() + + if (!contains(name)) { + emit(name) + add(name) + } + } + } + } + } + } val callRpc: suspend (List) -> BatchGetMetricsResponse = { items -> batchGetMetrics(principal.resourceKey.toName(), items) @@ -316,7 +357,6 @@ class ReportsService( key.metricCalculationSpecId } } - .distinct() val externalIdToMetricCalculationSpecMap: Map = createExternalIdToMetricCalculationSpecMap( @@ -369,20 +409,22 @@ class ReportsService( } // Create metrics. - val createMetricRequests: List = - internalReport.reportingMetricEntriesMap.flatMap { - (reportingSetId, reportingMetricCalculationSpec) -> - reportingMetricCalculationSpec.metricCalculationSpecReportingMetricsList.flatMap { - metricCalculationSpecReportingMetrics -> - metricCalculationSpecReportingMetrics.reportingMetricsList.map { - it.toCreateMetricRequest( - principal.resourceKey, - reportingSetId, - externalIdToMetricCalculationSpecMap - .getValue(metricCalculationSpecReportingMetrics.externalMetricCalculationSpecId) - .details - .filter, - ) + val createMetricRequests: Flow = + flow { + internalReport.reportingMetricEntriesMap.flatMap { (reportingSetId, reportingMetricCalculationSpec) -> + reportingMetricCalculationSpec.metricCalculationSpecReportingMetricsList.flatMap { metricCalculationSpecReportingMetrics -> + metricCalculationSpecReportingMetrics.reportingMetricsList.map { + emit( + it.toCreateMetricRequest( + principal.resourceKey, + reportingSetId, + externalIdToMetricCalculationSpecMap + .getValue(metricCalculationSpecReportingMetrics.externalMetricCalculationSpecId) + .details + .filter, + ) + ) + } } } } @@ -494,7 +536,6 @@ class ReportsService( it.externalMetricCalculationSpecId } } - .distinct() val externalIdToMetricCalculationMap: Map = createExternalIdToMetricCalculationSpecMap( @@ -848,12 +889,6 @@ private val InternalReport.externalMetricIds: List } } -private val InternalReport.metricNames: List - get() = - externalMetricIds.map { externalMetricId -> - MetricKey(cmmsMeasurementConsumerId, externalMetricId).toName() - } - /** Converts a public [ListReportsRequest] to a [ListReportsPageToken]. */ private fun ListReportsRequest.toListReportsPageToken(): ListReportsPageToken { grpcRequire(pageSize >= 0) { "Page size cannot be less than 0" } diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/internal/testing/v2/ReportsServiceTest.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/internal/testing/v2/ReportsServiceTest.kt index d4ce867f432..1a2b6e415c9 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/internal/testing/v2/ReportsServiceTest.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/internal/testing/v2/ReportsServiceTest.kt @@ -30,6 +30,7 @@ import io.grpc.StatusRuntimeException import java.time.Clock import kotlin.random.Random import kotlin.test.assertFailsWith +import kotlinx.coroutines.flow.asFlow import kotlinx.coroutines.flow.toList import kotlinx.coroutines.runBlocking import org.junit.Before @@ -666,7 +667,7 @@ abstract class ReportsServiceTest { } ) } - submitBatchRequests(createMetricsRequests, MAX_BATCH_SIZE, callRpc) { response -> + submitBatchRequests(createMetricsRequests.asFlow(), MAX_BATCH_SIZE, callRpc) { response -> response.metricsList } .collect {} diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequestsTest.kt b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequestsTest.kt index adc71611d53..f369519cf3c 100644 --- a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequestsTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequestsTest.kt @@ -20,6 +20,8 @@ import com.google.common.truth.extensions.proto.ProtoTruth.assertThat import io.grpc.Status import io.grpc.StatusException import kotlin.math.ceil +import kotlinx.coroutines.flow.asFlow +import kotlinx.coroutines.flow.emptyFlow import kotlinx.coroutines.flow.toList import kotlinx.coroutines.runBlocking import org.junit.Before @@ -87,7 +89,7 @@ class SubmitBatchRequestsTest { ceil(INTERNAL_PRIMITIVE_REPORTING_SETS.size / BATCH_GET_REPORTING_SETS_LIMIT.toFloat()) .toInt() - val items = INTERNAL_PRIMITIVE_REPORTING_SETS.map { it.externalReportingSetId } + val items = INTERNAL_PRIMITIVE_REPORTING_SETS.map { it.externalReportingSetId }.asFlow() val parseResponse: (BatchGetReportingSetsResponse) -> List = { response -> @@ -120,7 +122,7 @@ class SubmitBatchRequestsTest { val expectedReportingSets = INTERNAL_PRIMITIVE_REPORTING_SETS.subList(0, numberTargetReportingSet) val expectedNumberBatches = 1 - val items = expectedReportingSets.map { it.externalReportingSetId } + val items = expectedReportingSets.map { it.externalReportingSetId }.asFlow() val parseResponse: (BatchGetReportingSetsResponse) -> List = { response -> @@ -154,10 +156,10 @@ class SubmitBatchRequestsTest { val result: List = submitBatchRequests( - emptyList(), - BATCH_GET_REPORTING_SETS_LIMIT, - ::batchGetReportingSets, - parseResponse, + emptyFlow(), + BATCH_GET_REPORTING_SETS_LIMIT, + ::batchGetReportingSets, + parseResponse, ) .toList() .flatten() From c4856cd1c7aea84ce2d4c897c534803b88fdf259 Mon Sep 17 00:00:00 2001 From: Tristan Vuong Date: Wed, 14 Feb 2024 17:53:34 +0000 Subject: [PATCH 08/13] lint fix --- .../service/api/SubmitBatchRequests.kt | 15 +- .../service/api/v2alpha/MetricsService.kt | 2089 +++++++++-------- .../api/v2alpha/ReportSchedulesService.kt | 9 +- .../service/api/v2alpha/ReportsService.kt | 123 +- .../service/api/SubmitBatchRequestsTest.kt | 8 +- 5 files changed, 1141 insertions(+), 1103 deletions(-) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequests.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequests.kt index 61aee48ad3d..c270a7a67d3 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequests.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequests.kt @@ -72,15 +72,14 @@ suspend fun submitBatchRequests( val batchSemaphore = Semaphore(3) return flow { coroutineScope { - val deferred: List>> = - buildList { - items.chunked(limit).collect { batch: List -> - // The batch reference is reused for every collect call. To ensure async works, a copy - // of the contents needs to be saved in a new reference. - val tempBatch = batch.toList() - add(async { batchSemaphore.withPermit { parseResponse(callRpc(tempBatch)) } }) - } + val deferred: List>> = buildList { + items.chunked(limit).collect { batch: List -> + // The batch reference is reused for every collect call. To ensure async works, a copy + // of the contents needs to be saved in a new reference. + val tempBatch = batch.toList() + add(async { batchSemaphore.withPermit { parseResponse(callRpc(tempBatch)) } }) } + } deferred.forEach { emit(it.await()) } } diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt index 47997cfe753..7b95bc5b3ef 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt @@ -50,6 +50,7 @@ import kotlinx.coroutines.awaitAll import kotlinx.coroutines.coroutineScope import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.asFlow +import kotlinx.coroutines.flow.count import kotlinx.coroutines.flow.flatMapMerge import kotlinx.coroutines.flow.flow import kotlinx.coroutines.flow.map @@ -158,7 +159,6 @@ import org.wfanet.measurement.measurementconsumer.stats.LiquidLegionsSketchMetho import org.wfanet.measurement.measurementconsumer.stats.LiquidLegionsV2Methodology import org.wfanet.measurement.measurementconsumer.stats.Methodology import org.wfanet.measurement.measurementconsumer.stats.NoiseMechanism as StatsNoiseMechanism -import kotlinx.coroutines.flow.count import org.wfanet.measurement.measurementconsumer.stats.ReachMeasurementParams import org.wfanet.measurement.measurementconsumer.stats.ReachMeasurementVarianceParams import org.wfanet.measurement.measurementconsumer.stats.ReachMetricVarianceParams @@ -214,98 +214,100 @@ private const val BATCH_SET_MEASUREMENT_RESULTS_LIMIT = 1000 private const val BATCH_SET_MEASUREMENT_FAILURES_LIMIT = 1000 class MetricsService( - private val metricSpecConfig: MetricSpecConfig, - private val internalReportingSetsStub: InternalReportingSetsCoroutineStub, - private val internalMetricsStub: InternalMetricsCoroutineStub, - private val variances: Variances, - internalMeasurementsStub: InternalMeasurementsCoroutineStub, - dataProvidersStub: DataProvidersCoroutineStub, - measurementsStub: MeasurementsCoroutineStub, - certificatesStub: CertificatesCoroutineStub, - measurementConsumersStub: MeasurementConsumersCoroutineStub, - encryptionKeyPairStore: EncryptionKeyPairStore, - secureRandom: SecureRandom, - signingPrivateKeyDir: File, - trustedCertificates: Map, - certificateCacheExpirationDuration: Duration = Duration.ofMinutes(60), - dataProviderCacheExpirationDuration: Duration = Duration.ofMinutes(60), - keyReaderContext: @BlockingExecutor CoroutineContext = Dispatchers.IO, - cacheLoaderContext: @NonBlockingExecutor CoroutineContext = Dispatchers.Default, + private val metricSpecConfig: MetricSpecConfig, + private val internalReportingSetsStub: InternalReportingSetsCoroutineStub, + private val internalMetricsStub: InternalMetricsCoroutineStub, + private val variances: Variances, + internalMeasurementsStub: InternalMeasurementsCoroutineStub, + dataProvidersStub: DataProvidersCoroutineStub, + measurementsStub: MeasurementsCoroutineStub, + certificatesStub: CertificatesCoroutineStub, + measurementConsumersStub: MeasurementConsumersCoroutineStub, + encryptionKeyPairStore: EncryptionKeyPairStore, + secureRandom: SecureRandom, + signingPrivateKeyDir: File, + trustedCertificates: Map, + certificateCacheExpirationDuration: Duration = Duration.ofMinutes(60), + dataProviderCacheExpirationDuration: Duration = Duration.ofMinutes(60), + keyReaderContext: @BlockingExecutor CoroutineContext = Dispatchers.IO, + cacheLoaderContext: @NonBlockingExecutor CoroutineContext = Dispatchers.Default, ) : MetricsCoroutineImplBase() { private data class DataProviderInfo( - val dataProviderName: String, - val publicKey: SignedMessage, - val certificateName: String, + val dataProviderName: String, + val publicKey: SignedMessage, + val certificateName: String, ) private val measurementSupplier = - MeasurementSupplier( - internalReportingSetsStub, - internalMeasurementsStub, - measurementsStub, - dataProvidersStub, - certificatesStub, - measurementConsumersStub, - encryptionKeyPairStore, - secureRandom, - signingPrivateKeyDir, - trustedCertificates, - certificateCacheExpirationDuration = certificateCacheExpirationDuration, - dataProviderCacheExpirationDuration = dataProviderCacheExpirationDuration, - keyReaderContext, - cacheLoaderContext, - ) + MeasurementSupplier( + internalReportingSetsStub, + internalMeasurementsStub, + measurementsStub, + dataProvidersStub, + certificatesStub, + measurementConsumersStub, + encryptionKeyPairStore, + secureRandom, + signingPrivateKeyDir, + trustedCertificates, + certificateCacheExpirationDuration = certificateCacheExpirationDuration, + dataProviderCacheExpirationDuration = dataProviderCacheExpirationDuration, + keyReaderContext, + cacheLoaderContext, + ) private class MeasurementSupplier( - private val internalReportingSetsStub: InternalReportingSetsCoroutineStub, - private val internalMeasurementsStub: InternalMeasurementsCoroutineStub, - private val measurementsStub: MeasurementsCoroutineStub, - private val dataProvidersStub: DataProvidersCoroutineStub, - private val certificatesStub: CertificatesCoroutineStub, - private val measurementConsumersStub: MeasurementConsumersCoroutineStub, - private val encryptionKeyPairStore: EncryptionKeyPairStore, - private val secureRandom: SecureRandom, - private val signingPrivateKeyDir: File, - private val trustedCertificates: Map, - certificateCacheExpirationDuration: Duration, - dataProviderCacheExpirationDuration: Duration, - private val keyReaderContext: @BlockingExecutor CoroutineContext = Dispatchers.IO, - cacheLoaderContext: @NonBlockingExecutor CoroutineContext = Dispatchers.Default, + private val internalReportingSetsStub: InternalReportingSetsCoroutineStub, + private val internalMeasurementsStub: InternalMeasurementsCoroutineStub, + private val measurementsStub: MeasurementsCoroutineStub, + private val dataProvidersStub: DataProvidersCoroutineStub, + private val certificatesStub: CertificatesCoroutineStub, + private val measurementConsumersStub: MeasurementConsumersCoroutineStub, + private val encryptionKeyPairStore: EncryptionKeyPairStore, + private val secureRandom: SecureRandom, + private val signingPrivateKeyDir: File, + private val trustedCertificates: Map, + certificateCacheExpirationDuration: Duration, + dataProviderCacheExpirationDuration: Duration, + private val keyReaderContext: @BlockingExecutor CoroutineContext = Dispatchers.IO, + cacheLoaderContext: @NonBlockingExecutor CoroutineContext = Dispatchers.Default, ) { private data class ResourceNameApiAuthenticationKey( - val name: String, - val apiAuthenticationKey: String, + val name: String, + val apiAuthenticationKey: String, ) private val certificateCache: LoadingCache = - LoadingCache( - Caffeine.newBuilder() - .expireAfterWrite(certificateCacheExpirationDuration) - .executor( - (cacheLoaderContext[ContinuationInterceptor] as CoroutineDispatcher) - .asExecutor()) - .buildAsync()) { key -> - getCertificate(name = key.name, apiAuthenticationKey = key.apiAuthenticationKey) - } + LoadingCache( + Caffeine.newBuilder() + .expireAfterWrite(certificateCacheExpirationDuration) + .executor( + (cacheLoaderContext[ContinuationInterceptor] as CoroutineDispatcher).asExecutor() + ) + .buildAsync() + ) { key -> + getCertificate(name = key.name, apiAuthenticationKey = key.apiAuthenticationKey) + } private val dataProviderCache: LoadingCache = - LoadingCache( - Caffeine.newBuilder() - .expireAfterWrite(dataProviderCacheExpirationDuration) - .executor( - (cacheLoaderContext[ContinuationInterceptor] as CoroutineDispatcher) - .asExecutor()) - .buildAsync()) { key -> - getDataProvider(name = key.name, apiAuthenticationKey = key.apiAuthenticationKey) - } + LoadingCache( + Caffeine.newBuilder() + .expireAfterWrite(dataProviderCacheExpirationDuration) + .executor( + (cacheLoaderContext[ContinuationInterceptor] as CoroutineDispatcher).asExecutor() + ) + .buildAsync() + ) { key -> + getDataProvider(name = key.name, apiAuthenticationKey = key.apiAuthenticationKey) + } /** * Creates CMM public [Measurement]s and [InternalMeasurement]s from a list of [InternalMetric]. */ suspend fun createCmmsMeasurements( - internalMetricsList: List, - principal: MeasurementConsumerPrincipal, + internalMetricsList: List, + principal: MeasurementConsumerPrincipal, ) { val measurementConsumer: MeasurementConsumer = getMeasurementConsumer(principal) @@ -315,7 +317,7 @@ class MetricsService( for (internalMetric in internalMetricsList) { for (weightedMeasurement in internalMetric.weightedMeasurementsList) { for (primitiveReportingSetBasis in - weightedMeasurement.measurement.primitiveReportingSetBasesList) { + weightedMeasurement.measurement.primitiveReportingSetBasesList) { if (!contains(primitiveReportingSetBasis.externalReportingSetId)) { emit(primitiveReportingSetBasis.externalReportingSetId) add(primitiveReportingSetBasis.externalReportingSetId) @@ -327,24 +329,24 @@ class MetricsService( } val callBatchGetInternalReportingSetsRpc: - suspend (List) -> BatchGetReportingSetsResponse = - { items -> - batchGetInternalReportingSets(principal.resourceKey.measurementConsumerId, items) - } + suspend (List) -> BatchGetReportingSetsResponse = + { items -> + batchGetInternalReportingSets(principal.resourceKey.measurementConsumerId, items) + } val internalPrimitiveReportingSetMap: Map = buildMap { submitBatchRequests( - externalPrimitiveReportingSetIds, - BATCH_GET_REPORTING_SETS_LIMIT, - callBatchGetInternalReportingSetsRpc, - ) { response: BatchGetReportingSetsResponse -> - response.reportingSetsList - } - .collect { reportingSets: List -> - for (reportingSet in reportingSets) { - computeIfAbsent(reportingSet.externalReportingSetId) { reportingSet } - } + externalPrimitiveReportingSetIds, + BATCH_GET_REPORTING_SETS_LIMIT, + callBatchGetInternalReportingSetsRpc, + ) { response: BatchGetReportingSetsResponse -> + response.reportingSetsList + } + .collect { reportingSets: List -> + for (reportingSet in reportingSets) { + computeIfAbsent(reportingSet.externalReportingSetId) { reportingSet } } + } } val dataProviderNames = mutableSetOf() @@ -354,7 +356,7 @@ class MetricsService( } } val dataProviderInfoMap: Map = - buildDataProviderInfoMap(principal.config.apiKey, dataProviderNames) + buildDataProviderInfoMap(principal.config.apiKey, dataProviderNames) val measurementConsumerSigningKey = getMeasurementConsumerSigningKey(principal) @@ -363,15 +365,16 @@ class MetricsService( for (weightedMeasurement in internalMetric.weightedMeasurementsList) { if (weightedMeasurement.measurement.cmmsMeasurementId.isBlank()) { emit( - buildCreateMeasurementRequest( - weightedMeasurement.measurement, - internalMetric.metricSpec, - internalPrimitiveReportingSetMap, - measurementConsumer, - principal, - dataProviderInfoMap, - measurementConsumerSigningKey, - )) + buildCreateMeasurementRequest( + weightedMeasurement.measurement, + internalMetric.metricSpec, + internalPrimitiveReportingSetMap, + measurementConsumer, + principal, + dataProviderInfoMap, + measurementConsumerSigningKey, + ) + ) } } } @@ -379,55 +382,56 @@ class MetricsService( // Create CMMS measurements. val callBatchCreateMeasurementsRpc: - suspend (List) -> BatchCreateMeasurementsResponse = - { items -> - batchCreateCmmsMeasurements(principal, items) - } + suspend (List) -> BatchCreateMeasurementsResponse = + { items -> + batchCreateCmmsMeasurements(principal, items) + } @OptIn(ExperimentalCoroutinesApi::class) val cmmsMeasurements: Flow = - submitBatchRequests( - cmmsCreateMeasurementRequests, - BATCH_KINGDOM_MEASUREMENTS_LIMIT, - callBatchCreateMeasurementsRpc, - ) { response: BatchCreateMeasurementsResponse -> - response.measurementsList - } - .flatMapMerge { it.asFlow() } + submitBatchRequests( + cmmsCreateMeasurementRequests, + BATCH_KINGDOM_MEASUREMENTS_LIMIT, + callBatchCreateMeasurementsRpc, + ) { response: BatchCreateMeasurementsResponse -> + response.measurementsList + } + .flatMapMerge { it.asFlow() } // Set CMMS measurement IDs. val callBatchSetCmmsMeasurementIdsRpc: - suspend (List) -> BatchSetCmmsMeasurementIdsResponse = - { items -> - batchSetCmmsMeasurementIds(principal.resourceKey.measurementConsumerId, items) - } + suspend (List) -> BatchSetCmmsMeasurementIdsResponse = + { items -> + batchSetCmmsMeasurementIds(principal.resourceKey.measurementConsumerId, items) + } submitBatchRequests( - cmmsMeasurements.map { - measurementIds { - cmmsCreateMeasurementRequestId = it.measurementReferenceId - cmmsMeasurementId = MeasurementKey.fromName(it.name)!!.measurementId - } - }, - BATCH_SET_CMMS_MEASUREMENT_IDS_LIMIT, - callBatchSetCmmsMeasurementIdsRpc, - ) { response: BatchSetCmmsMeasurementIdsResponse -> - response.measurementsList - } - .collect {} + cmmsMeasurements.map { + measurementIds { + cmmsCreateMeasurementRequestId = it.measurementReferenceId + cmmsMeasurementId = MeasurementKey.fromName(it.name)!!.measurementId + } + }, + BATCH_SET_CMMS_MEASUREMENT_IDS_LIMIT, + callBatchSetCmmsMeasurementIdsRpc, + ) { response: BatchSetCmmsMeasurementIdsResponse -> + response.measurementsList + } + .collect {} } /** Sets a batch of CMMS [MeasurementIds] to the [InternalMeasurement] table. */ private suspend fun batchSetCmmsMeasurementIds( - cmmsMeasurementConsumerId: String, - measurementIds: List, + cmmsMeasurementConsumerId: String, + measurementIds: List, ): BatchSetCmmsMeasurementIdsResponse { return try { internalMeasurementsStub.batchSetCmmsMeasurementIds( - batchSetCmmsMeasurementIdsRequest { - this.cmmsMeasurementConsumerId = cmmsMeasurementConsumerId - this.measurementIds += measurementIds - }) + batchSetCmmsMeasurementIdsRequest { + this.cmmsMeasurementConsumerId = cmmsMeasurementConsumerId + this.measurementIds += measurementIds + } + ) } catch (e: StatusException) { throw Exception("Unable to set the CMMS measurement IDs for the measurements.", e) } @@ -435,49 +439,49 @@ class MetricsService( /** Batch create CMMS measurements. */ private suspend fun batchCreateCmmsMeasurements( - principal: MeasurementConsumerPrincipal, - createMeasurementRequests: List, + principal: MeasurementConsumerPrincipal, + createMeasurementRequests: List, ): BatchCreateMeasurementsResponse { try { return measurementsStub - .withAuthenticationKey(principal.config.apiKey) - .batchCreateMeasurements( - batchCreateMeasurementsRequest { - parent = principal.resourceKey.toName() - requests += createMeasurementRequests - }) + .withAuthenticationKey(principal.config.apiKey) + .batchCreateMeasurements( + batchCreateMeasurementsRequest { + parent = principal.resourceKey.toName() + requests += createMeasurementRequests + } + ) } catch (e: StatusException) { throw when (e.status.code) { - Status.Code.INVALID_ARGUMENT -> - Status.INVALID_ARGUMENT.withDescription("Required field unspecified or invalid.") - Status.Code.PERMISSION_DENIED -> - Status.PERMISSION_DENIED.withDescription( - "Cannot create CMMS Measurements for another MeasurementConsumer.") - Status.Code.FAILED_PRECONDITION -> - Status.FAILED_PRECONDITION.withDescription("Failed precondition.") - Status.Code.NOT_FOUND -> - Status.NOT_FOUND.withDescription( - "${principal.resourceKey.toName()} is not found.") - else -> Status.UNKNOWN.withDescription("Unable to create CMMS Measurements.") - } - .withCause(e) - .asRuntimeException() + Status.Code.INVALID_ARGUMENT -> + Status.INVALID_ARGUMENT.withDescription("Required field unspecified or invalid.") + Status.Code.PERMISSION_DENIED -> + Status.PERMISSION_DENIED.withDescription( + "Cannot create CMMS Measurements for another MeasurementConsumer." + ) + Status.Code.FAILED_PRECONDITION -> + Status.FAILED_PRECONDITION.withDescription("Failed precondition.") + Status.Code.NOT_FOUND -> + Status.NOT_FOUND.withDescription("${principal.resourceKey.toName()} is not found.") + else -> Status.UNKNOWN.withDescription("Unable to create CMMS Measurements.") + } + .withCause(e) + .asRuntimeException() } } /** Builds a CMMS [CreateMeasurementRequest]. */ private fun buildCreateMeasurementRequest( - internalMeasurement: InternalMeasurement, - metricSpec: InternalMetricSpec, - internalPrimitiveReportingSetMap: Map, - measurementConsumer: MeasurementConsumer, - principal: MeasurementConsumerPrincipal, - dataProviderInfoMap: Map, - measurementConsumerSigningKey: SigningKeyHandle, + internalMeasurement: InternalMeasurement, + metricSpec: InternalMetricSpec, + internalPrimitiveReportingSetMap: Map, + measurementConsumer: MeasurementConsumer, + principal: MeasurementConsumerPrincipal, + dataProviderInfoMap: Map, + measurementConsumerSigningKey: SigningKeyHandle, ): CreateMeasurementRequest { val eventGroupEntriesByDataProvider = - groupEventGroupEntriesByDataProvider( - internalMeasurement, internalPrimitiveReportingSetMap) + groupEventGroupEntriesByDataProvider(internalMeasurement, internalPrimitiveReportingSetMap) val packedMeasurementEncryptionPublicKey = measurementConsumer.publicKey.message return createMeasurementRequest { @@ -486,22 +490,22 @@ class MetricsService( measurementConsumerCertificate = principal.config.signingCertificateName dataProviders += - buildDataProviderEntries( - eventGroupEntriesByDataProvider, - packedMeasurementEncryptionPublicKey, - measurementConsumerSigningKey, - dataProviderInfoMap, - ) + buildDataProviderEntries( + eventGroupEntriesByDataProvider, + packedMeasurementEncryptionPublicKey, + measurementConsumerSigningKey, + dataProviderInfoMap, + ) val unsignedMeasurementSpec: MeasurementSpec = - buildUnsignedMeasurementSpec( - packedMeasurementEncryptionPublicKey, - dataProviders.map { it.value.nonceHash }, - metricSpec, - ) + buildUnsignedMeasurementSpec( + packedMeasurementEncryptionPublicKey, + dataProviders.map { it.value.nonceHash }, + metricSpec, + ) measurementSpec = - signMeasurementSpec(unsignedMeasurementSpec, measurementConsumerSigningKey) + signMeasurementSpec(unsignedMeasurementSpec, measurementConsumerSigningKey) // To help map reporting measurements to cmms measurements. measurementReferenceId = internalMeasurement.cmmsCreateMeasurementRequestId } @@ -511,26 +515,26 @@ class MetricsService( /** Gets a [SigningKeyHandle] for a [MeasurementConsumerPrincipal]. */ private suspend fun getMeasurementConsumerSigningKey( - principal: MeasurementConsumerPrincipal + principal: MeasurementConsumerPrincipal ): SigningKeyHandle { // TODO: Factor this out to a separate class similar to EncryptionKeyPairStore. val signingPrivateKeyDer: ByteString = - withContext(keyReaderContext) { - signingPrivateKeyDir.resolve(principal.config.signingPrivateKeyPath).readByteString() - } + withContext(keyReaderContext) { + signingPrivateKeyDir.resolve(principal.config.signingPrivateKeyPath).readByteString() + } val measurementConsumerCertificate: X509Certificate = - readCertificate(getSigningCertificateDer(principal)) + readCertificate(getSigningCertificateDer(principal)) val signingPrivateKey: PrivateKey = - readPrivateKey(signingPrivateKeyDer, measurementConsumerCertificate.publicKey.algorithm) + readPrivateKey(signingPrivateKeyDer, measurementConsumerCertificate.publicKey.algorithm) return SigningKeyHandle(measurementConsumerCertificate, signingPrivateKey) } /** Builds an unsigned [MeasurementSpec]. */ private fun buildUnsignedMeasurementSpec( - packedMeasurementEncryptionPublicKey: ProtoAny, - nonceHashes: List, - metricSpec: InternalMetricSpec, + packedMeasurementEncryptionPublicKey: ProtoAny, + nonceHashes: List, + metricSpec: InternalMetricSpec, ): MeasurementSpec { return measurementSpec { measurementPublicKey = packedMeasurementEncryptionPublicKey @@ -554,9 +558,9 @@ class MetricsService( population = MeasurementSpec.Population.getDefaultInstance() } InternalMetricSpec.TypeCase.TYPE_NOT_SET -> - failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) { - "Unset metric type should've already raised error." - } + failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) { + "Unset metric type should've already raised error." + } } vidSamplingInterval = metricSpec.vidSamplingInterval.toCmmsVidSamplingInterval() // TODO(@jojijac0b): Add modelLine @@ -565,8 +569,8 @@ class MetricsService( /** Builds a [Map] of [DataProvider] name to [DataProviderInfo]. */ private suspend fun buildDataProviderInfoMap( - apiAuthenticationKey: String, - dataProviderNames: Collection, + apiAuthenticationKey: String, + dataProviderNames: Collection, ): Map { val dataProviderInfoMap = mutableMapOf() @@ -578,48 +582,55 @@ class MetricsService( coroutineScope { for (dataProviderName in dataProviderNames) { deferredDataProviderInfoList.add( - async { - val dataProvider: DataProvider = - dataProviderCache.getValue( - ResourceNameApiAuthenticationKey( - name = dataProviderName, - apiAuthenticationKey = apiAuthenticationKey, - )) - - val certificate = - certificateCache.getValue( - ResourceNameApiAuthenticationKey( - name = dataProvider.certificate, - apiAuthenticationKey = apiAuthenticationKey, - )) - - if (certificate.revocationState != - Certificate.RevocationState.REVOCATION_STATE_UNSPECIFIED) { - throw Status.FAILED_PRECONDITION.withDescription( - "${certificate.name} revocation state is ${certificate.revocationState}") - .asRuntimeException() - } + async { + val dataProvider: DataProvider = + dataProviderCache.getValue( + ResourceNameApiAuthenticationKey( + name = dataProviderName, + apiAuthenticationKey = apiAuthenticationKey, + ) + ) - val x509Certificate: X509Certificate = readCertificate(certificate.x509Der) - val trustedIssuer: X509Certificate = - trustedCertificates[checkNotNull(x509Certificate.authorityKeyIdentifier)] - ?: throw Status.FAILED_PRECONDITION.withDescription( - "${certificate.name} not issued by trusted CA") - .asRuntimeException() - try { - verifyEncryptionPublicKey(dataProvider.publicKey, x509Certificate, trustedIssuer) - } catch (e: CertPathValidatorException) { - throw Status.FAILED_PRECONDITION.withCause(e) - .withDescription("Certificate path for ${certificate.name} is invalid") - .asRuntimeException() - } catch (e: SignatureException) { - throw Status.FAILED_PRECONDITION.withCause(e) - .withDescription("DataProvider public key signature is invalid") - .asRuntimeException() - } + val certificate = + certificateCache.getValue( + ResourceNameApiAuthenticationKey( + name = dataProvider.certificate, + apiAuthenticationKey = apiAuthenticationKey, + ) + ) + + if ( + certificate.revocationState != + Certificate.RevocationState.REVOCATION_STATE_UNSPECIFIED + ) { + throw Status.FAILED_PRECONDITION.withDescription( + "${certificate.name} revocation state is ${certificate.revocationState}" + ) + .asRuntimeException() + } + + val x509Certificate: X509Certificate = readCertificate(certificate.x509Der) + val trustedIssuer: X509Certificate = + trustedCertificates[checkNotNull(x509Certificate.authorityKeyIdentifier)] + ?: throw Status.FAILED_PRECONDITION.withDescription( + "${certificate.name} not issued by trusted CA" + ) + .asRuntimeException() + try { + verifyEncryptionPublicKey(dataProvider.publicKey, x509Certificate, trustedIssuer) + } catch (e: CertPathValidatorException) { + throw Status.FAILED_PRECONDITION.withCause(e) + .withDescription("Certificate path for ${certificate.name} is invalid") + .asRuntimeException() + } catch (e: SignatureException) { + throw Status.FAILED_PRECONDITION.withCause(e) + .withDescription("DataProvider public key signature is invalid") + .asRuntimeException() + } - DataProviderInfo(dataProvider.name, dataProvider.publicKey, certificate.name) - }) + DataProviderInfo(dataProvider.name, dataProvider.publicKey, certificate.name) + } + ) } for (deferredDataProviderInfo in deferredDataProviderInfoList.awaitAll()) { @@ -635,10 +646,10 @@ class MetricsService( * [eventGroupEntriesByDataProvider]. */ private fun buildDataProviderEntries( - eventGroupEntriesByDataProvider: Map>, - packedMeasurementEncryptionPublicKey: ProtoAny, - measurementConsumerSigningKey: SigningKeyHandle, - dataProviderInfoMap: Map, + eventGroupEntriesByDataProvider: Map>, + packedMeasurementEncryptionPublicKey: ProtoAny, + measurementConsumerSigningKey: SigningKeyHandle, + dataProviderInfoMap: Map, ): List { return eventGroupEntriesByDataProvider.map { (dataProviderKey, eventGroupEntriesList) -> val dataProviderName: String = dataProviderKey.toName() @@ -650,20 +661,20 @@ class MetricsService( nonce = secureRandom.nextLong() } val encryptRequisitionSpec = - encryptRequisitionSpec( - signRequisitionSpec(requisitionSpec, measurementConsumerSigningKey), - dataProviderInfo.publicKey.unpack(), - ) + encryptRequisitionSpec( + signRequisitionSpec(requisitionSpec, measurementConsumerSigningKey), + dataProviderInfo.publicKey.unpack(), + ) dataProviderEntry { key = dataProviderName value = - MeasurementKt.DataProviderEntryKt.value { - dataProviderCertificate = dataProviderInfo.certificateName - dataProviderPublicKey = dataProviderInfo.publicKey.message - this.encryptedRequisitionSpec = encryptRequisitionSpec - nonceHash = Hashing.hashSha256(requisitionSpec.nonce) - } + MeasurementKt.DataProviderEntryKt.value { + dataProviderCertificate = dataProviderInfo.certificateName + dataProviderPublicKey = dataProviderInfo.publicKey.message + this.encryptedRequisitionSpec = encryptRequisitionSpec + nonceHash = Hashing.hashSha256(requisitionSpec.nonce) + } } } } @@ -673,44 +684,42 @@ class MetricsService( * grouping them by DataProvider. */ private fun groupEventGroupEntriesByDataProvider( - measurement: InternalMeasurement, - internalPrimitiveReportingSetMap: Map, + measurement: InternalMeasurement, + internalPrimitiveReportingSetMap: Map, ): Map> { return measurement.primitiveReportingSetBasesList - .flatMap { primitiveReportingSetBasis -> - val internalPrimitiveReportingSet = - internalPrimitiveReportingSetMap.getValue( - primitiveReportingSetBasis.externalReportingSetId) - - internalPrimitiveReportingSet.primitive.eventGroupKeysList.map { internalEventGroupKey - -> - val cmmsEventGroupKey = - CmmsEventGroupKey( - internalEventGroupKey.cmmsDataProviderId, - internalEventGroupKey.cmmsEventGroupId, - ) - val filtersList = - primitiveReportingSetBasis.filtersList.filter { !it.isNullOrBlank() } - val filter: String? = - if (filtersList.isEmpty()) null else buildConjunction(filtersList) - - cmmsEventGroupKey to - RequisitionSpecKt.eventGroupEntry { - key = cmmsEventGroupKey.toName() - value = - RequisitionSpecKt.EventGroupEntryKt.value { - collectionInterval = measurement.timeInterval - if (filter != null) { - this.filter = RequisitionSpecKt.eventFilter { expression = filter } - } - } + .flatMap { primitiveReportingSetBasis -> + val internalPrimitiveReportingSet = + internalPrimitiveReportingSetMap.getValue( + primitiveReportingSetBasis.externalReportingSetId + ) + + internalPrimitiveReportingSet.primitive.eventGroupKeysList.map { internalEventGroupKey -> + val cmmsEventGroupKey = + CmmsEventGroupKey( + internalEventGroupKey.cmmsDataProviderId, + internalEventGroupKey.cmmsEventGroupId, + ) + val filtersList = primitiveReportingSetBasis.filtersList.filter { !it.isNullOrBlank() } + val filter: String? = if (filtersList.isEmpty()) null else buildConjunction(filtersList) + + cmmsEventGroupKey to + RequisitionSpecKt.eventGroupEntry { + key = cmmsEventGroupKey.toName() + value = + RequisitionSpecKt.EventGroupEntryKt.value { + collectionInterval = measurement.timeInterval + if (filter != null) { + this.filter = RequisitionSpecKt.eventFilter { expression = filter } + } } - } + } } - .groupBy( - { (cmmsEventGroupKey, _) -> DataProviderKey(cmmsEventGroupKey.dataProviderId) }, - { (_, eventGroupEntry) -> eventGroupEntry }, - ) + } + .groupBy( + { (cmmsEventGroupKey, _) -> DataProviderKey(cmmsEventGroupKey.dataProviderId) }, + { (_, eventGroupEntry) -> eventGroupEntry }, + ) } /** Combines event group filters. */ @@ -720,59 +729,64 @@ class MetricsService( /** Gets a [MeasurementConsumer] based on a CMMS ID. */ private suspend fun getMeasurementConsumer( - principal: MeasurementConsumerPrincipal + principal: MeasurementConsumerPrincipal ): MeasurementConsumer { return try { measurementConsumersStub - .withAuthenticationKey(principal.config.apiKey) - .getMeasurementConsumer( - getMeasurementConsumerRequest { name = principal.resourceKey.toName() }) + .withAuthenticationKey(principal.config.apiKey) + .getMeasurementConsumer( + getMeasurementConsumerRequest { name = principal.resourceKey.toName() } + ) } catch (e: StatusException) { throw when (e.status.code) { - Status.Code.NOT_FOUND -> - Status.NOT_FOUND.withDescription("${principal.resourceKey.toName()} not found.") - else -> - Status.UNKNOWN.withDescription( - "Unable to retrieve the measurement consumer [${principal.resourceKey.toName()}].") - } - .withCause(e) - .asRuntimeException() + Status.Code.NOT_FOUND -> + Status.NOT_FOUND.withDescription("${principal.resourceKey.toName()} not found.") + else -> + Status.UNKNOWN.withDescription( + "Unable to retrieve the measurement consumer [${principal.resourceKey.toName()}]." + ) + } + .withCause(e) + .asRuntimeException() } } /** Gets a batch of [InternalReportingSet]s. */ private suspend fun batchGetInternalReportingSets( - cmmsMeasurementConsumerId: String, - externalReportingSetIds: List, + cmmsMeasurementConsumerId: String, + externalReportingSetIds: List, ): BatchGetReportingSetsResponse { return try { internalReportingSetsStub.batchGetReportingSets( - batchGetReportingSetsRequest { - this.cmmsMeasurementConsumerId = cmmsMeasurementConsumerId - this.externalReportingSetIds += externalReportingSetIds - }) + batchGetReportingSetsRequest { + this.cmmsMeasurementConsumerId = cmmsMeasurementConsumerId + this.externalReportingSetIds += externalReportingSetIds + } + ) } catch (e: StatusException) { throw when (e.status.code) { - Status.Code.NOT_FOUND -> Status.NOT_FOUND.withDescription("Reporting Set not found.") - else -> - Status.UNKNOWN.withDescription( - "Unable to retrieve ReportingSets used in the requesting metric.") - } - .withCause(e) - .asRuntimeException() + Status.Code.NOT_FOUND -> Status.NOT_FOUND.withDescription("Reporting Set not found.") + else -> + Status.UNKNOWN.withDescription( + "Unable to retrieve ReportingSets used in the requesting metric." + ) + } + .withCause(e) + .asRuntimeException() } } /** Gets a signing certificate x509Der in ByteString. */ private suspend fun getSigningCertificateDer( - principal: MeasurementConsumerPrincipal + principal: MeasurementConsumerPrincipal ): ByteString { val certificate = - certificateCache.getValue( - ResourceNameApiAuthenticationKey( - name = principal.config.signingCertificateName, - apiAuthenticationKey = principal.config.apiKey, - )) + certificateCache.getValue( + ResourceNameApiAuthenticationKey( + name = principal.config.signingCertificateName, + apiAuthenticationKey = principal.config.apiKey, + ) + ) return certificate.x509Der } @@ -783,9 +797,9 @@ class MetricsService( * @return a boolean to indicate whether any [InternalMeasurement] was updated. */ suspend fun syncInternalMeasurements( - internalMeasurements: List, - apiAuthenticationKey: String, - principal: MeasurementConsumerPrincipal, + internalMeasurements: List, + apiAuthenticationKey: String, + principal: MeasurementConsumerPrincipal, ): Boolean { val failedMeasurements: MutableList = mutableListOf() @@ -802,9 +816,9 @@ class MetricsService( Measurement.State.COMPUTING, Measurement.State.AWAITING_REQUISITION_FULFILLMENT -> {} Measurement.State.STATE_UNSPECIFIED -> - failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) { - "The CMMS measurement state should've been set." - } + failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) { + "The CMMS measurement state should've been set." + } Measurement.State.UNRECOGNIZED -> { failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) { "Unrecognized CMMS measurement state." @@ -818,14 +832,15 @@ class MetricsService( var anyUpdate = false val callBatchSetInternalMeasurementResultsRpc: - suspend (List) -> BatchSetCmmsMeasurementResultsResponse = - { items -> - batchSetInternalMeasurementResults(items, apiAuthenticationKey, principal) - } - val count = submitBatchRequests( - succeededMeasurements, - BATCH_SET_MEASUREMENT_RESULTS_LIMIT, - callBatchSetInternalMeasurementResultsRpc, + suspend (List) -> BatchSetCmmsMeasurementResultsResponse = + { items -> + batchSetInternalMeasurementResults(items, apiAuthenticationKey, principal) + } + val count = + submitBatchRequests( + succeededMeasurements, + BATCH_SET_MEASUREMENT_RESULTS_LIMIT, + callBatchSetInternalMeasurementResultsRpc, ) { response: BatchSetCmmsMeasurementResultsResponse -> response.measurementsList } @@ -839,18 +854,15 @@ class MetricsService( val callBatchSetInternalMeasurementFailuresRpc: suspend (List) -> BatchSetCmmsMeasurementFailuresResponse = { items -> - batchSetInternalMeasurementFailures( - items, - principal.resourceKey.measurementConsumerId, - ) + batchSetInternalMeasurementFailures(items, principal.resourceKey.measurementConsumerId) } submitBatchRequests( - failedMeasurements.asFlow(), - BATCH_SET_MEASUREMENT_FAILURES_LIMIT, - callBatchSetInternalMeasurementFailuresRpc, - ) { response: BatchSetCmmsMeasurementFailuresResponse -> - response.measurementsList - } + failedMeasurements.asFlow(), + BATCH_SET_MEASUREMENT_FAILURES_LIMIT, + callBatchSetInternalMeasurementFailuresRpc, + ) { response: BatchSetCmmsMeasurementFailuresResponse -> + response.measurementsList + } .collect {} anyUpdate = true @@ -864,23 +876,24 @@ class MetricsService( * failed or canceled CMMS [Measurement]s. */ private suspend fun batchSetInternalMeasurementFailures( - failedMeasurementsList: List, - cmmsMeasurementConsumerId: String, + failedMeasurementsList: List, + cmmsMeasurementConsumerId: String, ): BatchSetCmmsMeasurementFailuresResponse { val batchSetInternalMeasurementFailuresRequest = batchSetMeasurementFailuresRequest { this.cmmsMeasurementConsumerId = cmmsMeasurementConsumerId measurementFailures += - failedMeasurementsList.map { measurement -> - measurementFailure { - cmmsMeasurementId = MeasurementKey.fromName(measurement.name)!!.measurementId - failure = measurement.failure.toInternal() - } + failedMeasurementsList.map { measurement -> + measurementFailure { + cmmsMeasurementId = MeasurementKey.fromName(measurement.name)!!.measurementId + failure = measurement.failure.toInternal() } + } } return try { internalMeasurementsStub.batchSetMeasurementFailures( - batchSetInternalMeasurementFailuresRequest) + batchSetInternalMeasurementFailuresRequest + ) } catch (e: StatusException) { throw Exception("Unable to set measurement failures for Measurements.", e) } @@ -891,20 +904,20 @@ class MetricsService( * given succeeded CMMS [Measurement]s. */ private suspend fun batchSetInternalMeasurementResults( - succeededMeasurementsList: List, - apiAuthenticationKey: String, - principal: MeasurementConsumerPrincipal, + succeededMeasurementsList: List, + apiAuthenticationKey: String, + principal: MeasurementConsumerPrincipal, ): BatchSetCmmsMeasurementResultsResponse { val batchSetMeasurementResultsRequest = batchSetMeasurementResultsRequest { cmmsMeasurementConsumerId = principal.resourceKey.measurementConsumerId measurementResults += - succeededMeasurementsList.map { measurement -> - buildInternalMeasurementResult( - measurement, - apiAuthenticationKey, - principal.resourceKey.toName(), - ) - } + succeededMeasurementsList.map { measurement -> + buildInternalMeasurementResult( + measurement, + apiAuthenticationKey, + principal.resourceKey.toName(), + ) + } } return try { @@ -916,18 +929,18 @@ class MetricsService( /** Retrieves [Measurement]s from the CMMS. */ private suspend fun getCmmsMeasurements( - internalMeasurements: List, - principal: MeasurementConsumerPrincipal, + internalMeasurements: List, + principal: MeasurementConsumerPrincipal, ): Flow> { val measurementNames: Flow = flow { buildSet { for (internalMeasurement in internalMeasurements) { val name = - MeasurementKey( - principal.resourceKey.measurementConsumerId, - internalMeasurement.cmmsMeasurementId, - ) - .toName() + MeasurementKey( + principal.resourceKey.measurementConsumerId, + internalMeasurement.cmmsMeasurementId, + ) + .toName() if (!contains(name)) { emit(name) @@ -938,14 +951,14 @@ class MetricsService( } val callBatchGetMeasurementsRpc: suspend (List) -> BatchGetMeasurementsResponse = - { items -> - batchGetCmmsMeasurements(principal, items) - } + { items -> + batchGetCmmsMeasurements(principal, items) + } return submitBatchRequests( - measurementNames, - BATCH_KINGDOM_MEASUREMENTS_LIMIT, - callBatchGetMeasurementsRpc, + measurementNames, + BATCH_KINGDOM_MEASUREMENTS_LIMIT, + callBatchGetMeasurementsRpc, ) { response: BatchGetMeasurementsResponse -> response.measurementsList } @@ -953,99 +966,103 @@ class MetricsService( /** Batch get CMMS measurements. */ private suspend fun batchGetCmmsMeasurements( - principal: MeasurementConsumerPrincipal, - measurementNames: List, + principal: MeasurementConsumerPrincipal, + measurementNames: List, ): BatchGetMeasurementsResponse { try { return measurementsStub - .withAuthenticationKey(principal.config.apiKey) - .batchGetMeasurements( - batchGetMeasurementsRequest { - parent = principal.resourceKey.toName() - names += measurementNames - }) + .withAuthenticationKey(principal.config.apiKey) + .batchGetMeasurements( + batchGetMeasurementsRequest { + parent = principal.resourceKey.toName() + names += measurementNames + } + ) } catch (e: StatusException) { throw when (e.status.code) { - Status.Code.NOT_FOUND -> Status.NOT_FOUND.withDescription("Measurements not found.") - Status.Code.PERMISSION_DENIED -> - Status.PERMISSION_DENIED.withDescription( - "Doesn't have permission to get measurements.") - else -> Status.UNKNOWN.withDescription("Unable to retrieve Measurements.") - } - .withCause(e) - .asRuntimeException() + Status.Code.NOT_FOUND -> Status.NOT_FOUND.withDescription("Measurements not found.") + Status.Code.PERMISSION_DENIED -> + Status.PERMISSION_DENIED.withDescription( + "Doesn't have permission to get measurements." + ) + else -> Status.UNKNOWN.withDescription("Unable to retrieve Measurements.") + } + .withCause(e) + .asRuntimeException() } } /** Builds an [InternalMeasurement.Result]. */ private suspend fun buildInternalMeasurementResult( - measurement: Measurement, - apiAuthenticationKey: String, - principalName: String, + measurement: Measurement, + apiAuthenticationKey: String, + principalName: String, ): BatchSetMeasurementResultsRequest.MeasurementResult { val measurementSpec: MeasurementSpec = measurement.measurementSpec.unpack() val encryptionPrivateKeyHandle = - encryptionKeyPairStore.getPrivateKeyHandle( - principalName, - measurementSpec.measurementPublicKey.unpack().data, - ) - ?: failGrpc(Status.FAILED_PRECONDITION) { - "Encryption private key not found for the measurement ${measurement.name}." - } + encryptionKeyPairStore.getPrivateKeyHandle( + principalName, + measurementSpec.measurementPublicKey.unpack().data, + ) + ?: failGrpc(Status.FAILED_PRECONDITION) { + "Encryption private key not found for the measurement ${measurement.name}." + } val decryptedMeasurementResults: List = - measurement.resultsList.map { - decryptMeasurementResultOutput(it, encryptionPrivateKeyHandle, apiAuthenticationKey) - } + measurement.resultsList.map { + decryptMeasurementResultOutput(it, encryptionPrivateKeyHandle, apiAuthenticationKey) + } return measurementResult { cmmsMeasurementId = MeasurementKey.fromName(measurement.name)!!.measurementId results += - decryptedMeasurementResults.map { - try { - it.toInternal(measurement.protocolConfig) - } catch (e: NoiseMechanismUnrecognizedException) { - failGrpc(Status.UNKNOWN) { - listOfNotNull("Unrecognized noise mechanism.", e.message, e.cause?.message) - .joinToString(separator = "\n") - } - } catch (e: Throwable) { - failGrpc(Status.UNKNOWN) { - listOfNotNull("Unable to read measurement result.", e.message, e.cause?.message) - .joinToString(separator = "\n") - } + decryptedMeasurementResults.map { + try { + it.toInternal(measurement.protocolConfig) + } catch (e: NoiseMechanismUnrecognizedException) { + failGrpc(Status.UNKNOWN) { + listOfNotNull("Unrecognized noise mechanism.", e.message, e.cause?.message) + .joinToString(separator = "\n") + } + } catch (e: Throwable) { + failGrpc(Status.UNKNOWN) { + listOfNotNull("Unable to read measurement result.", e.message, e.cause?.message) + .joinToString(separator = "\n") } } + } } } /** Decrypts a [Measurement.ResultOutput] to [Measurement.Result] */ private suspend fun decryptMeasurementResultOutput( - measurementResultOutput: Measurement.ResultOutput, - encryptionPrivateKeyHandle: PrivateKeyHandle, - apiAuthenticationKey: String, + measurementResultOutput: Measurement.ResultOutput, + encryptionPrivateKeyHandle: PrivateKeyHandle, + apiAuthenticationKey: String, ): Measurement.Result { val certificate = - certificateCache.getValue( - ResourceNameApiAuthenticationKey( - name = measurementResultOutput.certificate, - apiAuthenticationKey = apiAuthenticationKey, - )) + certificateCache.getValue( + ResourceNameApiAuthenticationKey( + name = measurementResultOutput.certificate, + apiAuthenticationKey = apiAuthenticationKey, + ) + ) val signedResult = - decryptResult(measurementResultOutput.encryptedResult, encryptionPrivateKeyHandle) + decryptResult(measurementResultOutput.encryptedResult, encryptionPrivateKeyHandle) if (certificate.revocationState != Certificate.RevocationState.REVOCATION_STATE_UNSPECIFIED) { throw Status.FAILED_PRECONDITION.withDescription( - "${certificate.name} revocation state is ${certificate.revocationState}") - .asRuntimeException() + "${certificate.name} revocation state is ${certificate.revocationState}" + ) + .asRuntimeException() } val x509Certificate: X509Certificate = readCertificate(certificate.x509Der) val trustedIssuer: X509Certificate = - checkNotNull(trustedCertificates[checkNotNull(x509Certificate.authorityKeyIdentifier)]) { - "${certificate.name} not issued by trusted CA" - } + checkNotNull(trustedCertificates[checkNotNull(x509Certificate.authorityKeyIdentifier)]) { + "${certificate.name} not issued by trusted CA" + } // TODO: Record verification failure in internal Measurement rather than having the RPC fail. try { @@ -1069,16 +1086,16 @@ class MetricsService( private suspend fun getCertificate(name: String, apiAuthenticationKey: String): Certificate { return try { certificatesStub - .withAuthenticationKey(apiAuthenticationKey) - .getCertificate(getCertificateRequest { this.name = name }) + .withAuthenticationKey(apiAuthenticationKey) + .getCertificate(getCertificateRequest { this.name = name }) } catch (e: StatusException) { throw when (e.status.code) { - Status.Code.NOT_FOUND -> - Status.FAILED_PRECONDITION.withDescription("Certificate $name not found.") - else -> Status.UNKNOWN.withDescription("Unable to retrieve Certificate $name.") - } - .withCause(e) - .asRuntimeException() + Status.Code.NOT_FOUND -> + Status.FAILED_PRECONDITION.withDescription("Certificate $name not found.") + else -> Status.UNKNOWN.withDescription("Unable to retrieve Certificate $name.") + } + .withCause(e) + .asRuntimeException() } } @@ -1093,24 +1110,24 @@ class MetricsService( private suspend fun getDataProvider(name: String, apiAuthenticationKey: String): DataProvider { return try { dataProvidersStub - .withAuthenticationKey(apiAuthenticationKey) - .getDataProvider(getDataProviderRequest { this.name = name }) + .withAuthenticationKey(apiAuthenticationKey) + .getDataProvider(getDataProviderRequest { this.name = name }) } catch (e: StatusException) { throw when (e.status.code) { - Status.Code.NOT_FOUND -> Status.FAILED_PRECONDITION.withDescription("$name not found") - else -> Status.UNKNOWN.withDescription("Unable to retrieve $name") - } - .withCause(e) - .asRuntimeException() + Status.Code.NOT_FOUND -> Status.FAILED_PRECONDITION.withDescription("$name not found") + else -> Status.UNKNOWN.withDescription("Unable to retrieve $name") + } + .withCause(e) + .asRuntimeException() } } } override suspend fun getMetric(request: GetMetricRequest): Metric { val metricKey = - grpcRequireNotNull(MetricKey.fromName(request.name)) { - "Metric name is either unspecified or invalid." - } + grpcRequireNotNull(MetricKey.fromName(request.name)) { + "Metric name is either unspecified or invalid." + } val principal: ReportingPrincipal = principalFromCurrentContext when (principal) { @@ -1124,7 +1141,7 @@ class MetricsService( } val internalMetric: InternalMetric = - getInternalMetric(metricKey.cmmsMeasurementConsumerId, metricKey.metricId) + getInternalMetric(metricKey.cmmsMeasurementConsumerId, metricKey.metricId) // Early exit when the metric is at a terminal state. if (internalMetric.state != Metric.State.RUNNING) { @@ -1133,18 +1150,18 @@ class MetricsService( // Only syncs pending measurements which can only be in metrics that are still running. val toBeSyncedInternalMeasurements: List = - internalMetric.weightedMeasurementsList - .map { weightedMeasurement -> weightedMeasurement.measurement } - .filter { internalMeasurement -> - internalMeasurement.state == InternalMeasurement.State.PENDING - } + internalMetric.weightedMeasurementsList + .map { weightedMeasurement -> weightedMeasurement.measurement } + .filter { internalMeasurement -> + internalMeasurement.state == InternalMeasurement.State.PENDING + } val anyMeasurementUpdated: Boolean = - measurementSupplier.syncInternalMeasurements( - toBeSyncedInternalMeasurements, - principal.config.apiKey, - principal, - ) + measurementSupplier.syncInternalMeasurements( + toBeSyncedInternalMeasurements, + principal.config.apiKey, + principal, + ) return if (anyMeasurementUpdated) { getInternalMetric(metricKey.cmmsMeasurementConsumerId, metricKey.metricId).toMetric(variances) @@ -1155,9 +1172,9 @@ class MetricsService( override suspend fun batchGetMetrics(request: BatchGetMetricsRequest): BatchGetMetricsResponse { val parentKey = - grpcRequireNotNull(MeasurementConsumerKey.fromName(request.parent)) { - "Parent is either unspecified or invalid." - } + grpcRequireNotNull(MeasurementConsumerKey.fromName(request.parent)) { + "Parent is either unspecified or invalid." + } val principal: ReportingPrincipal = principalFromCurrentContext @@ -1177,55 +1194,55 @@ class MetricsService( } val metricIds: List = - request.namesList.map { metricName -> - val metricKey = - grpcRequireNotNull(MetricKey.fromName(metricName)) { - "Metric name is either unspecified or invalid." - } - metricKey.metricId - } + request.namesList.map { metricName -> + val metricKey = + grpcRequireNotNull(MetricKey.fromName(metricName)) { + "Metric name is either unspecified or invalid." + } + metricKey.metricId + } val internalMetrics: List = - batchGetInternalMetrics(principal.resourceKey.measurementConsumerId, metricIds) + batchGetInternalMetrics(principal.resourceKey.measurementConsumerId, metricIds) // Only syncs pending measurements which can only be in metrics that are still running. val toBeSyncedInternalMeasurements: List = - internalMetrics - .filter { internalMetric -> internalMetric.state == Metric.State.RUNNING } - .flatMap { internalMetric -> internalMetric.weightedMeasurementsList } - .map { weightedMeasurement -> weightedMeasurement.measurement } - .filter { internalMeasurement -> - internalMeasurement.state == InternalMeasurement.State.PENDING - } + internalMetrics + .filter { internalMetric -> internalMetric.state == Metric.State.RUNNING } + .flatMap { internalMetric -> internalMetric.weightedMeasurementsList } + .map { weightedMeasurement -> weightedMeasurement.measurement } + .filter { internalMeasurement -> + internalMeasurement.state == InternalMeasurement.State.PENDING + } val anyMeasurementUpdated: Boolean = - measurementSupplier.syncInternalMeasurements( - toBeSyncedInternalMeasurements, - principal.config.apiKey, - principal, - ) + measurementSupplier.syncInternalMeasurements( + toBeSyncedInternalMeasurements, + principal.config.apiKey, + principal, + ) return batchGetMetricsResponse { metrics += - /** - * TODO(@riemanli): a potential improvement can be done by only getting the metrics whose - * measurements are updated. Re-evaluate when a load-test is ready after deployment. - */ - if (anyMeasurementUpdated) { - batchGetInternalMetrics(principal.resourceKey.measurementConsumerId, metricIds).map { - it.toMetric(variances) - } - } else { - internalMetrics.map { it.toMetric(variances) } + /** + * TODO(@riemanli): a potential improvement can be done by only getting the metrics whose + * measurements are updated. Re-evaluate when a load-test is ready after deployment. + */ + if (anyMeasurementUpdated) { + batchGetInternalMetrics(principal.resourceKey.measurementConsumerId, metricIds).map { + it.toMetric(variances) } + } else { + internalMetrics.map { it.toMetric(variances) } + } } } override suspend fun listMetrics(request: ListMetricsRequest): ListMetricsResponse { val parentKey = - grpcRequireNotNull(MeasurementConsumerKey.fromName(request.parent)) { - "Parent is either unspecified or invalid." - } + grpcRequireNotNull(MeasurementConsumerKey.fromName(request.parent)) { + "Parent is either unspecified or invalid." + } val principal: ReportingPrincipal = principalFromCurrentContext when (principal) { @@ -1242,50 +1259,50 @@ class MetricsService( val apiAuthenticationKey: String = principal.config.apiKey val streamInternalMetricRequest: StreamMetricsRequest = - listMetricsPageToken.toStreamMetricsRequest() + listMetricsPageToken.toStreamMetricsRequest() val results: List = - try { - internalMetricsStub.streamMetrics(streamInternalMetricRequest).toList() - } catch (e: StatusException) { - throw Exception("Unable to list Metrics.", e) - } + try { + internalMetricsStub.streamMetrics(streamInternalMetricRequest).toList() + } catch (e: StatusException) { + throw Exception("Unable to list Metrics.", e) + } if (results.isEmpty()) { return ListMetricsResponse.getDefaultInstance() } val nextPageToken: ListMetricsPageToken? = - if (results.size > listMetricsPageToken.pageSize) { - listMetricsPageToken.copy { - lastMetric = previousPageEnd { - cmmsMeasurementConsumerId = results[results.lastIndex - 1].cmmsMeasurementConsumerId - externalMetricId = results[results.lastIndex - 1].externalMetricId - } + if (results.size > listMetricsPageToken.pageSize) { + listMetricsPageToken.copy { + lastMetric = previousPageEnd { + cmmsMeasurementConsumerId = results[results.lastIndex - 1].cmmsMeasurementConsumerId + externalMetricId = results[results.lastIndex - 1].externalMetricId } - } else { - null } + } else { + null + } val subResults: List = - results.subList(0, min(results.size, listMetricsPageToken.pageSize)) + results.subList(0, min(results.size, listMetricsPageToken.pageSize)) // Only syncs pending measurements which can only be in metrics that are still running. val toBeSyncedInternalMeasurements: List = - subResults - .filter { internalMetric -> internalMetric.state == Metric.State.RUNNING } - .flatMap { internalMetric -> internalMetric.weightedMeasurementsList } - .map { weightedMeasurement -> weightedMeasurement.measurement } - .filter { internalMeasurement -> - internalMeasurement.state == InternalMeasurement.State.PENDING - } + subResults + .filter { internalMetric -> internalMetric.state == Metric.State.RUNNING } + .flatMap { internalMetric -> internalMetric.weightedMeasurementsList } + .map { weightedMeasurement -> weightedMeasurement.measurement } + .filter { internalMeasurement -> + internalMeasurement.state == InternalMeasurement.State.PENDING + } val anyMeasurementUpdated: Boolean = - measurementSupplier.syncInternalMeasurements( - toBeSyncedInternalMeasurements, - apiAuthenticationKey, - principal, - ) + measurementSupplier.syncInternalMeasurements( + toBeSyncedInternalMeasurements, + apiAuthenticationKey, + principal, + ) /** * If any measurement got updated, pull the list of the up-to-date internal metrics. Otherwise, @@ -1295,14 +1312,14 @@ class MetricsService( * measurements are updated. Re-evaluate when a load-test is ready after deployment. */ val internalMetrics: List = - if (anyMeasurementUpdated) { - batchGetInternalMetrics( - principal.resourceKey.measurementConsumerId, - subResults.map { internalMetric -> internalMetric.externalMetricId }, - ) - } else { - subResults - } + if (anyMeasurementUpdated) { + batchGetInternalMetrics( + principal.resourceKey.measurementConsumerId, + subResults.map { internalMetric -> internalMetric.externalMetricId }, + ) + } else { + subResults + } return listMetricsResponse { metrics += internalMetrics.map { it.toMetric(variances) } @@ -1315,8 +1332,8 @@ class MetricsService( /** Gets a batch of [InternalMetric]s. */ private suspend fun batchGetInternalMetrics( - cmmsMeasurementConsumerId: String, - metricIds: List, + cmmsMeasurementConsumerId: String, + metricIds: List, ): List { val batchGetMetricsRequest = batchGetMetricsRequest { this.cmmsMeasurementConsumerId = cmmsMeasurementConsumerId @@ -1332,8 +1349,8 @@ class MetricsService( /** Gets an [InternalMetric]. */ private suspend fun getInternalMetric( - cmmsMeasurementConsumerId: String, - metricId: String, + cmmsMeasurementConsumerId: String, + metricId: String, ): InternalMetric { return try { batchGetInternalMetrics(cmmsMeasurementConsumerId, listOf(metricId)).first() @@ -1345,9 +1362,9 @@ class MetricsService( override suspend fun createMetric(request: CreateMetricRequest): Metric { val parentKey = - grpcRequireNotNull(MeasurementConsumerKey.fromName(request.parent)) { - "Parent is either unspecified or invalid." - } + grpcRequireNotNull(MeasurementConsumerKey.fromName(request.parent)) { + "Parent is either unspecified or invalid." + } val principal: ReportingPrincipal = principalFromCurrentContext @@ -1362,26 +1379,28 @@ class MetricsService( } val internalCreateMetricRequest: InternalCreateMetricRequest = - buildInternalCreateMetricRequest(principal.resourceKey.measurementConsumerId, request) + buildInternalCreateMetricRequest(principal.resourceKey.measurementConsumerId, request) val internalMetric = - try { - internalMetricsStub.createMetric(internalCreateMetricRequest) - } catch (e: StatusException) { - throw when (e.status.code) { - Status.Code.ALREADY_EXISTS -> - Status.ALREADY_EXISTS.withDescription( - "Metric with ID ${request.metricId} already exists under ${request.parent}") - Status.Code.NOT_FOUND -> - Status.NOT_FOUND.withDescription("Reporting set used in the metric not found.") - Status.Code.FAILED_PRECONDITION -> - Status.FAILED_PRECONDITION.withDescription( - "Unable to create the metric. The measurement consumer not found.") - else -> Status.UNKNOWN.withDescription("Unable to create Metric.") - } - .withCause(e) - .asRuntimeException() - } + try { + internalMetricsStub.createMetric(internalCreateMetricRequest) + } catch (e: StatusException) { + throw when (e.status.code) { + Status.Code.ALREADY_EXISTS -> + Status.ALREADY_EXISTS.withDescription( + "Metric with ID ${request.metricId} already exists under ${request.parent}" + ) + Status.Code.NOT_FOUND -> + Status.NOT_FOUND.withDescription("Reporting set used in the metric not found.") + Status.Code.FAILED_PRECONDITION -> + Status.FAILED_PRECONDITION.withDescription( + "Unable to create the metric. The measurement consumer not found." + ) + else -> Status.UNKNOWN.withDescription("Unable to create Metric.") + } + .withCause(e) + .asRuntimeException() + } if (internalMetric.state == Metric.State.RUNNING) { measurementSupplier.createCmmsMeasurements(listOf(internalMetric), principal) @@ -1392,12 +1411,12 @@ class MetricsService( } override suspend fun batchCreateMetrics( - request: BatchCreateMetricsRequest + request: BatchCreateMetricsRequest ): BatchCreateMetricsResponse { val parentKey = - grpcRequireNotNull(MeasurementConsumerKey.fromName(request.parent)) { - "Parent is either unspecified or invalid." - } + grpcRequireNotNull(MeasurementConsumerKey.fromName(request.parent)) { + "Parent is either unspecified or invalid." + } val principal: ReportingPrincipal = principalFromCurrentContext @@ -1422,34 +1441,36 @@ class MetricsService( } val internalCreateMetricRequestsList: List = - request.requestsList.map { createMetricRequest -> - buildInternalCreateMetricRequest(parentKey.measurementConsumerId, createMetricRequest) - } + request.requestsList.map { createMetricRequest -> + buildInternalCreateMetricRequest(parentKey.measurementConsumerId, createMetricRequest) + } val internalMetrics = - try { - internalMetricsStub - .batchCreateMetrics( - internalBatchCreateMetricsRequest { - cmmsMeasurementConsumerId = parentKey.measurementConsumerId - requests += internalCreateMetricRequestsList - }) - .metricsList - } catch (e: StatusException) { - throw when (e.status.code) { - Status.Code.NOT_FOUND -> - Status.NOT_FOUND.withDescription("Reporting set used in metrics not found.") - Status.Code.FAILED_PRECONDITION -> - Status.FAILED_PRECONDITION.withDescription( - "Unable to create the metrics. The measurement consumer not found.") - else -> Status.UNKNOWN.withDescription("Unable to create Metrics.") - } - .withCause(e) - .asRuntimeException() - } + try { + internalMetricsStub + .batchCreateMetrics( + internalBatchCreateMetricsRequest { + cmmsMeasurementConsumerId = parentKey.measurementConsumerId + requests += internalCreateMetricRequestsList + } + ) + .metricsList + } catch (e: StatusException) { + throw when (e.status.code) { + Status.Code.NOT_FOUND -> + Status.NOT_FOUND.withDescription("Reporting set used in metrics not found.") + Status.Code.FAILED_PRECONDITION -> + Status.FAILED_PRECONDITION.withDescription( + "Unable to create the metrics. The measurement consumer not found." + ) + else -> Status.UNKNOWN.withDescription("Unable to create Metrics.") + } + .withCause(e) + .asRuntimeException() + } val internalRunningMetrics = - internalMetrics.filter { internalMetric -> internalMetric.state == Metric.State.RUNNING } + internalMetrics.filter { internalMetric -> internalMetric.state == Metric.State.RUNNING } if (internalRunningMetrics.isNotEmpty()) { measurementSupplier.createCmmsMeasurements(internalRunningMetrics, principal) } @@ -1460,8 +1481,8 @@ class MetricsService( /** Builds an [InternalCreateMetricRequest]. */ private suspend fun buildInternalCreateMetricRequest( - cmmsMeasurementConsumerId: String, - request: CreateMetricRequest, + cmmsMeasurementConsumerId: String, + request: CreateMetricRequest, ): InternalCreateMetricRequest { grpcRequire(request.hasMetric()) { "Metric is not specified." } @@ -1471,31 +1492,34 @@ class MetricsService( } grpcRequire(request.metric.hasTimeInterval()) { "Time interval in metric is not specified." } grpcRequire( - request.metric.timeInterval.startTime.seconds > 0 || - request.metric.timeInterval.startTime.nanos > 0) { - "TimeInterval startTime is unspecified." - } + request.metric.timeInterval.startTime.seconds > 0 || + request.metric.timeInterval.startTime.nanos > 0 + ) { + "TimeInterval startTime is unspecified." + } grpcRequire( - request.metric.timeInterval.endTime.seconds > 0 || - request.metric.timeInterval.endTime.nanos > 0) { - "TimeInterval endTime is unspecified." - } + request.metric.timeInterval.endTime.seconds > 0 || + request.metric.timeInterval.endTime.nanos > 0 + ) { + "TimeInterval endTime is unspecified." + } grpcRequire( - request.metric.timeInterval.endTime.seconds > - request.metric.timeInterval.startTime.seconds || - request.metric.timeInterval.endTime.nanos > - request.metric.timeInterval.startTime.nanos) { - "TimeInterval endTime is not later than startTime." - } + request.metric.timeInterval.endTime.seconds > request.metric.timeInterval.startTime.seconds || + request.metric.timeInterval.endTime.nanos > request.metric.timeInterval.startTime.nanos + ) { + "TimeInterval endTime is not later than startTime." + } grpcRequire(request.metric.hasMetricSpec()) { "Metric spec in metric is not specified." } val internalReportingSet: InternalReportingSet = - getInternalReportingSet(cmmsMeasurementConsumerId, request.metric.reportingSet) + getInternalReportingSet(cmmsMeasurementConsumerId, request.metric.reportingSet) // Utilizes the property of the set expression compilation result -- If the set expression // contains only union operators, the compilation result has to be a single component. - if (request.metric.metricSpec.hasReachAndFrequency() && - internalReportingSet.weightedSubsetUnionsList.size != 1) { + if ( + request.metric.metricSpec.hasReachAndFrequency() && + internalReportingSet.weightedSubsetUnionsList.size != 1 + ) { failGrpc(Status.INVALID_ARGUMENT) { "Reach-and-frequency metrics can only be computed on union-only set expressions." } @@ -1509,22 +1533,22 @@ class MetricsService( externalReportingSetId = internalReportingSet.externalReportingSetId timeInterval = request.metric.timeInterval metricSpec = - try { - request.metric.metricSpec.withDefaults(metricSpecConfig).toInternal() - } catch (e: MetricSpecDefaultsException) { - failGrpc(Status.INVALID_ARGUMENT) { - listOfNotNull("Invalid metric spec.", e.message, e.cause?.message) - .joinToString(separator = "\n") - } - } catch (e: Exception) { - failGrpc(Status.UNKNOWN) { "Failed to read the metric spec." } + try { + request.metric.metricSpec.withDefaults(metricSpecConfig).toInternal() + } catch (e: MetricSpecDefaultsException) { + failGrpc(Status.INVALID_ARGUMENT) { + listOfNotNull("Invalid metric spec.", e.message, e.cause?.message) + .joinToString(separator = "\n") } + } catch (e: Exception) { + failGrpc(Status.UNKNOWN) { "Failed to read the metric spec." } + } weightedMeasurements += - buildInitialInternalMeasurements( - cmmsMeasurementConsumerId, - request.metric, - internalReportingSet, - ) + buildInitialInternalMeasurements( + cmmsMeasurementConsumerId, + request.metric, + internalReportingSet, + ) details = InternalMetricKt.details { filters += request.metric.filtersList } } } @@ -1532,9 +1556,9 @@ class MetricsService( /** Builds [InternalMeasurement]s for a [Metric] over an [InternalReportingSet]. */ private fun buildInitialInternalMeasurements( - cmmsMeasurementConsumerId: String, - metric: Metric, - internalReportingSet: InternalReportingSet, + cmmsMeasurementConsumerId: String, + metric: Metric, + internalReportingSet: InternalReportingSet, ): List { return internalReportingSet.weightedSubsetUnionsList.map { weightedSubsetUnion -> weightedMeasurement { @@ -1544,9 +1568,9 @@ class MetricsService( this.cmmsMeasurementConsumerId = cmmsMeasurementConsumerId timeInterval = metric.timeInterval this.primitiveReportingSetBases += - weightedSubsetUnion.primitiveReportingSetBasesList.map { primitiveReportingSetBasis -> - primitiveReportingSetBasis.copy { filters += metric.filtersList } - } + weightedSubsetUnion.primitiveReportingSetBasesList.map { primitiveReportingSetBasis -> + primitiveReportingSetBasis.copy { filters += metric.filtersList } + } } } } @@ -1554,13 +1578,13 @@ class MetricsService( /** Gets an [InternalReportingSet] based on a reporting set name. */ private suspend fun getInternalReportingSet( - cmmsMeasurementConsumerId: String, - reportingSetName: String, + cmmsMeasurementConsumerId: String, + reportingSetName: String, ): InternalReportingSet { val reportingSetKey = - grpcRequireNotNull(ReportingSetKey.fromName(reportingSetName)) { - "Invalid reporting set name $reportingSetName." - } + grpcRequireNotNull(ReportingSetKey.fromName(reportingSetName)) { + "Invalid reporting set name $reportingSetName." + } if (reportingSetKey.cmmsMeasurementConsumerId != cmmsMeasurementConsumerId) { failGrpc(Status.PERMISSION_DENIED) { "No access to the reporting set [$reportingSetName]." } @@ -1568,17 +1592,18 @@ class MetricsService( return try { internalReportingSetsStub - .batchGetReportingSets( - batchGetReportingSetsRequest { - this.cmmsMeasurementConsumerId = cmmsMeasurementConsumerId - this.externalReportingSetIds += reportingSetKey.reportingSetId - }) - .reportingSetsList - .first() + .batchGetReportingSets( + batchGetReportingSetsRequest { + this.cmmsMeasurementConsumerId = cmmsMeasurementConsumerId + this.externalReportingSetIds += reportingSetKey.reportingSetId + } + ) + .reportingSetsList + .first() } catch (e: StatusException) { throw Exception( - "Unable to retrieve ReportingSet using the provided name [$reportingSetName].", - e, + "Unable to retrieve ReportingSet using the provided name [$reportingSetName].", + e, ) } } @@ -1595,9 +1620,9 @@ fun ListMetricsRequest.toListMetricsPageToken(): ListMetricsPageToken { grpcRequire(source.pageSize >= 0) { "Page size cannot be less than 0." } val parentKey: MeasurementConsumerKey = - grpcRequireNotNull(MeasurementConsumerKey.fromName(source.parent)) { - "Parent is either unspecified or invalid." - } + grpcRequireNotNull(MeasurementConsumerKey.fromName(source.parent)) { + "Parent is either unspecified or invalid." + } val cmmsMeasurementConsumerId = parentKey.measurementConsumerId return if (pageToken.isNotBlank()) { @@ -1613,11 +1638,11 @@ fun ListMetricsRequest.toListMetricsPageToken(): ListMetricsPageToken { } else { listMetricsPageToken { pageSize = - when { - source.pageSize < MIN_PAGE_SIZE -> DEFAULT_PAGE_SIZE - source.pageSize > MAX_PAGE_SIZE -> MAX_PAGE_SIZE - else -> source.pageSize - } + when { + source.pageSize < MIN_PAGE_SIZE -> DEFAULT_PAGE_SIZE + source.pageSize > MAX_PAGE_SIZE -> MAX_PAGE_SIZE + else -> source.pageSize + } this.cmmsMeasurementConsumerId = cmmsMeasurementConsumerId } } @@ -1628,13 +1653,13 @@ private fun InternalMetric.toMetric(variances: Variances): Metric { val source = this return metric { name = - MetricKey( - cmmsMeasurementConsumerId = source.cmmsMeasurementConsumerId, - metricId = source.externalMetricId, - ) - .toName() + MetricKey( + cmmsMeasurementConsumerId = source.cmmsMeasurementConsumerId, + metricId = source.externalMetricId, + ) + .toName() reportingSet = - ReportingSetKey(source.cmmsMeasurementConsumerId, source.externalReportingSetId).toName() + ReportingSetKey(source.cmmsMeasurementConsumerId, source.externalReportingSetId).toName() timeInterval = source.timeInterval metricSpec = source.metricSpec.toMetricSpec() filters += source.details.filtersList @@ -1650,50 +1675,49 @@ private fun InternalMetric.toMetric(variances: Variances): Metric { private fun buildMetricResult(metric: InternalMetric, variances: Variances): MetricResult { return metricResult { cmmsMeasurements += - metric.weightedMeasurementsList.map { - MeasurementKey(metric.cmmsMeasurementConsumerId, it.measurement.cmmsMeasurementId) - .toName() - } + metric.weightedMeasurementsList.map { + MeasurementKey(metric.cmmsMeasurementConsumerId, it.measurement.cmmsMeasurementId).toName() + } @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. when (metric.metricSpec.typeCase) { InternalMetricSpec.TypeCase.REACH -> { reach = - calculateReachResult( - metric.weightedMeasurementsList, - metric.metricSpec.vidSamplingInterval, - metric.metricSpec.reach.privacyParams, - variances, - ) + calculateReachResult( + metric.weightedMeasurementsList, + metric.metricSpec.vidSamplingInterval, + metric.metricSpec.reach.privacyParams, + variances, + ) } InternalMetricSpec.TypeCase.REACH_AND_FREQUENCY -> { reachAndFrequency = reachAndFrequencyResult { reach = - calculateReachResult( - metric.weightedMeasurementsList, - metric.metricSpec.vidSamplingInterval, - metric.metricSpec.reachAndFrequency.reachPrivacyParams, - variances, - ) + calculateReachResult( + metric.weightedMeasurementsList, + metric.metricSpec.vidSamplingInterval, + metric.metricSpec.reachAndFrequency.reachPrivacyParams, + variances, + ) frequencyHistogram = - calculateFrequencyHistogramResults( - metric.weightedMeasurementsList, - metric.metricSpec, - variances, - ) + calculateFrequencyHistogramResults( + metric.weightedMeasurementsList, + metric.metricSpec, + variances, + ) } } InternalMetricSpec.TypeCase.IMPRESSION_COUNT -> { impressionCount = - calculateImpressionResult(metric.weightedMeasurementsList, metric.metricSpec, variances) + calculateImpressionResult(metric.weightedMeasurementsList, metric.metricSpec, variances) } InternalMetricSpec.TypeCase.WATCH_DURATION -> { watchDuration = - calculateWatchDurationResult( - metric.weightedMeasurementsList, - metric.metricSpec, - variances, - ) + calculateWatchDurationResult( + metric.weightedMeasurementsList, + metric.metricSpec, + variances, + ) } InternalMetricSpec.TypeCase.POPULATION_COUNT -> { populationCount = calculatePopulationResult(metric.weightedMeasurementsList) @@ -1709,7 +1733,7 @@ private fun buildMetricResult(metric: InternalMetric, variances: Variances): Met /** Aggregates a list of [InternalMeasurement.Result]s to a [InternalMeasurement.Result] */ private fun aggregateResults( - internalMeasurementResults: List + internalMeasurementResults: List ): InternalMeasurement.Result { if (internalMeasurementResults.isEmpty()) { failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) { @@ -1735,10 +1759,10 @@ private fun aggregateResults( } for ((frequency, percentage) in result.frequency.relativeFrequencyDistributionMap) { val previousTotalReachCount = - frequencyDistribution.getOrDefault(frequency, 0.0) * reachValue + frequencyDistribution.getOrDefault(frequency, 0.0) * reachValue val currentReachCount = percentage * result.reach.value frequencyDistribution[frequency] = - (previousTotalReachCount + currentReachCount) / (reachValue + result.reach.value) + (previousTotalReachCount + currentReachCount) / (reachValue + result.reach.value) } } if (result.hasReach()) { @@ -1761,16 +1785,16 @@ private fun aggregateResults( } if (internalMeasurementResults.first().hasFrequency()) { this.frequency = - InternalMeasurementKt.ResultKt.frequency { - relativeFrequencyDistribution.putAll(frequencyDistribution) - } + InternalMeasurementKt.ResultKt.frequency { + relativeFrequencyDistribution.putAll(frequencyDistribution) + } } if (internalMeasurementResults.first().hasImpression()) { this.impression = InternalMeasurementKt.ResultKt.impression { value = impressionValue } } if (internalMeasurementResults.first().hasWatchDuration()) { this.watchDuration = - InternalMeasurementKt.ResultKt.watchDuration { value = watchDurationValue } + InternalMeasurementKt.ResultKt.watchDuration { value = watchDurationValue } } if (internalMeasurementResults.first().hasPopulation()) { this.population = InternalMeasurementKt.ResultKt.population { value = populationValue } @@ -1780,9 +1804,9 @@ private fun aggregateResults( /** Calculates the watch duration result from [WeightedMeasurement]s. */ private fun calculateWatchDurationResult( - weightedMeasurements: List, - metricSpec: InternalMetricSpec, - variances: Variances, + weightedMeasurements: List, + metricSpec: InternalMetricSpec, + variances: Variances, ): MetricResult.WatchDurationResult { for (weightedMeasurement in weightedMeasurements) { if (weightedMeasurement.measurement.details.resultsList.any { !it.hasWatchDuration() }) { @@ -1793,24 +1817,24 @@ private fun calculateWatchDurationResult( } return watchDurationResult { val watchDuration: ProtoDuration = - weightedMeasurements - .map { weightedMeasurement -> - aggregateResults(weightedMeasurement.measurement.details.resultsList) - .watchDuration - .value * weightedMeasurement.weight - } - .reduce { sum, element -> sum + element } + weightedMeasurements + .map { weightedMeasurement -> + aggregateResults(weightedMeasurement.measurement.details.resultsList) + .watchDuration + .value * weightedMeasurement.weight + } + .reduce { sum, element -> sum + element } value = watchDuration.toDoubleSecond() // Only compute univariate statistics for union-only operations, i.e. single source measurement. if (weightedMeasurements.size == 1) { val weightedMeasurement = weightedMeasurements.first() val weightedMeasurementVarianceParamsList: - List = - buildWeightedWatchDurationMeasurementVarianceParamsPerResult( - weightedMeasurement, - metricSpec, - ) + List = + buildWeightedWatchDurationMeasurementVarianceParamsPerResult( + weightedMeasurement, + metricSpec, + ) // If any measurement result contains insufficient data for variance calculation, univariate // statistics won't be computed. @@ -1819,23 +1843,26 @@ private fun calculateWatchDurationResult( // Watch duration results in a measurement are independent to each other. The variance is // the sum of the variances of each result. standardDeviation = - sqrt( - weightedMeasurementVarianceParamsList.sumOf { weightedMeasurementVarianceParams -> - try { - variances.computeMetricVariance( - WatchDurationMetricVarianceParams( - listOf(requireNotNull(weightedMeasurementVarianceParams)))) - } catch (e: Throwable) { - failGrpc(Status.UNKNOWN) { - listOfNotNull( - "Unable to compute variance of watch duration metric.", - e.message, - e.cause?.message, - ) - .joinToString(separator = "\n") - } - } - }) + sqrt( + weightedMeasurementVarianceParamsList.sumOf { weightedMeasurementVarianceParams -> + try { + variances.computeMetricVariance( + WatchDurationMetricVarianceParams( + listOf(requireNotNull(weightedMeasurementVarianceParams)) + ) + ) + } catch (e: Throwable) { + failGrpc(Status.UNKNOWN) { + listOfNotNull( + "Unable to compute variance of watch duration metric.", + e.message, + e.cause?.message, + ) + .joinToString(separator = "\n") + } + } + } + ) } } } @@ -1844,11 +1871,11 @@ private fun calculateWatchDurationResult( /** Calculates the population result from [WeightedMeasurement]s. */ private fun calculatePopulationResult( - weightedMeasurements: List + weightedMeasurements: List ): MetricResult.PopulationCountResult { // Only take the first measurement because Population measurements will only have one element. val populationResult = - aggregateResults(weightedMeasurements.single().measurement.details.resultsList) + aggregateResults(weightedMeasurements.single().measurement.details.resultsList) return populationCountResult { value = populationResult.population.value } } @@ -1864,11 +1891,11 @@ private fun ProtoDuration.toDoubleSecond(): Double { * @throws io.grpc.StatusRuntimeException when measurement noise mechanism is unrecognized. */ fun buildWeightedWatchDurationMeasurementVarianceParamsPerResult( - weightedMeasurement: WeightedMeasurement, - metricSpec: MetricSpec, + weightedMeasurement: WeightedMeasurement, + metricSpec: MetricSpec, ): List { val watchDurationResults: List = - weightedMeasurement.measurement.details.resultsList.map { it.watchDuration } + weightedMeasurement.measurement.details.resultsList.map { it.watchDuration } if (watchDurationResults.isEmpty()) { failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) { @@ -1878,52 +1905,51 @@ fun buildWeightedWatchDurationMeasurementVarianceParamsPerResult( return watchDurationResults.map { watchDurationResult -> val statsNoiseMechanism: StatsNoiseMechanism = - try { - watchDurationResult.noiseMechanism.toStatsNoiseMechanism() - } catch (e: NoiseMechanismUnspecifiedException) { - return@map null - } catch (e: NoiseMechanismUnrecognizedException) { - failGrpc(Status.UNKNOWN) { - listOfNotNull( - "Unrecognized noise mechanism should've been caught earlier.", - e.message, - e.cause?.message, - ) - .joinToString(separator = "\n") - } + try { + watchDurationResult.noiseMechanism.toStatsNoiseMechanism() + } catch (e: NoiseMechanismUnspecifiedException) { + return@map null + } catch (e: NoiseMechanismUnrecognizedException) { + failGrpc(Status.UNKNOWN) { + listOfNotNull( + "Unrecognized noise mechanism should've been caught earlier.", + e.message, + e.cause?.message, + ) + .joinToString(separator = "\n") } + } val methodology: Methodology = - try { - buildStatsMethodology(watchDurationResult) - } catch (e: MeasurementVarianceNotComputableException) { - return@map null - } + try { + buildStatsMethodology(watchDurationResult) + } catch (e: MeasurementVarianceNotComputableException) { + return@map null + } WeightedWatchDurationMeasurementVarianceParams( - binaryRepresentation = weightedMeasurement.binaryRepresentation, - weight = weightedMeasurement.weight, - measurementVarianceParams = - WatchDurationMeasurementVarianceParams( - duration = max(0.0, watchDurationResult.value.toDoubleSecond()), - measurementParams = - WatchDurationMeasurementParams( - vidSamplingInterval = - metricSpec.vidSamplingInterval.toStatsVidSamplingInterval(), - dpParams = metricSpec.watchDuration.privacyParams.toNoiserDpParams(), - maximumDurationPerUser = - metricSpec.watchDuration.maximumWatchDurationPerUser.toDoubleSecond(), - noiseMechanism = statsNoiseMechanism, - ), + binaryRepresentation = weightedMeasurement.binaryRepresentation, + weight = weightedMeasurement.weight, + measurementVarianceParams = + WatchDurationMeasurementVarianceParams( + duration = max(0.0, watchDurationResult.value.toDoubleSecond()), + measurementParams = + WatchDurationMeasurementParams( + vidSamplingInterval = metricSpec.vidSamplingInterval.toStatsVidSamplingInterval(), + dpParams = metricSpec.watchDuration.privacyParams.toNoiserDpParams(), + maximumDurationPerUser = + metricSpec.watchDuration.maximumWatchDurationPerUser.toDoubleSecond(), + noiseMechanism = statsNoiseMechanism, ), - methodology = methodology, + ), + methodology = methodology, ) } } /** Builds a [Methodology] from an [InternalMeasurement.Result.WatchDuration]. */ fun buildStatsMethodology( - watchDurationResult: InternalMeasurement.Result.WatchDuration + watchDurationResult: InternalMeasurement.Result.WatchDuration ): Methodology { @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") return when (watchDurationResult.methodologyCase) { @@ -1940,7 +1966,8 @@ fun buildStatsMethodology( } CustomDirectMethodology.Variance.TypeCase.UNAVAILABLE -> { throw MeasurementVarianceNotComputableException( - "Watch duration computed from a custom methodology doesn't have variance.") + "Watch duration computed from a custom methodology doesn't have variance." + ) } CustomDirectMethodology.Variance.TypeCase.TYPE_NOT_SET -> { failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) { @@ -1960,9 +1987,9 @@ fun buildStatsMethodology( /** Calculates the impression result from [WeightedMeasurement]s. */ private fun calculateImpressionResult( - weightedMeasurements: List, - metricSpec: InternalMetricSpec, - variances: Variances, + weightedMeasurements: List, + metricSpec: InternalMetricSpec, + variances: Variances, ): MetricResult.ImpressionCountResult { for (weightedMeasurement in weightedMeasurements) { if (weightedMeasurement.measurement.details.resultsList.any { !it.hasImpression() }) { @@ -1974,17 +2001,17 @@ private fun calculateImpressionResult( return impressionCountResult { value = - weightedMeasurements.sumOf { weightedMeasurement -> - aggregateResults(weightedMeasurement.measurement.details.resultsList).impression.value * - weightedMeasurement.weight - } + weightedMeasurements.sumOf { weightedMeasurement -> + aggregateResults(weightedMeasurement.measurement.details.resultsList).impression.value * + weightedMeasurement.weight + } // Only compute univariate statistics for union-only operations, i.e. single source measurement. if (weightedMeasurements.size == 1) { val weightedMeasurement = weightedMeasurements.first() val weightedMeasurementVarianceParamsList: - List = - buildWeightedImpressionMeasurementVarianceParamsPerResult(weightedMeasurement, metricSpec) + List = + buildWeightedImpressionMeasurementVarianceParamsPerResult(weightedMeasurement, metricSpec) // If any measurement result contains insufficient data for variance calculation, univariate // statistics won't be computed. @@ -1993,23 +2020,26 @@ private fun calculateImpressionResult( // Impression results in a measurement are independent to each other. The variance is the // sum of the variances of each result. standardDeviation = - sqrt( - weightedMeasurementVarianceParamsList.sumOf { weightedMeasurementVarianceParams -> - try { - variances.computeMetricVariance( - ImpressionMetricVarianceParams( - listOf(requireNotNull(weightedMeasurementVarianceParams)))) - } catch (e: Throwable) { - failGrpc(Status.UNKNOWN) { - listOfNotNull( - "Unable to compute variance of impression metric.", - e.message, - e.cause?.message, - ) - .joinToString(separator = "\n") - } - } - }) + sqrt( + weightedMeasurementVarianceParamsList.sumOf { weightedMeasurementVarianceParams -> + try { + variances.computeMetricVariance( + ImpressionMetricVarianceParams( + listOf(requireNotNull(weightedMeasurementVarianceParams)) + ) + ) + } catch (e: Throwable) { + failGrpc(Status.UNKNOWN) { + listOfNotNull( + "Unable to compute variance of impression metric.", + e.message, + e.cause?.message, + ) + .joinToString(separator = "\n") + } + } + } + ) } } } @@ -2022,11 +2052,11 @@ private fun calculateImpressionResult( * @throws io.grpc.StatusRuntimeException when measurement noise mechanism is unrecognized. */ fun buildWeightedImpressionMeasurementVarianceParamsPerResult( - weightedMeasurement: WeightedMeasurement, - metricSpec: MetricSpec, + weightedMeasurement: WeightedMeasurement, + metricSpec: MetricSpec, ): List { val impressionResults: List = - weightedMeasurement.measurement.details.resultsList.map { it.impression } + weightedMeasurement.measurement.details.resultsList.map { it.impression } if (impressionResults.isEmpty()) { failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) { @@ -2036,45 +2066,43 @@ fun buildWeightedImpressionMeasurementVarianceParamsPerResult( return impressionResults.map { impressionResult -> val statsNoiseMechanism: StatsNoiseMechanism = - try { - impressionResult.noiseMechanism.toStatsNoiseMechanism() - } catch (e: NoiseMechanismUnspecifiedException) { - return@map null - } catch (e: NoiseMechanismUnrecognizedException) { - failGrpc(Status.UNKNOWN) { - listOfNotNull( - "Unrecognized noise mechanism should've been caught earlier.", - e.message, - e.cause?.message, - ) - .joinToString(separator = "\n") - } + try { + impressionResult.noiseMechanism.toStatsNoiseMechanism() + } catch (e: NoiseMechanismUnspecifiedException) { + return@map null + } catch (e: NoiseMechanismUnrecognizedException) { + failGrpc(Status.UNKNOWN) { + listOfNotNull( + "Unrecognized noise mechanism should've been caught earlier.", + e.message, + e.cause?.message, + ) + .joinToString(separator = "\n") } + } val methodology: Methodology = - try { - buildStatsMethodology(impressionResult) - } catch (e: MeasurementVarianceNotComputableException) { - return@map null - } + try { + buildStatsMethodology(impressionResult) + } catch (e: MeasurementVarianceNotComputableException) { + return@map null + } WeightedImpressionMeasurementVarianceParams( - binaryRepresentation = weightedMeasurement.binaryRepresentation, - weight = weightedMeasurement.weight, - measurementVarianceParams = - ImpressionMeasurementVarianceParams( - impression = max(0L, impressionResult.value), - measurementParams = - ImpressionMeasurementParams( - vidSamplingInterval = - metricSpec.vidSamplingInterval.toStatsVidSamplingInterval(), - dpParams = metricSpec.impressionCount.privacyParams.toNoiserDpParams(), - maximumFrequencyPerUser = - metricSpec.impressionCount.maximumFrequencyPerUser, - noiseMechanism = statsNoiseMechanism, - ), + binaryRepresentation = weightedMeasurement.binaryRepresentation, + weight = weightedMeasurement.weight, + measurementVarianceParams = + ImpressionMeasurementVarianceParams( + impression = max(0L, impressionResult.value), + measurementParams = + ImpressionMeasurementParams( + vidSamplingInterval = metricSpec.vidSamplingInterval.toStatsVidSamplingInterval(), + dpParams = metricSpec.impressionCount.privacyParams.toNoiserDpParams(), + maximumFrequencyPerUser = metricSpec.impressionCount.maximumFrequencyPerUser, + noiseMechanism = statsNoiseMechanism, ), - methodology = methodology, + ), + methodology = methodology, ) } } @@ -2096,7 +2124,8 @@ fun buildStatsMethodology(impressionResult: InternalMeasurement.Result.Impressio } CustomDirectMethodology.Variance.TypeCase.UNAVAILABLE -> { throw MeasurementVarianceNotComputableException( - "Impression computed from a custom methodology doesn't have variance.") + "Impression computed from a custom methodology doesn't have variance." + ) } CustomDirectMethodology.Variance.TypeCase.TYPE_NOT_SET -> { failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) { @@ -2116,35 +2145,37 @@ fun buildStatsMethodology(impressionResult: InternalMeasurement.Result.Impressio /** Calculates the frequency histogram result from [WeightedMeasurement]s. */ private fun calculateFrequencyHistogramResults( - weightedMeasurements: List, - metricSpec: InternalMetricSpec, - variances: Variances, + weightedMeasurements: List, + metricSpec: InternalMetricSpec, + variances: Variances, ): MetricResult.HistogramResult { val aggregatedFrequencyHistogramMap: MutableMap = - weightedMeasurements - .map { weightedMeasurement -> - if (weightedMeasurement.measurement.details.resultsList.any { - !it.hasReach() || !it.hasFrequency() - }) { - failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) { - "Reach-Frequency measurement is missing." - } - } - val result = aggregateResults(weightedMeasurement.measurement.details.resultsList) - val reach = result.reach.value - result.frequency.relativeFrequencyDistributionMap.mapValues { (_, rate) -> - rate * weightedMeasurement.weight * reach - } + weightedMeasurements + .map { weightedMeasurement -> + if ( + weightedMeasurement.measurement.details.resultsList.any { + !it.hasReach() || !it.hasFrequency() } - .fold(mutableMapOf().withDefault { 0.0 }) { - aggregatedFrequencyHistogramMap: MutableMap, - weightedFrequencyHistogramMap -> - for ((frequency, count) in weightedFrequencyHistogramMap) { - aggregatedFrequencyHistogramMap[frequency] = - aggregatedFrequencyHistogramMap.getValue(frequency) + count - } - aggregatedFrequencyHistogramMap + ) { + failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) { + "Reach-Frequency measurement is missing." } + } + val result = aggregateResults(weightedMeasurement.measurement.details.resultsList) + val reach = result.reach.value + result.frequency.relativeFrequencyDistributionMap.mapValues { (_, rate) -> + rate * weightedMeasurement.weight * reach + } + } + .fold(mutableMapOf().withDefault { 0.0 }) { + aggregatedFrequencyHistogramMap: MutableMap, + weightedFrequencyHistogramMap -> + for ((frequency, count) in weightedFrequencyHistogramMap) { + aggregatedFrequencyHistogramMap[frequency] = + aggregatedFrequencyHistogramMap.getValue(frequency) + count + } + aggregatedFrequencyHistogramMap + } // Fill the buckets that don't have any count with zeros. for (frequency in (1L..metricSpec.reachAndFrequency.maximumFrequency)) { @@ -2154,55 +2185,56 @@ private fun calculateFrequencyHistogramResults( } val weightedMeasurementVarianceParamsList: List = - weightedMeasurements.mapNotNull { weightedMeasurement -> - buildWeightedFrequencyMeasurementVarianceParams(weightedMeasurement, metricSpec, variances) - } + weightedMeasurements.mapNotNull { weightedMeasurement -> + buildWeightedFrequencyMeasurementVarianceParams(weightedMeasurement, metricSpec, variances) + } val frequencyVariances: FrequencyVariances? = - if (weightedMeasurementVarianceParamsList.size == weightedMeasurements.size) { - try { - variances.computeMetricVariance( - FrequencyMetricVarianceParams(weightedMeasurementVarianceParamsList)) - } catch (e: Throwable) { - failGrpc(Status.UNKNOWN) { - listOfNotNull( - "Unable to compute variance of reach-frequency metric.", - e.message, - e.cause?.message, - ) - .joinToString(separator = "\n") - } + if (weightedMeasurementVarianceParamsList.size == weightedMeasurements.size) { + try { + variances.computeMetricVariance( + FrequencyMetricVarianceParams(weightedMeasurementVarianceParamsList) + ) + } catch (e: Throwable) { + failGrpc(Status.UNKNOWN) { + listOfNotNull( + "Unable to compute variance of reach-frequency metric.", + e.message, + e.cause?.message, + ) + .joinToString(separator = "\n") } - } else { - null } + } else { + null + } return histogramResult { bins += - aggregatedFrequencyHistogramMap.map { (frequency, count) -> - bin { - label = frequency.toString() - binResult = binResult { value = count } - if (frequencyVariances != null) { - resultUnivariateStatistics = univariateStatistics { - standardDeviation = - sqrt(frequencyVariances.countVariances.getValue(frequency.toInt())) - } - relativeUnivariateStatistics = univariateStatistics { - standardDeviation = - sqrt(frequencyVariances.relativeVariances.getValue(frequency.toInt())) - } - kPlusUnivariateStatistics = univariateStatistics { - standardDeviation = - sqrt(frequencyVariances.kPlusCountVariances.getValue(frequency.toInt())) - } - relativeKPlusUnivariateStatistics = univariateStatistics { - standardDeviation = - sqrt(frequencyVariances.kPlusRelativeVariances.getValue(frequency.toInt())) - } + aggregatedFrequencyHistogramMap.map { (frequency, count) -> + bin { + label = frequency.toString() + binResult = binResult { value = count } + if (frequencyVariances != null) { + resultUnivariateStatistics = univariateStatistics { + standardDeviation = + sqrt(frequencyVariances.countVariances.getValue(frequency.toInt())) + } + relativeUnivariateStatistics = univariateStatistics { + standardDeviation = + sqrt(frequencyVariances.relativeVariances.getValue(frequency.toInt())) + } + kPlusUnivariateStatistics = univariateStatistics { + standardDeviation = + sqrt(frequencyVariances.kPlusCountVariances.getValue(frequency.toInt())) + } + relativeKPlusUnivariateStatistics = univariateStatistics { + standardDeviation = + sqrt(frequencyVariances.kPlusRelativeVariances.getValue(frequency.toInt())) } } } + } } } @@ -2214,83 +2246,81 @@ private fun calculateFrequencyHistogramResults( * @throws io.grpc.StatusRuntimeException when measurement noise mechanism is unrecognized. */ fun buildWeightedFrequencyMeasurementVarianceParams( - weightedMeasurement: WeightedMeasurement, - metricSpec: MetricSpec, - variances: Variances, + weightedMeasurement: WeightedMeasurement, + metricSpec: MetricSpec, + variances: Variances, ): WeightedFrequencyMeasurementVarianceParams? { // Get reach measurement variance params val weightedReachMeasurementVarianceParams: WeightedReachMeasurementVarianceParams = - buildWeightedReachMeasurementVarianceParams( - weightedMeasurement, - metricSpec.vidSamplingInterval, - metricSpec.reachAndFrequency.reachPrivacyParams, - ) ?: return null + buildWeightedReachMeasurementVarianceParams( + weightedMeasurement, + metricSpec.vidSamplingInterval, + metricSpec.reachAndFrequency.reachPrivacyParams, + ) ?: return null val reachMeasurementVariance: Double = - variances.computeMeasurementVariance( - weightedReachMeasurementVarianceParams.methodology, - ReachMeasurementVarianceParams( - weightedReachMeasurementVarianceParams.measurementVarianceParams.reach, - weightedReachMeasurementVarianceParams.measurementVarianceParams.measurementParams, - ), - ) + variances.computeMeasurementVariance( + weightedReachMeasurementVarianceParams.methodology, + ReachMeasurementVarianceParams( + weightedReachMeasurementVarianceParams.measurementVarianceParams.reach, + weightedReachMeasurementVarianceParams.measurementVarianceParams.measurementParams, + ), + ) val frequencyResult: InternalMeasurement.Result.Frequency = - if (weightedMeasurement.measurement.details.resultsList.size == 1) { - weightedMeasurement.measurement.details.resultsList.first().frequency - } else if (weightedMeasurement.measurement.details.resultsList.size > 1) { - failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) { - "No supported methodology generates more than one frequency result." - } - } else { - failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) { - "Frequency measurement should've had frequency results." - } + if (weightedMeasurement.measurement.details.resultsList.size == 1) { + weightedMeasurement.measurement.details.resultsList.first().frequency + } else if (weightedMeasurement.measurement.details.resultsList.size > 1) { + failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) { + "No supported methodology generates more than one frequency result." + } + } else { + failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) { + "Frequency measurement should've had frequency results." } + } val frequencyStatsNoiseMechanism: StatsNoiseMechanism = - try { - frequencyResult.noiseMechanism.toStatsNoiseMechanism() - } catch (e: NoiseMechanismUnspecifiedException) { - return null - } catch (e: NoiseMechanismUnrecognizedException) { - failGrpc(Status.UNKNOWN) { - listOfNotNull( - "Unrecognized noise mechanism should've been caught earlier.", - e.message, - e.cause?.message, - ) - .joinToString(separator = "\n") - } + try { + frequencyResult.noiseMechanism.toStatsNoiseMechanism() + } catch (e: NoiseMechanismUnspecifiedException) { + return null + } catch (e: NoiseMechanismUnrecognizedException) { + failGrpc(Status.UNKNOWN) { + listOfNotNull( + "Unrecognized noise mechanism should've been caught earlier.", + e.message, + e.cause?.message, + ) + .joinToString(separator = "\n") } + } val frequencyMethodology: Methodology = - try { - buildStatsMethodology(frequencyResult) - } catch (e: MeasurementVarianceNotComputableException) { - return null - } + try { + buildStatsMethodology(frequencyResult) + } catch (e: MeasurementVarianceNotComputableException) { + return null + } return WeightedFrequencyMeasurementVarianceParams( - binaryRepresentation = weightedMeasurement.binaryRepresentation, - weight = weightedMeasurement.weight, - measurementVarianceParams = - FrequencyMeasurementVarianceParams( - totalReach = weightedReachMeasurementVarianceParams.measurementVarianceParams.reach, - reachMeasurementVariance = reachMeasurementVariance, - relativeFrequencyDistribution = - frequencyResult.relativeFrequencyDistributionMap.mapKeys { it.key.toInt() }, - measurementParams = - FrequencyMeasurementParams( - vidSamplingInterval = - metricSpec.vidSamplingInterval.toStatsVidSamplingInterval(), - dpParams = - metricSpec.reachAndFrequency.frequencyPrivacyParams.toNoiserDpParams(), - noiseMechanism = frequencyStatsNoiseMechanism, - maximumFrequency = metricSpec.reachAndFrequency.maximumFrequency, - ), + binaryRepresentation = weightedMeasurement.binaryRepresentation, + weight = weightedMeasurement.weight, + measurementVarianceParams = + FrequencyMeasurementVarianceParams( + totalReach = weightedReachMeasurementVarianceParams.measurementVarianceParams.reach, + reachMeasurementVariance = reachMeasurementVariance, + relativeFrequencyDistribution = + frequencyResult.relativeFrequencyDistributionMap.mapKeys { it.key.toInt() }, + measurementParams = + FrequencyMeasurementParams( + vidSamplingInterval = metricSpec.vidSamplingInterval.toStatsVidSamplingInterval(), + dpParams = metricSpec.reachAndFrequency.frequencyPrivacyParams.toNoiserDpParams(), + noiseMechanism = frequencyStatsNoiseMechanism, + maximumFrequency = metricSpec.reachAndFrequency.maximumFrequency, ), - methodology = frequencyMethodology, + ), + methodology = frequencyMethodology, ) } @@ -2308,17 +2338,18 @@ fun buildStatsMethodology(frequencyResult: InternalMeasurement.Result.Frequency) } CustomDirectMethodology.Variance.TypeCase.FREQUENCY -> { CustomDirectFrequencyMethodology( - frequencyResult.customDirectMethodology.variance.frequency.variancesMap.mapKeys { - it.key.toInt() - }, - frequencyResult.customDirectMethodology.variance.frequency.kPlusVariancesMap.mapKeys { - it.key.toInt() - }, + frequencyResult.customDirectMethodology.variance.frequency.variancesMap.mapKeys { + it.key.toInt() + }, + frequencyResult.customDirectMethodology.variance.frequency.kPlusVariancesMap.mapKeys { + it.key.toInt() + }, ) } CustomDirectMethodology.Variance.TypeCase.UNAVAILABLE -> { throw MeasurementVarianceNotComputableException( - "Frequency computed from a custom methodology doesn't have variance.") + "Frequency computed from a custom methodology doesn't have variance." + ) } CustomDirectMethodology.Variance.TypeCase.TYPE_NOT_SET -> { failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) { @@ -2332,16 +2363,15 @@ fun buildStatsMethodology(frequencyResult: InternalMeasurement.Result.Frequency) } InternalMeasurement.Result.Frequency.MethodologyCase.LIQUID_LEGIONS_DISTRIBUTION -> { LiquidLegionsSketchMethodology( - decayRate = frequencyResult.liquidLegionsDistribution.decayRate, - sketchSize = frequencyResult.liquidLegionsDistribution.maxSize, + decayRate = frequencyResult.liquidLegionsDistribution.decayRate, + sketchSize = frequencyResult.liquidLegionsDistribution.maxSize, ) } InternalMeasurement.Result.Frequency.MethodologyCase.LIQUID_LEGIONS_V2 -> { LiquidLegionsV2Methodology( - decayRate = frequencyResult.liquidLegionsV2.sketchParams.decayRate, - sketchSize = frequencyResult.liquidLegionsV2.sketchParams.maxSize, - samplingIndicatorSize = - frequencyResult.liquidLegionsV2.sketchParams.samplingIndicatorSize, + decayRate = frequencyResult.liquidLegionsV2.sketchParams.decayRate, + sketchSize = frequencyResult.liquidLegionsV2.sketchParams.maxSize, + samplingIndicatorSize = frequencyResult.liquidLegionsV2.sketchParams.samplingIndicatorSize, ) } InternalMeasurement.Result.Frequency.MethodologyCase.METHODOLOGY_NOT_SET -> { @@ -2352,10 +2382,10 @@ fun buildStatsMethodology(frequencyResult: InternalMeasurement.Result.Frequency) /** Calculates the reach result from [WeightedMeasurement]s. */ private fun calculateReachResult( - weightedMeasurements: List, - vidSamplingInterval: InternalMetricSpec.VidSamplingInterval, - privacyParams: InternalMetricSpec.DifferentialPrivacyParams, - variances: Variances, + weightedMeasurements: List, + vidSamplingInterval: InternalMetricSpec.VidSamplingInterval, + privacyParams: InternalMetricSpec.DifferentialPrivacyParams, + variances: Variances, ): MetricResult.ReachResult { for (weightedMeasurement in weightedMeasurements) { if (weightedMeasurement.measurement.details.resultsList.any { !it.hasReach() }) { @@ -2367,39 +2397,41 @@ private fun calculateReachResult( return reachResult { value = - weightedMeasurements.sumOf { weightedMeasurement -> - aggregateResults(weightedMeasurement.measurement.details.resultsList).reach.value * - weightedMeasurement.weight - } + weightedMeasurements.sumOf { weightedMeasurement -> + aggregateResults(weightedMeasurement.measurement.details.resultsList).reach.value * + weightedMeasurement.weight + } val weightedMeasurementVarianceParamsList: List = - weightedMeasurements.mapNotNull { weightedMeasurement -> - buildWeightedReachMeasurementVarianceParams( - weightedMeasurement, - vidSamplingInterval, - privacyParams, - ) - } + weightedMeasurements.mapNotNull { weightedMeasurement -> + buildWeightedReachMeasurementVarianceParams( + weightedMeasurement, + vidSamplingInterval, + privacyParams, + ) + } // If any measurement contains insufficient data for variance calculation, univariate statistics // won't be computed. if (weightedMeasurementVarianceParamsList.size == weightedMeasurements.size) { univariateStatistics = univariateStatistics { standardDeviation = - sqrt( - try { - variances.computeMetricVariance( - ReachMetricVarianceParams(weightedMeasurementVarianceParamsList)) - } catch (e: Throwable) { - failGrpc(Status.UNKNOWN) { - listOfNotNull( - "Unable to compute variance of reach metric.", - e.message, - e.cause?.message, - ) - .joinToString(separator = "\n") - } - }) + sqrt( + try { + variances.computeMetricVariance( + ReachMetricVarianceParams(weightedMeasurementVarianceParamsList) + ) + } catch (e: Throwable) { + failGrpc(Status.UNKNOWN) { + listOfNotNull( + "Unable to compute variance of reach metric.", + e.message, + e.cause?.message, + ) + .joinToString(separator = "\n") + } + } + ) } } } @@ -2413,60 +2445,60 @@ private fun calculateReachResult( * @throws io.grpc.StatusRuntimeException when measurement noise mechanism is unrecognized. */ private fun buildWeightedReachMeasurementVarianceParams( - weightedMeasurement: WeightedMeasurement, - vidSamplingInterval: InternalMetricSpec.VidSamplingInterval, - privacyParams: InternalMetricSpec.DifferentialPrivacyParams, + weightedMeasurement: WeightedMeasurement, + vidSamplingInterval: InternalMetricSpec.VidSamplingInterval, + privacyParams: InternalMetricSpec.DifferentialPrivacyParams, ): WeightedReachMeasurementVarianceParams? { val reachResult = - if (weightedMeasurement.measurement.details.resultsList.size == 1) { - weightedMeasurement.measurement.details.resultsList.first().reach - } else if (weightedMeasurement.measurement.details.resultsList.size > 1) { - failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) { - "No supported methodology generates more than one reach result." - } - } else { - failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) { - "Reach measurement should've had reach results." - } + if (weightedMeasurement.measurement.details.resultsList.size == 1) { + weightedMeasurement.measurement.details.resultsList.first().reach + } else if (weightedMeasurement.measurement.details.resultsList.size > 1) { + failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) { + "No supported methodology generates more than one reach result." } + } else { + failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) { + "Reach measurement should've had reach results." + } + } val statsNoiseMechanism: StatsNoiseMechanism = - try { - reachResult.noiseMechanism.toStatsNoiseMechanism() - } catch (e: NoiseMechanismUnspecifiedException) { - return null - } catch (e: NoiseMechanismUnrecognizedException) { - failGrpc(Status.UNKNOWN) { - listOfNotNull( - "Unrecognized noise mechanism should've been caught earlier.", - e.message, - e.cause?.message, - ) - .joinToString(separator = "\n") - } + try { + reachResult.noiseMechanism.toStatsNoiseMechanism() + } catch (e: NoiseMechanismUnspecifiedException) { + return null + } catch (e: NoiseMechanismUnrecognizedException) { + failGrpc(Status.UNKNOWN) { + listOfNotNull( + "Unrecognized noise mechanism should've been caught earlier.", + e.message, + e.cause?.message, + ) + .joinToString(separator = "\n") } + } val methodology: Methodology = - try { - buildStatsMethodology(reachResult) - } catch (e: MeasurementVarianceNotComputableException) { - return null - } + try { + buildStatsMethodology(reachResult) + } catch (e: MeasurementVarianceNotComputableException) { + return null + } return WeightedReachMeasurementVarianceParams( - binaryRepresentation = weightedMeasurement.binaryRepresentation, - weight = weightedMeasurement.weight, - measurementVarianceParams = - ReachMeasurementVarianceParams( - reach = max(0L, reachResult.value), - measurementParams = - ReachMeasurementParams( - vidSamplingInterval = vidSamplingInterval.toStatsVidSamplingInterval(), - dpParams = privacyParams.toNoiserDpParams(), - noiseMechanism = statsNoiseMechanism, - ), + binaryRepresentation = weightedMeasurement.binaryRepresentation, + weight = weightedMeasurement.weight, + measurementVarianceParams = + ReachMeasurementVarianceParams( + reach = max(0L, reachResult.value), + measurementParams = + ReachMeasurementParams( + vidSamplingInterval = vidSamplingInterval.toStatsVidSamplingInterval(), + dpParams = privacyParams.toNoiserDpParams(), + noiseMechanism = statsNoiseMechanism, ), - methodology = methodology, + ), + methodology = methodology, ) } @@ -2487,7 +2519,8 @@ fun buildStatsMethodology(reachResult: InternalMeasurement.Result.Reach): Method } CustomDirectMethodology.Variance.TypeCase.UNAVAILABLE -> { throw MeasurementVarianceNotComputableException( - "Reach computed from a custom methodology doesn't have variance.") + "Reach computed from a custom methodology doesn't have variance." + ) } CustomDirectMethodology.Variance.TypeCase.TYPE_NOT_SET -> { failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) { @@ -2501,22 +2534,22 @@ fun buildStatsMethodology(reachResult: InternalMeasurement.Result.Reach): Method } InternalMeasurement.Result.Reach.MethodologyCase.LIQUID_LEGIONS_COUNT_DISTINCT -> { LiquidLegionsSketchMethodology( - decayRate = reachResult.liquidLegionsCountDistinct.decayRate, - sketchSize = reachResult.liquidLegionsCountDistinct.maxSize, + decayRate = reachResult.liquidLegionsCountDistinct.decayRate, + sketchSize = reachResult.liquidLegionsCountDistinct.maxSize, ) } InternalMeasurement.Result.Reach.MethodologyCase.LIQUID_LEGIONS_V2 -> { LiquidLegionsV2Methodology( - decayRate = reachResult.liquidLegionsV2.sketchParams.decayRate, - sketchSize = reachResult.liquidLegionsV2.sketchParams.maxSize, - samplingIndicatorSize = reachResult.liquidLegionsV2.sketchParams.samplingIndicatorSize, + decayRate = reachResult.liquidLegionsV2.sketchParams.decayRate, + sketchSize = reachResult.liquidLegionsV2.sketchParams.maxSize, + samplingIndicatorSize = reachResult.liquidLegionsV2.sketchParams.samplingIndicatorSize, ) } InternalMeasurement.Result.Reach.MethodologyCase.REACH_ONLY_LIQUID_LEGIONS_V2 -> { LiquidLegionsV2Methodology( - decayRate = reachResult.reachOnlyLiquidLegionsV2.sketchParams.decayRate, - sketchSize = reachResult.reachOnlyLiquidLegionsV2.sketchParams.maxSize, - samplingIndicatorSize = 0L, + decayRate = reachResult.reachOnlyLiquidLegionsV2.sketchParams.decayRate, + sketchSize = reachResult.reachOnlyLiquidLegionsV2.sketchParams.maxSize, + samplingIndicatorSize = 0L, ) } InternalMeasurement.Result.Reach.MethodologyCase.METHODOLOGY_NOT_SET -> { @@ -2529,7 +2562,7 @@ private operator fun ProtoDuration.times(weight: Int): ProtoDuration { val source = this return duration { val weightedTotalNanos: Long = - (TimeUnit.SECONDS.toNanos(source.seconds) + source.nanos) * weight + (TimeUnit.SECONDS.toNanos(source.seconds) + source.nanos) * weight seconds = TimeUnit.NANOSECONDS.toSeconds(weightedTotalNanos) nanos = (weightedTotalNanos % NANOS_PER_SECOND).toInt() } diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportSchedulesService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportSchedulesService.kt index 4eb8d64dfab..648e048fe19 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportSchedulesService.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportSchedulesService.kt @@ -33,6 +33,7 @@ import java.time.ZonedDateTime import java.time.temporal.TemporalAdjusters import java.time.zone.ZoneRulesException import kotlin.math.min +import kotlinx.coroutines.flow.asFlow import org.projectnessie.cel.Env import org.wfanet.measurement.api.v2alpha.DataProvider import org.wfanet.measurement.api.v2alpha.DataProviderKey @@ -67,7 +68,6 @@ import org.wfanet.measurement.internal.reporting.v2.getReportScheduleRequest import org.wfanet.measurement.internal.reporting.v2.listReportSchedulesRequest import org.wfanet.measurement.internal.reporting.v2.report as internalReport import org.wfanet.measurement.internal.reporting.v2.reportSchedule as internalReportSchedule -import kotlinx.coroutines.flow.asFlow import org.wfanet.measurement.internal.reporting.v2.stopReportScheduleRequest import org.wfanet.measurement.reporting.service.api.submitBatchRequests import org.wfanet.measurement.reporting.v2alpha.CreateReportScheduleRequest @@ -673,8 +673,11 @@ class ReportSchedulesService( while (externalReportingSetIdSet.isNotEmpty()) { retrievedExternalReportingSetIdSet.addAll(externalReportingSetIdSet) - submitBatchRequests(externalReportingSetIdSet.asFlow(), BATCH_GET_REPORTING_SETS_LIMIT, callRpc) { - response -> + submitBatchRequests( + externalReportingSetIdSet.asFlow(), + BATCH_GET_REPORTING_SETS_LIMIT, + callRpc, + ) { response -> externalReportingSetIdSet.clear() response.reportingSetsList } diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsService.kt index b677f4f1778..a6b7d7813ff 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsService.kt @@ -35,6 +35,8 @@ import java.time.temporal.Temporal import java.time.temporal.TemporalAdjusters import java.time.zone.ZoneRulesException import kotlin.math.min +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.flow import kotlinx.coroutines.flow.toList import org.projectnessie.cel.Env import org.wfanet.measurement.api.v2alpha.MeasurementConsumerKey @@ -58,8 +60,6 @@ import org.wfanet.measurement.internal.reporting.v2.batchGetMetricCalculationSpe import org.wfanet.measurement.internal.reporting.v2.createReportRequest as internalCreateReportRequest import org.wfanet.measurement.internal.reporting.v2.getReportRequest as internalGetReportRequest import org.wfanet.measurement.internal.reporting.v2.report as internalReport -import kotlinx.coroutines.flow.Flow -import kotlinx.coroutines.flow.flow import org.wfanet.measurement.reporting.service.api.submitBatchRequests import org.wfanet.measurement.reporting.service.api.v2alpha.MetadataPrincipalServerInterceptor.Companion.withPrincipalName import org.wfanet.measurement.reporting.service.api.v2alpha.ReportScheduleInfoServerInterceptor.Companion.reportScheduleInfoFromCurrentContext @@ -164,28 +164,30 @@ class ReportsService( results.subList(0, min(results.size, listReportsPageToken.pageSize)) // Get metrics. - val metricNames: Flow = - flow { - buildSet { - for (internalReport in subResults) { - for (reportingMetricEntry in internalReport.reportingMetricEntriesMap) { - for (metricCalculationSpecReportingMetrics in reportingMetricEntry.value.metricCalculationSpecReportingMetricsList) { - for (reportingMetric in metricCalculationSpecReportingMetrics.reportingMetricsList) { - val name = MetricKey( - internalReport.cmmsMeasurementConsumerId, - reportingMetric.externalMetricId - ).toName() + val metricNames: Flow = flow { + buildSet { + for (internalReport in subResults) { + for (reportingMetricEntry in internalReport.reportingMetricEntriesMap) { + for (metricCalculationSpecReportingMetrics in + reportingMetricEntry.value.metricCalculationSpecReportingMetricsList) { + for (reportingMetric in metricCalculationSpecReportingMetrics.reportingMetricsList) { + val name = + MetricKey( + internalReport.cmmsMeasurementConsumerId, + reportingMetric.externalMetricId, + ) + .toName() - if (!contains(name)) { - emit(name) - add(name) - } + if (!contains(name)) { + emit(name) + add(name) } } } } } } + } val callRpc: suspend (List) -> BatchGetMetricsResponse = { items -> batchGetMetrics(principal.resourceKey.toName(), items) @@ -254,26 +256,28 @@ class ReportsService( } // Get metrics. - val metricNames: Flow = - flow { - buildSet { - for (reportingMetricEntry in internalReport.reportingMetricEntriesMap) { - for (metricCalculationSpecReportingMetrics in reportingMetricEntry.value.metricCalculationSpecReportingMetricsList) { - for (reportingMetric in metricCalculationSpecReportingMetrics.reportingMetricsList) { - val name = MetricKey( - internalReport.cmmsMeasurementConsumerId, - reportingMetric.externalMetricId - ).toName() + val metricNames: Flow = flow { + buildSet { + for (reportingMetricEntry in internalReport.reportingMetricEntriesMap) { + for (metricCalculationSpecReportingMetrics in + reportingMetricEntry.value.metricCalculationSpecReportingMetricsList) { + for (reportingMetric in metricCalculationSpecReportingMetrics.reportingMetricsList) { + val name = + MetricKey( + internalReport.cmmsMeasurementConsumerId, + reportingMetric.externalMetricId, + ) + .toName() - if (!contains(name)) { - emit(name) - add(name) - } + if (!contains(name)) { + emit(name) + add(name) } } } } } + } val callRpc: suspend (List) -> BatchGetMetricsResponse = { items -> batchGetMetrics(principal.resourceKey.toName(), items) @@ -347,16 +351,15 @@ class ReportsService( validateTime(request.report) val externalMetricCalculationSpecIds: List = - request.report.reportingMetricEntriesList - .flatMap { reportingMetricEntry -> - reportingMetricEntry.value.metricCalculationSpecsList.map { - val key = - grpcRequireNotNull(MetricCalculationSpecKey.fromName(it)) { - "MetricCalculationSpec name $it is invalid." - } - key.metricCalculationSpecId - } + request.report.reportingMetricEntriesList.flatMap { reportingMetricEntry -> + reportingMetricEntry.value.metricCalculationSpecsList.map { + val key = + grpcRequireNotNull(MetricCalculationSpecKey.fromName(it)) { + "MetricCalculationSpec name $it is invalid." + } + key.metricCalculationSpecId } + } val externalIdToMetricCalculationSpecMap: Map = createExternalIdToMetricCalculationSpecMap( @@ -409,25 +412,26 @@ class ReportsService( } // Create metrics. - val createMetricRequests: Flow = - flow { - internalReport.reportingMetricEntriesMap.flatMap { (reportingSetId, reportingMetricCalculationSpec) -> - reportingMetricCalculationSpec.metricCalculationSpecReportingMetricsList.flatMap { metricCalculationSpecReportingMetrics -> - metricCalculationSpecReportingMetrics.reportingMetricsList.map { - emit( - it.toCreateMetricRequest( - principal.resourceKey, - reportingSetId, - externalIdToMetricCalculationSpecMap - .getValue(metricCalculationSpecReportingMetrics.externalMetricCalculationSpecId) - .details - .filter, - ) + val createMetricRequests: Flow = flow { + internalReport.reportingMetricEntriesMap.flatMap { + (reportingSetId, reportingMetricCalculationSpec) -> + reportingMetricCalculationSpec.metricCalculationSpecReportingMetricsList.flatMap { + metricCalculationSpecReportingMetrics -> + metricCalculationSpecReportingMetrics.reportingMetricsList.map { + emit( + it.toCreateMetricRequest( + principal.resourceKey, + reportingSetId, + externalIdToMetricCalculationSpecMap + .getValue(metricCalculationSpecReportingMetrics.externalMetricCalculationSpecId) + .details + .filter, ) - } + ) } } } + } val callRpc: suspend (List) -> BatchCreateMetricsResponse = { items -> batchCreateMetrics(request.parent, items) @@ -530,12 +534,11 @@ class ReportsService( if (state == Report.State.SUCCEEDED || state == Report.State.FAILED) { val externalMetricCalculationSpecIds = - internalReport.reportingMetricEntriesMap - .flatMap { reportingMetricCalculationSpec -> - reportingMetricCalculationSpec.value.metricCalculationSpecReportingMetricsList.map { - it.externalMetricCalculationSpecId - } + internalReport.reportingMetricEntriesMap.flatMap { reportingMetricCalculationSpec -> + reportingMetricCalculationSpec.value.metricCalculationSpecReportingMetricsList.map { + it.externalMetricCalculationSpecId } + } val externalIdToMetricCalculationMap: Map = createExternalIdToMetricCalculationSpecMap( diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequestsTest.kt b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequestsTest.kt index f369519cf3c..e1fcb0bc774 100644 --- a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequestsTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequestsTest.kt @@ -156,10 +156,10 @@ class SubmitBatchRequestsTest { val result: List = submitBatchRequests( - emptyFlow(), - BATCH_GET_REPORTING_SETS_LIMIT, - ::batchGetReportingSets, - parseResponse, + emptyFlow(), + BATCH_GET_REPORTING_SETS_LIMIT, + ::batchGetReportingSets, + parseResponse, ) .toList() .flatten() From dac02eda757546b5708b9505a9cad756ee6be1e2 Mon Sep 17 00:00:00 2001 From: Tristan Vuong Date: Mon, 26 Feb 2024 19:47:34 +0000 Subject: [PATCH 09/13] Replace collect with transform --- .../service/api/v2alpha/MetricsService.kt | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt index 7b95bc5b3ef..c987310df36 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt @@ -159,6 +159,7 @@ import org.wfanet.measurement.measurementconsumer.stats.LiquidLegionsSketchMetho import org.wfanet.measurement.measurementconsumer.stats.LiquidLegionsV2Methodology import org.wfanet.measurement.measurementconsumer.stats.Methodology import org.wfanet.measurement.measurementconsumer.stats.NoiseMechanism as StatsNoiseMechanism +import kotlinx.coroutines.flow.transform import org.wfanet.measurement.measurementconsumer.stats.ReachMeasurementParams import org.wfanet.measurement.measurementconsumer.stats.ReachMeasurementVarianceParams import org.wfanet.measurement.measurementconsumer.stats.ReachMetricVarianceParams @@ -805,20 +806,26 @@ class MetricsService( // Most Measurements are expected to be SUCCEEDED so SUCCEEDED Measurements will be collected // via a Flow. - val succeededMeasurements: Flow = flow { - getCmmsMeasurements(internalMeasurements, principal).collect { measurements -> + val succeededMeasurements: Flow = + getCmmsMeasurements(internalMeasurements, principal).transform { measurements -> for (measurement in measurements) { @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Protobuf enum fields cannot be null. when (measurement.state) { Measurement.State.SUCCEEDED -> emit(measurement) Measurement.State.CANCELLED, - Measurement.State.FAILED -> failedMeasurements.add(measurement) + Measurement.State.FAILED + -> failedMeasurements.add(measurement) + Measurement.State.COMPUTING, - Measurement.State.AWAITING_REQUISITION_FULFILLMENT -> {} + Measurement.State.AWAITING_REQUISITION_FULFILLMENT + -> { + } + Measurement.State.STATE_UNSPECIFIED -> failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) { "The CMMS measurement state should've been set." } + Measurement.State.UNRECOGNIZED -> { failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) { "Unrecognized CMMS measurement state." @@ -827,7 +834,6 @@ class MetricsService( } } } - } var anyUpdate = false From 265ac8037ddafe2323b44090f3715c04d2886c8c Mon Sep 17 00:00:00 2001 From: Tristan Vuong Date: Mon, 26 Feb 2024 20:13:40 +0000 Subject: [PATCH 10/13] lint fix --- .../reporting/service/api/v2alpha/MetricsService.kt | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt index c987310df36..ee61773987a 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt @@ -55,6 +55,7 @@ import kotlinx.coroutines.flow.flatMapMerge import kotlinx.coroutines.flow.flow import kotlinx.coroutines.flow.map import kotlinx.coroutines.flow.toList +import kotlinx.coroutines.flow.transform import kotlinx.coroutines.withContext import org.jetbrains.annotations.BlockingExecutor import org.jetbrains.annotations.NonBlockingExecutor @@ -159,7 +160,6 @@ import org.wfanet.measurement.measurementconsumer.stats.LiquidLegionsSketchMetho import org.wfanet.measurement.measurementconsumer.stats.LiquidLegionsV2Methodology import org.wfanet.measurement.measurementconsumer.stats.Methodology import org.wfanet.measurement.measurementconsumer.stats.NoiseMechanism as StatsNoiseMechanism -import kotlinx.coroutines.flow.transform import org.wfanet.measurement.measurementconsumer.stats.ReachMeasurementParams import org.wfanet.measurement.measurementconsumer.stats.ReachMeasurementVarianceParams import org.wfanet.measurement.measurementconsumer.stats.ReachMetricVarianceParams @@ -813,19 +813,13 @@ class MetricsService( when (measurement.state) { Measurement.State.SUCCEEDED -> emit(measurement) Measurement.State.CANCELLED, - Measurement.State.FAILED - -> failedMeasurements.add(measurement) - + Measurement.State.FAILED -> failedMeasurements.add(measurement) Measurement.State.COMPUTING, - Measurement.State.AWAITING_REQUISITION_FULFILLMENT - -> { - } - + Measurement.State.AWAITING_REQUISITION_FULFILLMENT -> {} Measurement.State.STATE_UNSPECIFIED -> failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) { "The CMMS measurement state should've been set." } - Measurement.State.UNRECOGNIZED -> { failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) { "Unrecognized CMMS measurement state." From e1bb6211c300c812cce96151d1e312bce2671128 Mon Sep 17 00:00:00 2001 From: Tristan Vuong Date: Mon, 26 Feb 2024 20:32:33 +0000 Subject: [PATCH 11/13] refactor --- .../service/api/v2alpha/MetricsService.kt | 14 +++++--- .../service/api/v2alpha/ReportsService.kt | 32 +++++++++---------- 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt index ee61773987a..41cfe25d8f7 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt @@ -51,7 +51,6 @@ import kotlinx.coroutines.coroutineScope import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.asFlow import kotlinx.coroutines.flow.count -import kotlinx.coroutines.flow.flatMapMerge import kotlinx.coroutines.flow.flow import kotlinx.coroutines.flow.map import kotlinx.coroutines.flow.toList @@ -160,6 +159,7 @@ import org.wfanet.measurement.measurementconsumer.stats.LiquidLegionsSketchMetho import org.wfanet.measurement.measurementconsumer.stats.LiquidLegionsV2Methodology import org.wfanet.measurement.measurementconsumer.stats.Methodology import org.wfanet.measurement.measurementconsumer.stats.NoiseMechanism as StatsNoiseMechanism +import kotlinx.coroutines.flow.flattenMerge import org.wfanet.measurement.measurementconsumer.stats.ReachMeasurementParams import org.wfanet.measurement.measurementconsumer.stats.ReachMeasurementVarianceParams import org.wfanet.measurement.measurementconsumer.stats.ReachMetricVarianceParams @@ -314,12 +314,15 @@ class MetricsService( // Gets all external IDs of primitive reporting sets from the metric list. val externalPrimitiveReportingSetIds: Flow = flow { - buildSet { + buildSet { for (internalMetric in internalMetricsList) { for (weightedMeasurement in internalMetric.weightedMeasurementsList) { for (primitiveReportingSetBasis in weightedMeasurement.measurement.primitiveReportingSetBasesList) { + // Checks if the set already contains the ID if (!contains(primitiveReportingSetBasis.externalReportingSetId)) { + // If the set doesn't contain the ID, emit it and add it to the set so it won't + // get emitted again. emit(primitiveReportingSetBasis.externalReportingSetId) add(primitiveReportingSetBasis.externalReportingSetId) } @@ -397,7 +400,8 @@ class MetricsService( ) { response: BatchCreateMeasurementsResponse -> response.measurementsList } - .flatMapMerge { it.asFlow() } + .map { it.asFlow() } + .flattenMerge() // Set CMMS measurement IDs. val callBatchSetCmmsMeasurementIdsRpc: @@ -941,8 +945,10 @@ class MetricsService( internalMeasurement.cmmsMeasurementId, ) .toName() - + // Checks if the set already contains the name if (!contains(name)) { + // If the set doesn't contain the name, emit it and add it to the set so it won't + // get emitted again. emit(name) add(name) } diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsService.kt index a6b7d7813ff..3a543f08b7f 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsService.kt @@ -60,6 +60,10 @@ import org.wfanet.measurement.internal.reporting.v2.batchGetMetricCalculationSpe import org.wfanet.measurement.internal.reporting.v2.createReportRequest as internalCreateReportRequest import org.wfanet.measurement.internal.reporting.v2.getReportRequest as internalGetReportRequest import org.wfanet.measurement.internal.reporting.v2.report as internalReport +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.flow.asFlow +import kotlinx.coroutines.flow.flatMapMerge +import kotlinx.coroutines.flow.map import org.wfanet.measurement.reporting.service.api.submitBatchRequests import org.wfanet.measurement.reporting.service.api.v2alpha.MetadataPrincipalServerInterceptor.Companion.withPrincipalName import org.wfanet.measurement.reporting.service.api.v2alpha.ReportScheduleInfoServerInterceptor.Companion.reportScheduleInfoFromCurrentContext @@ -412,26 +416,22 @@ class ReportsService( } // Create metrics. - val createMetricRequests: Flow = flow { - internalReport.reportingMetricEntriesMap.flatMap { - (reportingSetId, reportingMetricCalculationSpec) -> - reportingMetricCalculationSpec.metricCalculationSpecReportingMetricsList.flatMap { - metricCalculationSpecReportingMetrics -> - metricCalculationSpecReportingMetrics.reportingMetricsList.map { - emit( - it.toCreateMetricRequest( - principal.resourceKey, - reportingSetId, - externalIdToMetricCalculationSpecMap - .getValue(metricCalculationSpecReportingMetrics.externalMetricCalculationSpecId) - .details - .filter, - ) + val createMetricRequests: Flow = + @OptIn(ExperimentalCoroutinesApi::class) + internalReport.reportingMetricEntriesMap.entries.asFlow().flatMapMerge { entry -> + entry.value.metricCalculationSpecReportingMetricsList.asFlow().flatMapMerge { metricCalculationSpecReportingMetrics -> + metricCalculationSpecReportingMetrics.reportingMetricsList.asFlow().map { + it.toCreateMetricRequest( + principal.resourceKey, + entry.key, + externalIdToMetricCalculationSpecMap + .getValue(metricCalculationSpecReportingMetrics.externalMetricCalculationSpecId) + .details + .filter, ) } } } - } val callRpc: suspend (List) -> BatchCreateMetricsResponse = { items -> batchCreateMetrics(request.parent, items) From 72bc1315ae4e468c21fe3636433ae1e73ca7b968 Mon Sep 17 00:00:00 2001 From: Tristan Vuong Date: Mon, 26 Feb 2024 20:34:16 +0000 Subject: [PATCH 12/13] lint fix --- .../reporting/service/api/v2alpha/ReportsService.kt | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsService.kt index 3a543f08b7f..2639b0f8f34 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsService.kt @@ -35,8 +35,12 @@ import java.time.temporal.Temporal import java.time.temporal.TemporalAdjusters import java.time.zone.ZoneRulesException import kotlin.math.min +import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.asFlow +import kotlinx.coroutines.flow.flatMapMerge import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.flow.map import kotlinx.coroutines.flow.toList import org.projectnessie.cel.Env import org.wfanet.measurement.api.v2alpha.MeasurementConsumerKey @@ -60,10 +64,6 @@ import org.wfanet.measurement.internal.reporting.v2.batchGetMetricCalculationSpe import org.wfanet.measurement.internal.reporting.v2.createReportRequest as internalCreateReportRequest import org.wfanet.measurement.internal.reporting.v2.getReportRequest as internalGetReportRequest import org.wfanet.measurement.internal.reporting.v2.report as internalReport -import kotlinx.coroutines.ExperimentalCoroutinesApi -import kotlinx.coroutines.flow.asFlow -import kotlinx.coroutines.flow.flatMapMerge -import kotlinx.coroutines.flow.map import org.wfanet.measurement.reporting.service.api.submitBatchRequests import org.wfanet.measurement.reporting.service.api.v2alpha.MetadataPrincipalServerInterceptor.Companion.withPrincipalName import org.wfanet.measurement.reporting.service.api.v2alpha.ReportScheduleInfoServerInterceptor.Companion.reportScheduleInfoFromCurrentContext @@ -419,7 +419,8 @@ class ReportsService( val createMetricRequests: Flow = @OptIn(ExperimentalCoroutinesApi::class) internalReport.reportingMetricEntriesMap.entries.asFlow().flatMapMerge { entry -> - entry.value.metricCalculationSpecReportingMetricsList.asFlow().flatMapMerge { metricCalculationSpecReportingMetrics -> + entry.value.metricCalculationSpecReportingMetricsList.asFlow().flatMapMerge { + metricCalculationSpecReportingMetrics -> metricCalculationSpecReportingMetrics.reportingMetricsList.asFlow().map { it.toCreateMetricRequest( principal.resourceKey, From 77b11203935bde78ff70855a04978afcb30e3ce3 Mon Sep 17 00:00:00 2001 From: Tristan Vuong Date: Mon, 26 Feb 2024 20:35:53 +0000 Subject: [PATCH 13/13] lint fix --- .../measurement/reporting/service/api/v2alpha/MetricsService.kt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt index 41cfe25d8f7..87bb6876b6d 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt @@ -51,6 +51,7 @@ import kotlinx.coroutines.coroutineScope import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.asFlow import kotlinx.coroutines.flow.count +import kotlinx.coroutines.flow.flattenMerge import kotlinx.coroutines.flow.flow import kotlinx.coroutines.flow.map import kotlinx.coroutines.flow.toList @@ -159,7 +160,6 @@ import org.wfanet.measurement.measurementconsumer.stats.LiquidLegionsSketchMetho import org.wfanet.measurement.measurementconsumer.stats.LiquidLegionsV2Methodology import org.wfanet.measurement.measurementconsumer.stats.Methodology import org.wfanet.measurement.measurementconsumer.stats.NoiseMechanism as StatsNoiseMechanism -import kotlinx.coroutines.flow.flattenMerge import org.wfanet.measurement.measurementconsumer.stats.ReachMeasurementParams import org.wfanet.measurement.measurementconsumer.stats.ReachMeasurementVarianceParams import org.wfanet.measurement.measurementconsumer.stats.ReachMetricVarianceParams