From c9e982c43788d7565343926af9ab41b371e7d5cf Mon Sep 17 00:00:00 2001 From: Tristan Vuong <85768771+tristanvuong2021@users.noreply.github.com> Date: Fri, 8 Mar 2024 13:34:59 -0800 Subject: [PATCH] Move reporting sets call out of build internal create metric request method (#1499) --- .../service/api/v2alpha/MetricsService.kt | 114 ++++++++++++------ 1 file changed, 79 insertions(+), 35 deletions(-) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt index 764d21713b5..6c162266325 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha/MetricsService.kt @@ -1384,8 +1384,20 @@ class MetricsService( } } + grpcRequire(request.hasMetric()) { "Metric is not specified." } + + val batchGetReportingSetsResponse = + batchGetInternalReportingSets( + parentKey.measurementConsumerId, + listOf(request.metric.reportingSet), + ) + val internalCreateMetricRequest: InternalCreateMetricRequest = - buildInternalCreateMetricRequest(principal.resourceKey.measurementConsumerId, request) + buildInternalCreateMetricRequest( + principal.resourceKey.measurementConsumerId, + request, + batchGetReportingSetsResponse.reportingSetsList.first(), + ) val internalMetric = try { @@ -1446,9 +1458,48 @@ class MetricsService( "Duplicate metric IDs in the request." } - val internalCreateMetricRequestsList: List = - request.requestsList.map { createMetricRequest -> - buildInternalCreateMetricRequest(parentKey.measurementConsumerId, createMetricRequest) + val reportingSetNames = + request.requestsList + .map { + grpcRequire(it.hasMetric()) { "Metric is not specified." } + + it.metric.reportingSet + } + .distinct() + + val callRpc: suspend (List) -> BatchGetReportingSetsResponse = { items -> + batchGetInternalReportingSets(parentKey.measurementConsumerId, items) + } + + val reportingSetNameToInternalReportingSetMap: Map = buildMap { + submitBatchRequests(reportingSetNames.asFlow(), BATCH_GET_REPORTING_SETS_LIMIT, callRpc) { + response -> + response.reportingSetsList + } + .collect { reportingSetsList -> + for (reportingSet in reportingSetsList) { + putIfAbsent( + ReportingSetKey(parentKey.measurementConsumerId, reportingSet.externalReportingSetId) + .toName(), + reportingSet, + ) + } + } + } + + val internalCreateMetricRequestsList: List> = + coroutineScope { + request.requestsList.map { createMetricRequest -> + async { + buildInternalCreateMetricRequest( + parentKey.measurementConsumerId, + createMetricRequest, + reportingSetNameToInternalReportingSetMap.getValue( + createMetricRequest.metric.reportingSet + ), + ) + } + } } val internalMetrics = @@ -1457,7 +1508,7 @@ class MetricsService( .batchCreateMetrics( internalBatchCreateMetricsRequest { cmmsMeasurementConsumerId = parentKey.measurementConsumerId - requests += internalCreateMetricRequestsList + requests += internalCreateMetricRequestsList.awaitAll() } ) .metricsList @@ -1486,12 +1537,11 @@ class MetricsService( } /** Builds an [InternalCreateMetricRequest]. */ - private suspend fun buildInternalCreateMetricRequest( + private fun buildInternalCreateMetricRequest( cmmsMeasurementConsumerId: String, request: CreateMetricRequest, + internalReportingSet: InternalReportingSet, ): InternalCreateMetricRequest { - grpcRequire(request.hasMetric()) { "Metric is not specified." } - grpcRequire(request.metricId.matches(RESOURCE_ID_REGEX)) { "Metric ID is invalid." } grpcRequire(request.metric.reportingSet.isNotEmpty()) { "Reporting set in metric is not specified." @@ -1517,9 +1567,6 @@ class MetricsService( } grpcRequire(request.metric.hasMetricSpec()) { "Metric spec in metric is not specified." } - val internalReportingSet: InternalReportingSet = - getInternalReportingSet(cmmsMeasurementConsumerId, request.metric.reportingSet) - // Utilizes the property of the set expression compilation result -- If the set expression // contains only union operators, the compilation result has to be a single component. if ( @@ -1582,35 +1629,32 @@ class MetricsService( } } - /** Gets an [InternalReportingSet] based on a reporting set name. */ - private suspend fun getInternalReportingSet( + /** Batch get [InternalReportingSet]s based on [ReportingSet] names. */ + private suspend fun batchGetInternalReportingSets( cmmsMeasurementConsumerId: String, - reportingSetName: String, - ): InternalReportingSet { - val reportingSetKey = - grpcRequireNotNull(ReportingSetKey.fromName(reportingSetName)) { - "Invalid reporting set name $reportingSetName." - } + reportingSetNames: List, + ): BatchGetReportingSetsResponse { + val externalReportingSetIds: List = + reportingSetNames.map { + val reportingSetKey = + grpcRequireNotNull(ReportingSetKey.fromName(it)) { "Invalid reporting set name $it." } + + if (reportingSetKey.cmmsMeasurementConsumerId != cmmsMeasurementConsumerId) { + failGrpc(Status.PERMISSION_DENIED) { "No access to the reporting set [$it]." } + } - if (reportingSetKey.cmmsMeasurementConsumerId != cmmsMeasurementConsumerId) { - failGrpc(Status.PERMISSION_DENIED) { "No access to the reporting set [$reportingSetName]." } - } + reportingSetKey.reportingSetId + } return try { - internalReportingSetsStub - .batchGetReportingSets( - batchGetReportingSetsRequest { - this.cmmsMeasurementConsumerId = cmmsMeasurementConsumerId - this.externalReportingSetIds += reportingSetKey.reportingSetId - } - ) - .reportingSetsList - .first() - } catch (e: StatusException) { - throw Exception( - "Unable to retrieve ReportingSet using the provided name [$reportingSetName].", - e, + internalReportingSetsStub.batchGetReportingSets( + batchGetReportingSetsRequest { + this.cmmsMeasurementConsumerId = cmmsMeasurementConsumerId + this.externalReportingSetIds += externalReportingSetIds + } ) + } catch (e: StatusException) { + throw Exception("Unable to retrieve ReportingSets using the provided names.", e) } }