Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor SubmitBatchRequests to use Coroutines #1467

Merged
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
9898ebf
Refactor SubmitBatchRequests to use Coroutines
tristanvuong2021 Feb 8, 2024
ae095fb
Merge branch 'main' into tristanvuong-change-submit-batch-requests-to…
tristanvuong2021 Feb 8, 2024
51faade
lint fix
tristanvuong2021 Feb 8, 2024
aaa1d70
lint fix
tristanvuong2021 Feb 8, 2024
ecf2ca1
Merge branch 'main' into tristanvuong-change-submit-batch-requests-to…
tristanvuong2021 Feb 12, 2024
200bb1b
Flow of lists is returned now.
tristanvuong2021 Feb 12, 2024
896ee81
lint fix
tristanvuong2021 Feb 12, 2024
18ba7ce
Merge branch 'main' into tristanvuong-change-submit-batch-requests-to…
tristanvuong2021 Feb 12, 2024
2a3392b
Clarify method comment
tristanvuong2021 Feb 12, 2024
e013889
Use flow as input
tristanvuong2021 Feb 14, 2024
c4856cd
lint fix
tristanvuong2021 Feb 14, 2024
a65dc67
Merge branch 'main' into tristanvuong-change-submit-batch-requests-to…
tristanvuong2021 Feb 14, 2024
d18475f
Merge branch 'main' into tristanvuong-change-submit-batch-requests-to…
tristanvuong2021 Feb 14, 2024
03a63aa
Merge branch 'main' into tristanvuong-change-submit-batch-requests-to…
tristanvuong2021 Feb 15, 2024
c45e7f9
Merge branch 'main' into tristanvuong-change-submit-batch-requests-to…
tristanvuong2021 Feb 16, 2024
990ae00
Merge branch 'main' into tristanvuong-change-submit-batch-requests-to…
tristanvuong2021 Feb 16, 2024
79442cb
Merge branch 'main' into tristanvuong-change-submit-batch-requests-to…
tristanvuong2021 Feb 16, 2024
7b1387e
Merge branch 'main' into tristanvuong-change-submit-batch-requests-to…
tristanvuong2021 Feb 20, 2024
d659037
Merge branch 'main' into tristanvuong-change-submit-batch-requests-to…
tristanvuong2021 Feb 22, 2024
0162ff0
Merge branch 'main' into tristanvuong-change-submit-batch-requests-to…
tristanvuong2021 Feb 26, 2024
dac02ed
Replace collect with transform
tristanvuong2021 Feb 26, 2024
265ac80
lint fix
tristanvuong2021 Feb 26, 2024
e1bb621
refactor
tristanvuong2021 Feb 26, 2024
72bc131
lint fix
tristanvuong2021 Feb 26, 2024
77b1120
lint fix
tristanvuong2021 Feb 26, 2024
f5b929c
Merge branch 'main' into tristanvuong-change-submit-batch-requests-to…
tristanvuong2021 Feb 27, 2024
a5349ca
Merge branch 'main' into tristanvuong-change-submit-batch-requests-to…
tristanvuong2021 Feb 28, 2024
fde15f8
Merge branch 'main' into tristanvuong-change-submit-batch-requests-to…
tristanvuong2021 Feb 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@

package org.wfanet.measurement.reporting.service.api

import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.Deferred
import kotlinx.coroutines.async
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.asFlow
import kotlinx.coroutines.flow.flatMapConcat
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.sync.Semaphore
import kotlinx.coroutines.sync.withPermit

class BatchRequestException(message: String? = null, cause: Throwable? = null) :
Exception(message, cause)
Expand All @@ -47,20 +49,39 @@ fun <T> Flow<T>.chunked(chunkSize: Int): Flow<List<T>> {
}
}

