Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor SubmitBatchRequests to use Coroutines #1467

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
9898ebf
Refactor SubmitBatchRequests to use Coroutines
tristanvuong2021 Feb 8, 2024
ae095fb
Merge branch 'main' into tristanvuong-change-submit-batch-requests-to…
tristanvuong2021 Feb 8, 2024
51faade
lint fix
tristanvuong2021 Feb 8, 2024
aaa1d70
lint fix
tristanvuong2021 Feb 8, 2024
ecf2ca1
Merge branch 'main' into tristanvuong-change-submit-batch-requests-to…
tristanvuong2021 Feb 12, 2024
200bb1b
Flow of lists is returned now.
tristanvuong2021 Feb 12, 2024
896ee81
lint fix
tristanvuong2021 Feb 12, 2024
18ba7ce
Merge branch 'main' into tristanvuong-change-submit-batch-requests-to…
tristanvuong2021 Feb 12, 2024
2a3392b
Clarify method comment
tristanvuong2021 Feb 12, 2024
e013889
Use flow as input
tristanvuong2021 Feb 14, 2024
c4856cd
lint fix
tristanvuong2021 Feb 14, 2024
a65dc67
Merge branch 'main' into tristanvuong-change-submit-batch-requests-to…
tristanvuong2021 Feb 14, 2024
d18475f
Merge branch 'main' into tristanvuong-change-submit-batch-requests-to…
tristanvuong2021 Feb 14, 2024
03a63aa
Merge branch 'main' into tristanvuong-change-submit-batch-requests-to…
tristanvuong2021 Feb 15, 2024
c45e7f9
Merge branch 'main' into tristanvuong-change-submit-batch-requests-to…
tristanvuong2021 Feb 16, 2024
990ae00
Merge branch 'main' into tristanvuong-change-submit-batch-requests-to…
tristanvuong2021 Feb 16, 2024
79442cb
Merge branch 'main' into tristanvuong-change-submit-batch-requests-to…
tristanvuong2021 Feb 16, 2024
7b1387e
Merge branch 'main' into tristanvuong-change-submit-batch-requests-to…
tristanvuong2021 Feb 20, 2024
d659037
Merge branch 'main' into tristanvuong-change-submit-batch-requests-to…
tristanvuong2021 Feb 22, 2024
0162ff0
Merge branch 'main' into tristanvuong-change-submit-batch-requests-to…
tristanvuong2021 Feb 26, 2024
dac02ed
Replace collect with transform
tristanvuong2021 Feb 26, 2024
265ac80
lint fix
tristanvuong2021 Feb 26, 2024
e1bb621
refactor
tristanvuong2021 Feb 26, 2024
72bc131
lint fix
tristanvuong2021 Feb 26, 2024
77b1120
lint fix
tristanvuong2021 Feb 26, 2024
f5b929c
Merge branch 'main' into tristanvuong-change-submit-batch-requests-to…
tristanvuong2021 Feb 27, 2024
a5349ca
Merge branch 'main' into tristanvuong-change-submit-batch-requests-to…
tristanvuong2021 Feb 28, 2024
fde15f8
Merge branch 'main' into tristanvuong-change-submit-batch-requests-to…
tristanvuong2021 Feb 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@

package org.wfanet.measurement.reporting.service.api

import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.Deferred
import kotlinx.coroutines.async
import kotlinx.coroutines.awaitAll
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.asFlow
import kotlinx.coroutines.flow.flatMapConcat
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.sync.Semaphore
import kotlinx.coroutines.sync.withPermit

class BatchRequestException(message: String? = null, cause: Throwable? = null) :
Exception(message, cause)
Expand Down Expand Up @@ -48,19 +51,28 @@ fun <T> Flow<T>.chunked(chunkSize: Int): Flow<List<T>> {
}

