diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequests.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequests.kt index 12fdae3c352..c270a7a67d3 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequests.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequests.kt @@ -16,11 +16,13 @@ package org.wfanet.measurement.reporting.service.api -import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.Deferred +import kotlinx.coroutines.async +import kotlinx.coroutines.coroutineScope import kotlinx.coroutines.flow.Flow -import kotlinx.coroutines.flow.asFlow -import kotlinx.coroutines.flow.flatMapConcat import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.sync.Semaphore +import kotlinx.coroutines.sync.withPermit class BatchRequestException(message: String? = null, cause: Throwable? = null) : Exception(message, cause) @@ -47,14 +49,17 @@ fun Flow.chunked(chunkSize: Int): Flow> { } } -/** Submits multiple RPCs by dividing the input items to batches. */ -@OptIn(ExperimentalCoroutinesApi::class) // For `flatMapConcat`. +/** + * Submits multiple RPCs by dividing the input items to batches. + * + * @return [Flow] that emits [List]s containing the results of the multiple RPCs. + */ suspend fun submitBatchRequests( items: Flow, limit: Int, callRpc: suspend (List) -> RESP, parseResponse: (RESP) -> List, -): Flow { +): Flow> { if (limit <= 0) { throw BatchRequestException( "Invalid limit", @@ -62,5 +67,21 @@ suspend fun submitBatchRequests( ) } - return items.chunked(limit).flatMapConcat { batch -> parseResponse(callRpc(batch)).asFlow() } + // For network requests, the number of concurrent coroutines needs to be capped. To be on the safe + // side, a low number is chosen. + val batchSemaphore = Semaphore(3) + return flow { + coroutineScope { + val deferred: List>> = buildList { + items.chunked(limit).collect { batch: List -> + // The batch reference is reused for every collect call. To ensure async works, a copy + // of the contents needs to be saved in a new reference. + val tempBatch = batch.toList() + add(async { batchSemaphore.withPermit { parseResponse(callRpc(tempBatch)) } }) + } + } + + deferred.forEach { emit(it.await()) } + } + } } diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt index fb02e53faf5..87bb6876b6d 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt @@ -43,14 +43,19 @@ import kotlin.math.sqrt import kotlinx.coroutines.CoroutineDispatcher import kotlinx.coroutines.Deferred import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.asExecutor import kotlinx.coroutines.async import kotlinx.coroutines.awaitAll import kotlinx.coroutines.coroutineScope import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.asFlow +import kotlinx.coroutines.flow.count +import kotlinx.coroutines.flow.flattenMerge +import kotlinx.coroutines.flow.flow import kotlinx.coroutines.flow.map import kotlinx.coroutines.flow.toList +import kotlinx.coroutines.flow.transform import kotlinx.coroutines.withContext import org.jetbrains.annotations.BlockingExecutor import org.jetbrains.annotations.NonBlockingExecutor @@ -308,17 +313,24 @@ class MetricsService( val measurementConsumer: MeasurementConsumer = getMeasurementConsumer(principal) // Gets all external IDs of primitive reporting sets from the metric list. - val externalPrimitiveReportingSetIds: Flow = - internalMetricsList - .flatMap { internalMetric -> - internalMetric.weightedMeasurementsList.flatMap { weightedMeasurement -> - weightedMeasurement.measurement.primitiveReportingSetBasesList.map { - it.externalReportingSetId + val externalPrimitiveReportingSetIds: Flow = flow { + buildSet { + for (internalMetric in internalMetricsList) { + for (weightedMeasurement in internalMetric.weightedMeasurementsList) { + for (primitiveReportingSetBasis in + weightedMeasurement.measurement.primitiveReportingSetBasesList) { + // Checks if the set already contains the ID + if (!contains(primitiveReportingSetBasis.externalReportingSetId)) { + // If the set doesn't contain the ID, emit it and add it to the set so it won't + // get emitted again. + emit(primitiveReportingSetBasis.externalReportingSetId) + add(primitiveReportingSetBasis.externalReportingSetId) + } } } } - .distinct() - .asFlow() + } + } val callBatchGetInternalReportingSetsRpc: suspend (List) -> BatchGetReportingSetsResponse = @@ -326,7 +338,7 @@ class MetricsService( batchGetInternalReportingSets(principal.resourceKey.measurementConsumerId, items) } - val internalPrimitiveReportingSetMap: Map = + val internalPrimitiveReportingSetMap: Map = buildMap { submitBatchRequests( externalPrimitiveReportingSetIds, BATCH_GET_REPORTING_SETS_LIMIT, @@ -334,8 +346,12 @@ class MetricsService( ) { response: BatchGetReportingSetsResponse -> response.reportingSetsList } - .toList() - .associateBy { it.externalReportingSetId } + .collect { reportingSets: List -> + for (reportingSet in reportingSets) { + computeIfAbsent(reportingSet.externalReportingSetId) { reportingSet } + } + } + } val dataProviderNames = mutableSetOf() for (internalPrimitiveReportingSet in internalPrimitiveReportingSetMap.values) { @@ -348,22 +364,25 @@ class MetricsService( val measurementConsumerSigningKey = getMeasurementConsumerSigningKey(principal) - val cmmsCreateMeasurementRequests: List = - internalMetricsList.flatMap { internalMetric -> - internalMetric.weightedMeasurementsList - .filter { it.measurement.cmmsMeasurementId.isBlank() } - .map { - buildCreateMeasurementRequest( - it.measurement, - internalMetric.metricSpec, - internalPrimitiveReportingSetMap, - measurementConsumer, - principal, - dataProviderInfoMap, - measurementConsumerSigningKey, + val cmmsCreateMeasurementRequests: Flow = flow { + for (internalMetric in internalMetricsList) { + for (weightedMeasurement in internalMetric.weightedMeasurementsList) { + if (weightedMeasurement.measurement.cmmsMeasurementId.isBlank()) { + emit( + buildCreateMeasurementRequest( + weightedMeasurement.measurement, + internalMetric.metricSpec, + internalPrimitiveReportingSetMap, + measurementConsumer, + principal, + dataProviderInfoMap, + measurementConsumerSigningKey, + ) ) } + } } + } // Create CMMS measurements. val callBatchCreateMeasurementsRpc: @@ -372,14 +391,17 @@ class MetricsService( batchCreateCmmsMeasurements(principal, items) } + @OptIn(ExperimentalCoroutinesApi::class) val cmmsMeasurements: Flow = submitBatchRequests( - cmmsCreateMeasurementRequests.asFlow(), - BATCH_KINGDOM_MEASUREMENTS_LIMIT, - callBatchCreateMeasurementsRpc, - ) { response: BatchCreateMeasurementsResponse -> - response.measurementsList - } + cmmsCreateMeasurementRequests, + BATCH_KINGDOM_MEASUREMENTS_LIMIT, + callBatchCreateMeasurementsRpc, + ) { response: BatchCreateMeasurementsResponse -> + response.measurementsList + } + .map { it.asFlow() } + .flattenMerge() // Set CMMS measurement IDs. val callBatchSetCmmsMeasurementIdsRpc: @@ -400,7 +422,7 @@ class MetricsService( ) { response: BatchSetCmmsMeasurementIdsResponse -> response.measurementsList } - .toList() + .collect {} } /** Sets a batch of CMMS [MeasurementIds] to the [InternalMeasurement] table. */ @@ -784,65 +806,70 @@ class MetricsService( apiAuthenticationKey: String, principal: MeasurementConsumerPrincipal, ): Boolean { - val newStateToCmmsMeasurements: Map> = - getCmmsMeasurements(internalMeasurements, principal).groupBy { measurement -> - measurement.state + val failedMeasurements: MutableList = mutableListOf() + + // Most Measurements are expected to be SUCCEEDED so SUCCEEDED Measurements will be collected + // via a Flow. + val succeededMeasurements: Flow = + getCmmsMeasurements(internalMeasurements, principal).transform { measurements -> + for (measurement in measurements) { + @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Protobuf enum fields cannot be null. + when (measurement.state) { + Measurement.State.SUCCEEDED -> emit(measurement) + Measurement.State.CANCELLED, + Measurement.State.FAILED -> failedMeasurements.add(measurement) + Measurement.State.COMPUTING, + Measurement.State.AWAITING_REQUISITION_FULFILLMENT -> {} + Measurement.State.STATE_UNSPECIFIED -> + failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) { + "The CMMS measurement state should've been set." + } + Measurement.State.UNRECOGNIZED -> { + failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) { + "Unrecognized CMMS measurement state." + } + } + } + } } var anyUpdate = false - for ((newState, measurementsList) in newStateToCmmsMeasurements) { - when (newState) { - Measurement.State.SUCCEEDED -> { - val callBatchSetInternalMeasurementResultsRpc: - suspend (List) -> BatchSetCmmsMeasurementResultsResponse = - { items -> - batchSetInternalMeasurementResults(items, apiAuthenticationKey, principal) - } - submitBatchRequests( - measurementsList.asFlow(), - BATCH_SET_MEASUREMENT_RESULTS_LIMIT, - callBatchSetInternalMeasurementResultsRpc, - ) { response: BatchSetCmmsMeasurementResultsResponse -> - response.measurementsList - } - .toList() - - anyUpdate = true + val callBatchSetInternalMeasurementResultsRpc: + suspend (List) -> BatchSetCmmsMeasurementResultsResponse = + { items -> + batchSetInternalMeasurementResults(items, apiAuthenticationKey, principal) + } + val count = + submitBatchRequests( + succeededMeasurements, + BATCH_SET_MEASUREMENT_RESULTS_LIMIT, + callBatchSetInternalMeasurementResultsRpc, + ) { response: BatchSetCmmsMeasurementResultsResponse -> + response.measurementsList } - Measurement.State.AWAITING_REQUISITION_FULFILLMENT, - Measurement.State.COMPUTING -> {} // Do nothing. - Measurement.State.FAILED, - Measurement.State.CANCELLED -> { - val callBatchSetInternalMeasurementFailuresRpc: - suspend (List) -> BatchSetCmmsMeasurementFailuresResponse = - { items -> - batchSetInternalMeasurementFailures( - items, - principal.resourceKey.measurementConsumerId, - ) - } - submitBatchRequests( - measurementsList.asFlow(), - BATCH_SET_MEASUREMENT_FAILURES_LIMIT, - callBatchSetInternalMeasurementFailuresRpc, - ) { response: BatchSetCmmsMeasurementFailuresResponse -> - response.measurementsList - } - .toList() + .count() + + if (count > 0) { + anyUpdate = true + } - anyUpdate = true + if (failedMeasurements.isNotEmpty()) { + val callBatchSetInternalMeasurementFailuresRpc: + suspend (List) -> BatchSetCmmsMeasurementFailuresResponse = + { items -> + batchSetInternalMeasurementFailures(items, principal.resourceKey.measurementConsumerId) } - Measurement.State.STATE_UNSPECIFIED -> - failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) { - "The CMMS measurement state should've been set." - } - Measurement.State.UNRECOGNIZED -> { - failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) { - "Unrecognized CMMS measurement state." - } + submitBatchRequests( + failedMeasurements.asFlow(), + BATCH_SET_MEASUREMENT_FAILURES_LIMIT, + callBatchSetInternalMeasurementFailuresRpc, + ) { response: BatchSetCmmsMeasurementFailuresResponse -> + response.measurementsList } - } + .collect {} + + anyUpdate = true } return anyUpdate @@ -908,17 +935,26 @@ class MetricsService( private suspend fun getCmmsMeasurements( internalMeasurements: List, principal: MeasurementConsumerPrincipal, - ): List { - val measurementNames: List = - internalMeasurements - .map { internalMeasurement -> - MeasurementKey( - principal.resourceKey.measurementConsumerId, - internalMeasurement.cmmsMeasurementId, - ) - .toName() + ): Flow> { + val measurementNames: Flow = flow { + buildSet { + for (internalMeasurement in internalMeasurements) { + val name = + MeasurementKey( + principal.resourceKey.measurementConsumerId, + internalMeasurement.cmmsMeasurementId, + ) + .toName() + // Checks if the set already contains the name + if (!contains(name)) { + // If the set doesn't contain the name, emit it and add it to the set so it won't + // get emitted again. + emit(name) + add(name) + } } - .distinct() + } + } val callBatchGetMeasurementsRpc: suspend (List) -> BatchGetMeasurementsResponse = { items -> @@ -926,13 +962,12 @@ class MetricsService( } return submitBatchRequests( - measurementNames.asFlow(), - BATCH_KINGDOM_MEASUREMENTS_LIMIT, - callBatchGetMeasurementsRpc, - ) { response: BatchGetMeasurementsResponse -> - response.measurementsList - } - .toList() + measurementNames, + BATCH_KINGDOM_MEASUREMENTS_LIMIT, + callBatchGetMeasurementsRpc, + ) { response: BatchGetMeasurementsResponse -> + response.measurementsList + } } /** Batch get CMMS measurements. */ diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportSchedulesService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportSchedulesService.kt index bdde1aa2af9..2b6391e54cf 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportSchedulesService.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportSchedulesService.kt @@ -691,23 +691,27 @@ class ReportSchedulesService( externalReportingSetIdSet.clear() response.reportingSetsList } - .collect { - if (it.hasComposite()) { - val lhsExternalReportingSetId = it.composite.lhs.externalReportingSetId - if (lhsExternalReportingSetId.isNotEmpty()) { - if (!retrievedExternalReportingSetIdSet.contains(lhsExternalReportingSetId)) { - externalReportingSetIdSet.add(lhsExternalReportingSetId) + .collect { internalReportingSets: List -> + for (internalReportingSet in internalReportingSets) { + if (internalReportingSet.hasComposite()) { + val lhsExternalReportingSetId = + internalReportingSet.composite.lhs.externalReportingSetId + if (lhsExternalReportingSetId.isNotEmpty()) { + if (!retrievedExternalReportingSetIdSet.contains(lhsExternalReportingSetId)) { + externalReportingSetIdSet.add(lhsExternalReportingSetId) + } } - } - val rhsExternalReportingSetId = it.composite.rhs.externalReportingSetId - if (rhsExternalReportingSetId.isNotEmpty()) { - if (!retrievedExternalReportingSetIdSet.contains(rhsExternalReportingSetId)) { - externalReportingSetIdSet.add(rhsExternalReportingSetId) + val rhsExternalReportingSetId = + internalReportingSet.composite.rhs.externalReportingSetId + if (rhsExternalReportingSetId.isNotEmpty()) { + if (!retrievedExternalReportingSetIdSet.contains(rhsExternalReportingSetId)) { + externalReportingSetIdSet.add(rhsExternalReportingSetId) + } } } + reportingSets.add(internalReportingSet) } - reportingSets.add(it) } } diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsService.kt index 702bd952140..2639b0f8f34 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/ReportsService.kt @@ -35,8 +35,12 @@ import java.time.temporal.Temporal import java.time.temporal.TemporalAdjusters import java.time.zone.ZoneRulesException import kotlin.math.min +import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.asFlow +import kotlinx.coroutines.flow.flatMapMerge +import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.flow.map import kotlinx.coroutines.flow.toList import org.projectnessie.cel.Env import org.wfanet.measurement.api.v2alpha.MeasurementConsumerKey @@ -164,18 +168,44 @@ class ReportsService( results.subList(0, min(results.size, listReportsPageToken.pageSize)) // Get metrics. - val metricNames: Flow = - subResults.flatMap { internalReport -> internalReport.metricNames }.distinct().asFlow() + val metricNames: Flow = flow { + buildSet { + for (internalReport in subResults) { + for (reportingMetricEntry in internalReport.reportingMetricEntriesMap) { + for (metricCalculationSpecReportingMetrics in + reportingMetricEntry.value.metricCalculationSpecReportingMetricsList) { + for (reportingMetric in metricCalculationSpecReportingMetrics.reportingMetricsList) { + val name = + MetricKey( + internalReport.cmmsMeasurementConsumerId, + reportingMetric.externalMetricId, + ) + .toName() + + if (!contains(name)) { + emit(name) + add(name) + } + } + } + } + } + } + } val callRpc: suspend (List) -> BatchGetMetricsResponse = { items -> batchGetMetrics(principal.resourceKey.toName(), items) } - val externalIdToMetricMap: Map = + val externalIdToMetricMap: Map = buildMap { submitBatchRequests(metricNames, BATCH_GET_METRICS_LIMIT, callRpc) { response -> response.metricsList } - .toList() - .associateBy { checkNotNull(MetricKey.fromName(it.name)).metricId } + .collect { metrics: List -> + for (metric in metrics) { + computeIfAbsent(checkNotNull(MetricKey.fromName(metric.name)).metricId) { metric } + } + } + } return listReportsResponse { reports += @@ -230,17 +260,42 @@ class ReportsService( } // Get metrics. - val metricNames: Flow = internalReport.metricNames.distinct().asFlow() + val metricNames: Flow = flow { + buildSet { + for (reportingMetricEntry in internalReport.reportingMetricEntriesMap) { + for (metricCalculationSpecReportingMetrics in + reportingMetricEntry.value.metricCalculationSpecReportingMetricsList) { + for (reportingMetric in metricCalculationSpecReportingMetrics.reportingMetricsList) { + val name = + MetricKey( + internalReport.cmmsMeasurementConsumerId, + reportingMetric.externalMetricId, + ) + .toName() + + if (!contains(name)) { + emit(name) + add(name) + } + } + } + } + } + } val callRpc: suspend (List) -> BatchGetMetricsResponse = { items -> batchGetMetrics(principal.resourceKey.toName(), items) } - val externalIdToMetricMap: Map = + val externalIdToMetricMap: Map = buildMap { submitBatchRequests(metricNames, BATCH_GET_METRICS_LIMIT, callRpc) { response -> response.metricsList } - .toList() - .associateBy { checkNotNull(MetricKey.fromName(it.name)).metricId } + .collect { metrics: List -> + for (metric in metrics) { + computeIfAbsent(checkNotNull(MetricKey.fromName(metric.name)).metricId) { metric } + } + } + } // Convert the internal report to public and return. return convertInternalReportToPublic(internalReport, externalIdToMetricMap) @@ -309,6 +364,7 @@ class ReportsService( key.metricCalculationSpecId } } + val externalIdToMetricCalculationSpecMap: Map = createExternalIdToMetricCalculationSpecMap( parentKey.measurementConsumerId, @@ -361,34 +417,36 @@ class ReportsService( // Create metrics. val createMetricRequests: Flow = - internalReport.reportingMetricEntriesMap - .flatMap { (reportingSetId, reportingMetricCalculationSpec) -> - reportingMetricCalculationSpec.metricCalculationSpecReportingMetricsList.flatMap { - metricCalculationSpecReportingMetrics -> - metricCalculationSpecReportingMetrics.reportingMetricsList.map { - it.toCreateMetricRequest( - principal.resourceKey, - reportingSetId, - externalIdToMetricCalculationSpecMap - .getValue(metricCalculationSpecReportingMetrics.externalMetricCalculationSpecId) - .details - .filter, - ) - } + @OptIn(ExperimentalCoroutinesApi::class) + internalReport.reportingMetricEntriesMap.entries.asFlow().flatMapMerge { entry -> + entry.value.metricCalculationSpecReportingMetricsList.asFlow().flatMapMerge { + metricCalculationSpecReportingMetrics -> + metricCalculationSpecReportingMetrics.reportingMetricsList.asFlow().map { + it.toCreateMetricRequest( + principal.resourceKey, + entry.key, + externalIdToMetricCalculationSpecMap + .getValue(metricCalculationSpecReportingMetrics.externalMetricCalculationSpecId) + .details + .filter, + ) } } - .asFlow() + } val callRpc: suspend (List) -> BatchCreateMetricsResponse = { items -> batchCreateMetrics(request.parent, items) } - val externalIdToMetricMap: Map = - submitBatchRequests(createMetricRequests, BATCH_CREATE_METRICS_LIMIT, callRpc) { - response: BatchCreateMetricsResponse -> + val externalIdToMetricMap: Map = buildMap { + submitBatchRequests(createMetricRequests, BATCH_CREATE_METRICS_LIMIT, callRpc) { response -> response.metricsList } - .toList() - .associateBy { checkNotNull(MetricKey.fromName(it.name)).metricId } + .collect { metrics: List -> + for (metric in metrics) { + computeIfAbsent(checkNotNull(MetricKey.fromName(metric.name)).metricId) { metric } + } + } + } // Once all metrics are created, get the updated internal report with the metric IDs filled. val updatedInternalReport = @@ -482,6 +540,7 @@ class ReportsService( it.externalMetricCalculationSpecId } } + val externalIdToMetricCalculationMap: Map = createExternalIdToMetricCalculationSpecMap( internalReport.cmmsMeasurementConsumerId, @@ -834,12 +893,6 @@ private val InternalReport.externalMetricIds: List } } -private val InternalReport.metricNames: List - get() = - externalMetricIds.map { externalMetricId -> - MetricKey(cmmsMeasurementConsumerId, externalMetricId).toName() - } - /** Converts a public [ListReportsRequest] to a [ListReportsPageToken]. */ private fun ListReportsRequest.toListReportsPageToken(): ListReportsPageToken { grpcRequire(pageSize >= 0) { "Page size cannot be less than 0" } diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/internal/testing/v2/ReportsServiceTest.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/internal/testing/v2/ReportsServiceTest.kt index f3bf4853844..1a2b6e415c9 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/internal/testing/v2/ReportsServiceTest.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/internal/testing/v2/ReportsServiceTest.kt @@ -30,7 +30,6 @@ import io.grpc.StatusRuntimeException import java.time.Clock import kotlin.random.Random import kotlin.test.assertFailsWith -import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.asFlow import kotlinx.coroutines.flow.toList import kotlinx.coroutines.runBlocking @@ -635,32 +634,31 @@ abstract class ReportsServiceTest { ) var metricIndex = 0 - val createMetricsRequests: Flow = - createdReport.reportingMetricEntriesMap.entries - .flatMap { entry -> - val reportingSet = createdReportingSetsByExternalId.getValue(entry.key) - - entry.value.metricCalculationSpecReportingMetricsList.flatMap { - metricCalculationSpecReportingMetrics -> - val metricCalculationSpecFilter = - createdMetricCalculationSpecsByExternalId - .getValue(metricCalculationSpecReportingMetrics.externalMetricCalculationSpecId) - .details - .filter - metricCalculationSpecReportingMetrics.reportingMetricsList.map { reportingMetric -> - val externalMetricId = "externalMetricId$metricIndex" - metricIndex++ - buildCreateMetricRequest( - createdReport.cmmsMeasurementConsumerId, - externalMetricId, - reportingSet, - reportingMetric, - metricCalculationSpecFilter, - ) - } + val createMetricsRequests: List = + createdReport.reportingMetricEntriesMap.entries.flatMap { entry -> + val reportingSet = createdReportingSetsByExternalId.getValue(entry.key) + + entry.value.metricCalculationSpecReportingMetricsList.flatMap { + metricCalculationSpecReportingMetrics -> + val metricCalculationSpecFilter = + createdMetricCalculationSpecsByExternalId + .getValue(metricCalculationSpecReportingMetrics.externalMetricCalculationSpecId) + .details + .filter + metricCalculationSpecReportingMetrics.reportingMetricsList.map { reportingMetric -> + val externalMetricId = "externalMetricId$metricIndex" + metricIndex++ + buildCreateMetricRequest( + createdReport.cmmsMeasurementConsumerId, + externalMetricId, + reportingSet, + reportingMetric, + metricCalculationSpecFilter, + ) } } - .asFlow() + } + val callRpc: suspend (List) -> BatchCreateMetricsResponse = { items -> metricsService.batchCreateMetrics( batchCreateMetricsRequest { @@ -669,10 +667,10 @@ abstract class ReportsServiceTest { } ) } - submitBatchRequests(createMetricsRequests, MAX_BATCH_SIZE, callRpc) { response -> + submitBatchRequests(createMetricsRequests.asFlow(), MAX_BATCH_SIZE, callRpc) { response -> response.metricsList } - .toList() + .collect {} val retrievedReport = service.getReport( diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequestsTest.kt b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequestsTest.kt index a9537be540f..e1fcb0bc774 100644 --- a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequestsTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/SubmitBatchRequestsTest.kt @@ -21,7 +21,7 @@ import io.grpc.Status import io.grpc.StatusException import kotlin.math.ceil import kotlinx.coroutines.flow.asFlow -import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.flow.emptyFlow import kotlinx.coroutines.flow.toList import kotlinx.coroutines.runBlocking import org.junit.Before @@ -104,6 +104,7 @@ class SubmitBatchRequestsTest { parseResponse, ) .toList() + .flatten() val batchGetReportingSetsCaptor: KArgumentCaptor = argumentCaptor() @@ -136,6 +137,7 @@ class SubmitBatchRequestsTest { parseResponse, ) .toList() + .flatten() val batchGetReportingSetsCaptor: KArgumentCaptor = argumentCaptor() @@ -154,12 +156,13 @@ class SubmitBatchRequestsTest { val result: List = submitBatchRequests( - flow {}, + emptyFlow(), BATCH_GET_REPORTING_SETS_LIMIT, ::batchGetReportingSets, parseResponse, ) .toList() + .flatten() val batchGetReportingSetsCaptor: KArgumentCaptor = argumentCaptor()