/** Submits multiple RPCs by dividing the input items to batches. */
@OptIn(ExperimentalCoroutinesApi::class) // For `flatMapConcat`.
/**
* Submits multiple RPCs by dividing the input items to batches.
*
* @return [Flow] that emits [List]s containing the results of the multiple RPCs.
*/
suspend fun <ITEM, RESP, RESULT> submitBatchRequests(
items: Flow<ITEM>,
limit: Int,
callRpc: suspend (List<ITEM>) -> RESP,
parseResponse: (RESP) -> List<RESULT>,
): Flow<RESULT> {
): Flow<List<RESULT>> {
if (limit <= 0) {
throw BatchRequestException(
"Invalid limit",
IllegalArgumentException("The size limit of a batch must be greater than 0."),
)
}

return items.chunked(limit).flatMapConcat { batch -> parseResponse(callRpc(batch)).asFlow() }
// For network requests, the number of concurrent coroutines needs to be capped. To be on the safe
// side, a low number is chosen.
val batchSemaphore = Semaphore(3)
return flow {
coroutineScope {
val deferred: List<Deferred<List<RESULT>>> = buildList {
items.chunked(limit).collect { batch: List<ITEM> ->
// The batch reference is reused for every collect call. To ensure async works, a copy
// of the contents needs to be saved in a new reference.
val tempBatch = batch.toList()
add(async { batchSemaphore.withPermit { parseResponse(callRpc(tempBatch)) } })
}
}

deferred.forEach { emit(it.await()) }
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,16 @@ import kotlin.math.sqrt
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.Deferred
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.asExecutor
import kotlinx.coroutines.async
import kotlinx.coroutines.awaitAll
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.asFlow
import kotlinx.coroutines.flow.count
import kotlinx.coroutines.flow.flatMapMerge
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.toList
import kotlinx.coroutines.withContext
Expand Down Expand Up @@ -308,34 +312,42 @@ class MetricsService(
val measurementConsumer: MeasurementConsumer = getMeasurementConsumer(principal)

// Gets all external IDs of primitive reporting sets from the metric list.
val externalPrimitiveReportingSetIds: Flow<String> =
internalMetricsList
.flatMap { internalMetric ->
internalMetric.weightedMeasurementsList.flatMap { weightedMeasurement ->
weightedMeasurement.measurement.primitiveReportingSetBasesList.map {
it.externalReportingSetId
val externalPrimitiveReportingSetIds: Flow<String> = flow {
buildSet<String> {
for (internalMetric in internalMetricsList) {
for (weightedMeasurement in internalMetric.weightedMeasurementsList) {
for (primitiveReportingSetBasis in
weightedMeasurement.measurement.primitiveReportingSetBasesList) {
if (!contains(primitiveReportingSetBasis.externalReportingSetId)) {
emit(primitiveReportingSetBasis.externalReportingSetId)
add(primitiveReportingSetBasis.externalReportingSetId)
}
}
}
}
.distinct()
.asFlow()
}
}

val callBatchGetInternalReportingSetsRpc:
suspend (List<String>) -> BatchGetReportingSetsResponse =
{ items ->
batchGetInternalReportingSets(principal.resourceKey.measurementConsumerId, items)
}

val internalPrimitiveReportingSetMap: Map<String, InternalReportingSet> =
val internalPrimitiveReportingSetMap: Map<String, InternalReportingSet> = buildMap {
submitBatchRequests(
externalPrimitiveReportingSetIds,
BATCH_GET_REPORTING_SETS_LIMIT,
callBatchGetInternalReportingSetsRpc,
) { response: BatchGetReportingSetsResponse ->
response.reportingSetsList
}
.toList()
.associateBy { it.externalReportingSetId }
.collect { reportingSets: List<InternalReportingSet> ->
for (reportingSet in reportingSets) {
computeIfAbsent(reportingSet.externalReportingSetId) { reportingSet }
}
}
}

val dataProviderNames = mutableSetOf<String>()
for (internalPrimitiveReportingSet in internalPrimitiveReportingSetMap.values) {
Expand All @@ -348,22 +360,25 @@ class MetricsService(

val measurementConsumerSigningKey = getMeasurementConsumerSigningKey(principal)

val cmmsCreateMeasurementRequests: List<CreateMeasurementRequest> =
internalMetricsList.flatMap { internalMetric ->
internalMetric.weightedMeasurementsList
.filter { it.measurement.cmmsMeasurementId.isBlank() }
.map {
buildCreateMeasurementRequest(
it.measurement,
internalMetric.metricSpec,
internalPrimitiveReportingSetMap,
measurementConsumer,
principal,
dataProviderInfoMap,
measurementConsumerSigningKey,
val cmmsCreateMeasurementRequests: Flow<CreateMeasurementRequest> = flow {
for (internalMetric in internalMetricsList) {
for (weightedMeasurement in internalMetric.weightedMeasurementsList) {
if (weightedMeasurement.measurement.cmmsMeasurementId.isBlank()) {
emit(
buildCreateMeasurementRequest(
weightedMeasurement.measurement,
internalMetric.metricSpec,
internalPrimitiveReportingSetMap,
measurementConsumer,
principal,
dataProviderInfoMap,
measurementConsumerSigningKey,
)
)
}
}
}
}

// Create CMMS measurements.
val callBatchCreateMeasurementsRpc:
Expand All @@ -372,14 +387,16 @@ class MetricsService(
batchCreateCmmsMeasurements(principal, items)
}

@OptIn(ExperimentalCoroutinesApi::class)
val cmmsMeasurements: Flow<Measurement> =
submitBatchRequests(
cmmsCreateMeasurementRequests.asFlow(),
BATCH_KINGDOM_MEASUREMENTS_LIMIT,
callBatchCreateMeasurementsRpc,
) { response: BatchCreateMeasurementsResponse ->
response.measurementsList
}
cmmsCreateMeasurementRequests,
BATCH_KINGDOM_MEASUREMENTS_LIMIT,
callBatchCreateMeasurementsRpc,
) { response: BatchCreateMeasurementsResponse ->
response.measurementsList
}
.flatMapMerge { it.asFlow() }

// Set CMMS measurement IDs.
val callBatchSetCmmsMeasurementIdsRpc:
Expand All @@ -400,7 +417,7 @@ class MetricsService(
) { response: BatchSetCmmsMeasurementIdsResponse ->
response.measurementsList
}
.toList()
.collect {}
}

/** Sets a batch of CMMS [MeasurementIds] to the [InternalMeasurement] table. */
Expand Down Expand Up @@ -784,65 +801,71 @@ class MetricsService(
apiAuthenticationKey: String,
principal: MeasurementConsumerPrincipal,
): Boolean {
val newStateToCmmsMeasurements: Map<Measurement.State, List<Measurement>> =
getCmmsMeasurements(internalMeasurements, principal).groupBy { measurement ->
measurement.state
val failedMeasurements: MutableList<Measurement> = mutableListOf()

// Most Measurements are expected to be SUCCEEDED so SUCCEEDED Measurements will be collected
// via a Flow.
val succeededMeasurements: Flow<Measurement> = flow {
getCmmsMeasurements(internalMeasurements, principal).collect { measurements ->
for (measurement in measurements) {
@Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Protobuf enum fields cannot be null.
when (measurement.state) {
Measurement.State.SUCCEEDED -> emit(measurement)
Measurement.State.CANCELLED,
Measurement.State.FAILED -> failedMeasurements.add(measurement)
Measurement.State.COMPUTING,
Measurement.State.AWAITING_REQUISITION_FULFILLMENT -> {}
Measurement.State.STATE_UNSPECIFIED ->
failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) {
"The CMMS measurement state should've been set."
}
Measurement.State.UNRECOGNIZED -> {
failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) {
"Unrecognized CMMS measurement state."
}
}
}
}
}
}

var anyUpdate = false

for ((newState, measurementsList) in newStateToCmmsMeasurements) {
when (newState) {
Measurement.State.SUCCEEDED -> {
val callBatchSetInternalMeasurementResultsRpc:
suspend (List<Measurement>) -> BatchSetCmmsMeasurementResultsResponse =
{ items ->
batchSetInternalMeasurementResults(items, apiAuthenticationKey, principal)
}
submitBatchRequests(
measurementsList.asFlow(),
BATCH_SET_MEASUREMENT_RESULTS_LIMIT,
callBatchSetInternalMeasurementResultsRpc,
) { response: BatchSetCmmsMeasurementResultsResponse ->
response.measurementsList
}
.toList()

anyUpdate = true
val callBatchSetInternalMeasurementResultsRpc:
suspend (List<Measurement>) -> BatchSetCmmsMeasurementResultsResponse =
{ items ->
batchSetInternalMeasurementResults(items, apiAuthenticationKey, principal)
}
val count =
submitBatchRequests(
succeededMeasurements,
BATCH_SET_MEASUREMENT_RESULTS_LIMIT,
callBatchSetInternalMeasurementResultsRpc,
) { response: BatchSetCmmsMeasurementResultsResponse ->
response.measurementsList
}
Measurement.State.AWAITING_REQUISITION_FULFILLMENT,
Measurement.State.COMPUTING -> {} // Do nothing.
Measurement.State.FAILED,
Measurement.State.CANCELLED -> {
val callBatchSetInternalMeasurementFailuresRpc:
suspend (List<Measurement>) -> BatchSetCmmsMeasurementFailuresResponse =
{ items ->
batchSetInternalMeasurementFailures(
items,
principal.resourceKey.measurementConsumerId,
)
}
submitBatchRequests(
measurementsList.asFlow(),
BATCH_SET_MEASUREMENT_FAILURES_LIMIT,
callBatchSetInternalMeasurementFailuresRpc,
) { response: BatchSetCmmsMeasurementFailuresResponse ->
response.measurementsList
}
.toList()
.count()

if (count > 0) {
anyUpdate = true
}

anyUpdate = true
if (failedMeasurements.isNotEmpty()) {
val callBatchSetInternalMeasurementFailuresRpc:
suspend (List<Measurement>) -> BatchSetCmmsMeasurementFailuresResponse =
{ items ->
batchSetInternalMeasurementFailures(items, principal.resourceKey.measurementConsumerId)
}
Measurement.State.STATE_UNSPECIFIED ->
failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) {
"The CMMS measurement state should've been set."
}
Measurement.State.UNRECOGNIZED -> {
failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) {
"Unrecognized CMMS measurement state."
}
submitBatchRequests(
failedMeasurements.asFlow(),
BATCH_SET_MEASUREMENT_FAILURES_LIMIT,
callBatchSetInternalMeasurementFailuresRpc,
) { response: BatchSetCmmsMeasurementFailuresResponse ->
response.measurementsList
}
}
.collect {}

anyUpdate = true
}

return anyUpdate
Expand Down Expand Up @@ -908,31 +931,37 @@ class MetricsService(
private suspend fun getCmmsMeasurements(
internalMeasurements: List<InternalMeasurement>,
principal: MeasurementConsumerPrincipal,
): List<Measurement> {
val measurementNames: List<String> =
internalMeasurements
.map { internalMeasurement ->
MeasurementKey(
principal.resourceKey.measurementConsumerId,
internalMeasurement.cmmsMeasurementId,
)
.toName()
): Flow<List<Measurement>> {
val measurementNames: Flow<String> = flow {
buildSet {
for (internalMeasurement in internalMeasurements) {
val name =
MeasurementKey(
principal.resourceKey.measurementConsumerId,
internalMeasurement.cmmsMeasurementId,
)
.toName()

if (!contains(name)) {
emit(name)
add(name)
}
}
.distinct()
}
}

val callBatchGetMeasurementsRpc: suspend (List<String>) -> BatchGetMeasurementsResponse =
{ items ->
batchGetCmmsMeasurements(principal, items)
}

return submitBatchRequests(
measurementNames.asFlow(),
BATCH_KINGDOM_MEASUREMENTS_LIMIT,
callBatchGetMeasurementsRpc,
) { response: BatchGetMeasurementsResponse ->
response.measurementsList
}
.toList()
measurementNames,
BATCH_KINGDOM_MEASUREMENTS_LIMIT,
callBatchGetMeasurementsRpc,
) { response: BatchGetMeasurementsResponse ->
response.measurementsList
}
}

/** Batch get CMMS measurements. */
Expand Down
Loading
Loading