/** Submits multiple RPCs by dividing the input items to batches. */
@OptIn(ExperimentalCoroutinesApi::class) // For `flatMapConcat`.
suspend fun <ITEM, RESP, RESULT> submitBatchRequests(
items: Flow<ITEM>,
items: Collection<ITEM>,
limit: Int,
callRpc: suspend (List<ITEM>) -> RESP,
parseResponse: (RESP) -> List<RESULT>,
): Flow<RESULT> {
): List<RESULT> {
if (limit <= 0) {
throw BatchRequestException(
"Invalid limit",
IllegalArgumentException("The size limit of a batch must be greater than 0."),
)
}

return items.chunked(limit).flatMapConcat { batch -> parseResponse(callRpc(batch)).asFlow() }
// For network requests, the number of concurrent coroutines needs to be capped. To be on the safe
// side, a low number is chosen.
val batchSemaphore = Semaphore(3)
return coroutineScope {
val deferred: List<Deferred<List<RESULT>>> =
items.chunked(limit).map { batch: List<ITEM> ->
async { batchSemaphore.withPermit { parseResponse(callRpc(batch)) } }
}
val responses: List<List<RESULT>> = deferred.awaitAll()
responses.flatten()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,6 @@ import kotlinx.coroutines.asExecutor
import kotlinx.coroutines.async
import kotlinx.coroutines.awaitAll
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.asFlow
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.toList
import kotlinx.coroutines.withContext
import org.jetbrains.annotations.BlockingExecutor
Expand Down Expand Up @@ -293,7 +290,7 @@ class MetricsService(
val measurementConsumer: MeasurementConsumer = getMeasurementConsumer(principal)

// Gets all external IDs of primitive reporting sets from the metric list.
val externalPrimitiveReportingSetIds: Flow<String> =
val externalPrimitiveReportingSetIds: List<String> =
internalMetricsList
.flatMap { internalMetric ->
internalMetric.weightedMeasurementsList.flatMap { weightedMeasurement ->
Expand All @@ -303,7 +300,6 @@ class MetricsService(
}
}
.distinct()
.asFlow()

val callBatchGetInternalReportingSetsRpc:
suspend (List<String>) -> BatchGetReportingSetsResponse =
Expand All @@ -319,7 +315,6 @@ class MetricsService(
) { response: BatchGetReportingSetsResponse ->
response.reportingSetsList
}
.toList()
.associateBy { it.externalReportingSetId }

val dataProviderNames = mutableSetOf<String>()
Expand Down Expand Up @@ -357,9 +352,9 @@ class MetricsService(
batchCreateCmmsMeasurements(principal, items)
}

val cmmsMeasurements: Flow<Measurement> =
val cmmsMeasurements: List<Measurement> =
submitBatchRequests(
cmmsCreateMeasurementRequests.asFlow(),
cmmsCreateMeasurementRequests,
BATCH_KINGDOM_MEASUREMENTS_LIMIT,
callBatchCreateMeasurementsRpc,
) { response: BatchCreateMeasurementsResponse ->
Expand All @@ -374,18 +369,17 @@ class MetricsService(
}

submitBatchRequests(
cmmsMeasurements.map {
measurementIds {
cmmsCreateMeasurementRequestId = it.measurementReferenceId
cmmsMeasurementId = MeasurementKey.fromName(it.name)!!.measurementId
}
},
BATCH_SET_CMMS_MEASUREMENT_IDS_LIMIT,
callBatchSetCmmsMeasurementIdsRpc,
) { response: BatchSetCmmsMeasurementIdsResponse ->
response.measurementsList
}
.toList()
cmmsMeasurements.map {
measurementIds {
cmmsCreateMeasurementRequestId = it.measurementReferenceId
cmmsMeasurementId = MeasurementKey.fromName(it.name)!!.measurementId
}
},
BATCH_SET_CMMS_MEASUREMENT_IDS_LIMIT,
callBatchSetCmmsMeasurementIdsRpc,
) { response: BatchSetCmmsMeasurementIdsResponse ->
response.measurementsList
}
}

/** Sets a batch of CMMS [MeasurementIds] to the [InternalMeasurement] table. */
Expand Down Expand Up @@ -792,13 +786,12 @@ class MetricsService(
batchSetInternalMeasurementResults(items, apiAuthenticationKey, principal)
}
submitBatchRequests(
measurementsList.asFlow(),
BATCH_SET_MEASUREMENT_RESULTS_LIMIT,
callBatchSetInternalMeasurementResultsRpc,
) { response: BatchSetCmmsMeasurementResultsResponse ->
response.measurementsList
}
.toList()
measurementsList,
BATCH_SET_MEASUREMENT_RESULTS_LIMIT,
callBatchSetInternalMeasurementResultsRpc,
) { response: BatchSetCmmsMeasurementResultsResponse ->
response.measurementsList
}

anyUpdate = true
}
Expand All @@ -815,13 +808,12 @@ class MetricsService(
)
}
submitBatchRequests(
measurementsList.asFlow(),
BATCH_SET_MEASUREMENT_FAILURES_LIMIT,
callBatchSetInternalMeasurementFailuresRpc,
) { response: BatchSetCmmsMeasurementFailuresResponse ->
response.measurementsList
}
.toList()
measurementsList,
BATCH_SET_MEASUREMENT_FAILURES_LIMIT,
callBatchSetInternalMeasurementFailuresRpc,
) { response: BatchSetCmmsMeasurementFailuresResponse ->
response.measurementsList
}

anyUpdate = true
}
Expand Down Expand Up @@ -918,13 +910,12 @@ class MetricsService(
}

return submitBatchRequests(
measurementNames.asFlow(),
BATCH_KINGDOM_MEASUREMENTS_LIMIT,
callBatchGetMeasurementsRpc,
) { response: BatchGetMeasurementsResponse ->
response.measurementsList
}
.toList()
measurementNames,
BATCH_KINGDOM_MEASUREMENTS_LIMIT,
callBatchGetMeasurementsRpc,
) { response: BatchGetMeasurementsResponse ->
response.measurementsList
}
}

/** Batch get CMMS measurements. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ import java.time.ZonedDateTime
import java.time.temporal.TemporalAdjusters
import java.time.zone.ZoneRulesException
import kotlin.math.min
import kotlinx.coroutines.flow.asFlow
import org.projectnessie.cel.Env
import org.wfanet.measurement.api.v2alpha.DataProvider
import org.wfanet.measurement.api.v2alpha.DataProviderKey
Expand Down Expand Up @@ -673,15 +672,12 @@ class ReportSchedulesService(
while (externalReportingSetIdSet.isNotEmpty()) {
retrievedExternalReportingSetIdSet.addAll(externalReportingSetIdSet)

submitBatchRequests(
externalReportingSetIdSet.asFlow(),
BATCH_GET_REPORTING_SETS_LIMIT,
callRpc,
) { response ->
submitBatchRequests(externalReportingSetIdSet, BATCH_GET_REPORTING_SETS_LIMIT, callRpc) {
response ->
externalReportingSetIdSet.clear()
response.reportingSetsList
}
.collect {
.forEach {
if (it.hasComposite()) {
val lhsExternalReportingSetId = it.composite.lhs.externalReportingSetId
if (lhsExternalReportingSetId.isNotEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@ import java.time.temporal.Temporal
import java.time.temporal.TemporalAdjusters
import java.time.zone.ZoneRulesException
import kotlin.math.min
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.asFlow
import kotlinx.coroutines.flow.toList
import org.projectnessie.cel.Env
import org.wfanet.measurement.api.v2alpha.MeasurementConsumerKey
Expand Down Expand Up @@ -164,8 +162,8 @@ class ReportsService(
results.subList(0, min(results.size, listReportsPageToken.pageSize))

// Get metrics.
val metricNames: Flow<String> =
subResults.flatMap { internalReport -> internalReport.metricNames }.distinct().asFlow()
val metricNames: List<String> =
subResults.flatMap { internalReport -> internalReport.metricNames }.distinct()

val callRpc: suspend (List<String>) -> BatchGetMetricsResponse = { items ->
batchGetMetrics(principal.resourceKey.toName(), items)
Expand All @@ -174,7 +172,6 @@ class ReportsService(
submitBatchRequests(metricNames, BATCH_GET_METRICS_LIMIT, callRpc) { response ->
response.metricsList
}
.toList()
.associateBy { checkNotNull(MetricKey.fromName(it.name)).metricId }

return listReportsResponse {
Expand Down Expand Up @@ -230,7 +227,7 @@ class ReportsService(
}

// Get metrics.
val metricNames: Flow<String> = internalReport.metricNames.distinct().asFlow()
val metricNames: List<String> = internalReport.metricNames.distinct()

val callRpc: suspend (List<String>) -> BatchGetMetricsResponse = { items ->
batchGetMetrics(principal.resourceKey.toName(), items)
Expand All @@ -239,7 +236,6 @@ class ReportsService(
submitBatchRequests(metricNames, BATCH_GET_METRICS_LIMIT, callRpc) { response ->
response.metricsList
}
.toList()
.associateBy { checkNotNull(MetricKey.fromName(it.name)).metricId }

// Convert the internal report to public and return.
Expand Down Expand Up @@ -300,15 +296,18 @@ class ReportsService(
validateTime(request.report)

val externalMetricCalculationSpecIds: List<String> =
request.report.reportingMetricEntriesList.flatMap { reportingMetricEntry ->
reportingMetricEntry.value.metricCalculationSpecsList.map {
val key =
grpcRequireNotNull(MetricCalculationSpecKey.fromName(it)) {
"MetricCalculationSpec name $it is invalid."
}
key.metricCalculationSpecId
request.report.reportingMetricEntriesList
.flatMap { reportingMetricEntry ->
reportingMetricEntry.value.metricCalculationSpecsList.map {
val key =
grpcRequireNotNull(MetricCalculationSpecKey.fromName(it)) {
"MetricCalculationSpec name $it is invalid."
}
key.metricCalculationSpecId
}
}
}
.distinct()

val externalIdToMetricCalculationSpecMap: Map<String, InternalMetricCalculationSpec> =
createExternalIdToMetricCalculationSpecMap(
parentKey.measurementConsumerId,
Expand Down Expand Up @@ -360,24 +359,23 @@ class ReportsService(
}

// Create metrics.
val createMetricRequests: Flow<CreateMetricRequest> =
internalReport.reportingMetricEntriesMap
.flatMap { (reportingSetId, reportingMetricCalculationSpec) ->
reportingMetricCalculationSpec.metricCalculationSpecReportingMetricsList.flatMap {
metricCalculationSpecReportingMetrics ->
metricCalculationSpecReportingMetrics.reportingMetricsList.map {
it.toCreateMetricRequest(
principal.resourceKey,
reportingSetId,
externalIdToMetricCalculationSpecMap
.getValue(metricCalculationSpecReportingMetrics.externalMetricCalculationSpecId)
.details
.filter,
)
}
val createMetricRequests: List<CreateMetricRequest> =
internalReport.reportingMetricEntriesMap.flatMap {
(reportingSetId, reportingMetricCalculationSpec) ->
reportingMetricCalculationSpec.metricCalculationSpecReportingMetricsList.flatMap {
metricCalculationSpecReportingMetrics ->
metricCalculationSpecReportingMetrics.reportingMetricsList.map {
it.toCreateMetricRequest(
principal.resourceKey,
reportingSetId,
externalIdToMetricCalculationSpecMap
.getValue(metricCalculationSpecReportingMetrics.externalMetricCalculationSpecId)
.details
.filter,
)
}
}
.asFlow()
}

val callRpc: suspend (List<CreateMetricRequest>) -> BatchCreateMetricsResponse = { items ->
batchCreateMetrics(request.parent, items)
Expand All @@ -387,7 +385,6 @@ class ReportsService(
response: BatchCreateMetricsResponse ->
response.metricsList
}
.toList()
.associateBy { checkNotNull(MetricKey.fromName(it.name)).metricId }

// Once all metrics are created, get the updated internal report with the metric IDs filled.
Expand Down Expand Up @@ -477,11 +474,14 @@ class ReportsService(

if (state == Report.State.SUCCEEDED || state == Report.State.FAILED) {
val externalMetricCalculationSpecIds =
internalReport.reportingMetricEntriesMap.flatMap { reportingMetricCalculationSpec ->
reportingMetricCalculationSpec.value.metricCalculationSpecReportingMetricsList.map {
it.externalMetricCalculationSpecId
internalReport.reportingMetricEntriesMap
.flatMap { reportingMetricCalculationSpec ->
reportingMetricCalculationSpec.value.metricCalculationSpecReportingMetricsList.map {
it.externalMetricCalculationSpecId
}
}
}
.distinct()

val externalIdToMetricCalculationMap: Map<String, InternalMetricCalculationSpec> =
createExternalIdToMetricCalculationSpecMap(
internalReport.cmmsMeasurementConsumerId,
Expand Down
Loading
Loading