Skip to content

Commit

Permalink
Refactor SubmitBatchRequests to use Coroutines (#1467)
Browse files Browse the repository at this point in the history
  • Loading branch information
tristanvuong2021 authored and ple13 committed Aug 16, 2024
1 parent bc60bad commit 9f112bb
Show file tree
Hide file tree
Showing 6 changed files with 296 additions and 182 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@

package org.wfanet.measurement.reporting.service.api

import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.Deferred
import kotlinx.coroutines.async
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.asFlow
import kotlinx.coroutines.flow.flatMapConcat
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.sync.Semaphore
import kotlinx.coroutines.sync.withPermit

class BatchRequestException(message: String? = null, cause: Throwable? = null) :
Exception(message, cause)
Expand All @@ -47,20 +49,39 @@ fun <T> Flow<T>.chunked(chunkSize: Int): Flow<List<T>> {
}
}

/** Submits multiple RPCs by dividing the input items to batches. */
@OptIn(ExperimentalCoroutinesApi::class) // For `flatMapConcat`.
/**
* Submits multiple RPCs by dividing the input items to batches.
*
* @return [Flow] that emits [List]s containing the results of the multiple RPCs.
*/
suspend fun <ITEM, RESP, RESULT> submitBatchRequests(
items: Flow<ITEM>,
limit: Int,
callRpc: suspend (List<ITEM>) -> RESP,
parseResponse: (RESP) -> List<RESULT>,
): Flow<RESULT> {
): Flow<List<RESULT>> {
if (limit <= 0) {
throw BatchRequestException(
"Invalid limit",
IllegalArgumentException("The size limit of a batch must be greater than 0."),
)
}

return items.chunked(limit).flatMapConcat { batch -> parseResponse(callRpc(batch)).asFlow() }
// For network requests, the number of concurrent coroutines needs to be capped. To be on the safe
// side, a low number is chosen.
val batchSemaphore = Semaphore(3)
return flow {
coroutineScope {
val deferred: List<Deferred<List<RESULT>>> = buildList {
items.chunked(limit).collect { batch: List<ITEM> ->
// The batch reference is reused for every collect call. To ensure async works, a copy
// of the contents needs to be saved in a new reference.
val tempBatch = batch.toList()
add(async { batchSemaphore.withPermit { parseResponse(callRpc(tempBatch)) } })
}
}

deferred.forEach { emit(it.await()) }
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,19 @@ import kotlin.math.sqrt
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.Deferred
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.asExecutor
import kotlinx.coroutines.async
import kotlinx.coroutines.awaitAll
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.asFlow
import kotlinx.coroutines.flow.count
import kotlinx.coroutines.flow.flattenMerge
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.toList
import kotlinx.coroutines.flow.transform
import kotlinx.coroutines.withContext
import org.jetbrains.annotations.BlockingExecutor
import org.jetbrains.annotations.NonBlockingExecutor
Expand Down Expand Up @@ -308,34 +313,45 @@ class MetricsService(
val measurementConsumer: MeasurementConsumer = getMeasurementConsumer(principal)

// Gets all external IDs of primitive reporting sets from the metric list.
val externalPrimitiveReportingSetIds: Flow<String> =
internalMetricsList
.flatMap { internalMetric ->
internalMetric.weightedMeasurementsList.flatMap { weightedMeasurement ->
weightedMeasurement.measurement.primitiveReportingSetBasesList.map {
it.externalReportingSetId
val externalPrimitiveReportingSetIds: Flow<String> = flow {
buildSet {
for (internalMetric in internalMetricsList) {
for (weightedMeasurement in internalMetric.weightedMeasurementsList) {
for (primitiveReportingSetBasis in
weightedMeasurement.measurement.primitiveReportingSetBasesList) {
// Checks if the set already contains the ID
if (!contains(primitiveReportingSetBasis.externalReportingSetId)) {
// If the set doesn't contain the ID, emit it and add it to the set so it won't
// get emitted again.
emit(primitiveReportingSetBasis.externalReportingSetId)
add(primitiveReportingSetBasis.externalReportingSetId)
}
}
}
}
.distinct()
.asFlow()
}
}

val callBatchGetInternalReportingSetsRpc:
suspend (List<String>) -> BatchGetReportingSetsResponse =
{ items ->
batchGetInternalReportingSets(principal.resourceKey.measurementConsumerId, items)
}

val internalPrimitiveReportingSetMap: Map<String, InternalReportingSet> =
val internalPrimitiveReportingSetMap: Map<String, InternalReportingSet> = buildMap {
submitBatchRequests(
externalPrimitiveReportingSetIds,
BATCH_GET_REPORTING_SETS_LIMIT,
callBatchGetInternalReportingSetsRpc,
) { response: BatchGetReportingSetsResponse ->
response.reportingSetsList
}
.toList()
.associateBy { it.externalReportingSetId }
.collect { reportingSets: List<InternalReportingSet> ->
for (reportingSet in reportingSets) {
computeIfAbsent(reportingSet.externalReportingSetId) { reportingSet }
}
}
}

val dataProviderNames = mutableSetOf<String>()
for (internalPrimitiveReportingSet in internalPrimitiveReportingSetMap.values) {
Expand All @@ -348,22 +364,25 @@ class MetricsService(

val measurementConsumerSigningKey = getMeasurementConsumerSigningKey(principal)

val cmmsCreateMeasurementRequests: List<CreateMeasurementRequest> =
internalMetricsList.flatMap { internalMetric ->
internalMetric.weightedMeasurementsList
.filter { it.measurement.cmmsMeasurementId.isBlank() }
.map {
buildCreateMeasurementRequest(
it.measurement,
internalMetric.metricSpec,
internalPrimitiveReportingSetMap,
measurementConsumer,
principal,
dataProviderInfoMap,
measurementConsumerSigningKey,
val cmmsCreateMeasurementRequests: Flow<CreateMeasurementRequest> = flow {
for (internalMetric in internalMetricsList) {
for (weightedMeasurement in internalMetric.weightedMeasurementsList) {
if (weightedMeasurement.measurement.cmmsMeasurementId.isBlank()) {
emit(
buildCreateMeasurementRequest(
weightedMeasurement.measurement,
internalMetric.metricSpec,
internalPrimitiveReportingSetMap,
measurementConsumer,
principal,
dataProviderInfoMap,
measurementConsumerSigningKey,
)
)
}
}
}
}

// Create CMMS measurements.
val callBatchCreateMeasurementsRpc:
Expand All @@ -372,14 +391,17 @@ class MetricsService(
batchCreateCmmsMeasurements(principal, items)
}

@OptIn(ExperimentalCoroutinesApi::class)
val cmmsMeasurements: Flow<Measurement> =
submitBatchRequests(
cmmsCreateMeasurementRequests.asFlow(),
BATCH_KINGDOM_MEASUREMENTS_LIMIT,
callBatchCreateMeasurementsRpc,
) { response: BatchCreateMeasurementsResponse ->
response.measurementsList
}
cmmsCreateMeasurementRequests,
BATCH_KINGDOM_MEASUREMENTS_LIMIT,
callBatchCreateMeasurementsRpc,
) { response: BatchCreateMeasurementsResponse ->
response.measurementsList
}
.map { it.asFlow() }
.flattenMerge()

// Set CMMS measurement IDs.
val callBatchSetCmmsMeasurementIdsRpc:
Expand All @@ -400,7 +422,7 @@ class MetricsService(
) { response: BatchSetCmmsMeasurementIdsResponse ->
response.measurementsList
}
.toList()
.collect {}
}

/** Sets a batch of CMMS [MeasurementIds] to the [InternalMeasurement] table. */
Expand Down Expand Up @@ -784,65 +806,70 @@ class MetricsService(
apiAuthenticationKey: String,
principal: MeasurementConsumerPrincipal,
): Boolean {
val newStateToCmmsMeasurements: Map<Measurement.State, List<Measurement>> =
getCmmsMeasurements(internalMeasurements, principal).groupBy { measurement ->
measurement.state
val failedMeasurements: MutableList<Measurement> = mutableListOf()

// Most Measurements are expected to be SUCCEEDED so SUCCEEDED Measurements will be collected
// via a Flow.
val succeededMeasurements: Flow<Measurement> =
getCmmsMeasurements(internalMeasurements, principal).transform { measurements ->
for (measurement in measurements) {
@Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Protobuf enum fields cannot be null.
when (measurement.state) {
Measurement.State.SUCCEEDED -> emit(measurement)
Measurement.State.CANCELLED,
Measurement.State.FAILED -> failedMeasurements.add(measurement)
Measurement.State.COMPUTING,
Measurement.State.AWAITING_REQUISITION_FULFILLMENT -> {}
Measurement.State.STATE_UNSPECIFIED ->
failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) {
"The CMMS measurement state should've been set."
}
Measurement.State.UNRECOGNIZED -> {
failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) {
"Unrecognized CMMS measurement state."
}
}
}
}
}

var anyUpdate = false

for ((newState, measurementsList) in newStateToCmmsMeasurements) {
when (newState) {
Measurement.State.SUCCEEDED -> {
val callBatchSetInternalMeasurementResultsRpc:
suspend (List<Measurement>) -> BatchSetCmmsMeasurementResultsResponse =
{ items ->
batchSetInternalMeasurementResults(items, apiAuthenticationKey, principal)
}
submitBatchRequests(
measurementsList.asFlow(),
BATCH_SET_MEASUREMENT_RESULTS_LIMIT,
callBatchSetInternalMeasurementResultsRpc,
) { response: BatchSetCmmsMeasurementResultsResponse ->
response.measurementsList
}
.toList()

anyUpdate = true
val callBatchSetInternalMeasurementResultsRpc:
suspend (List<Measurement>) -> BatchSetCmmsMeasurementResultsResponse =
{ items ->
batchSetInternalMeasurementResults(items, apiAuthenticationKey, principal)
}
val count =
submitBatchRequests(
succeededMeasurements,
BATCH_SET_MEASUREMENT_RESULTS_LIMIT,
callBatchSetInternalMeasurementResultsRpc,
) { response: BatchSetCmmsMeasurementResultsResponse ->
response.measurementsList
}
Measurement.State.AWAITING_REQUISITION_FULFILLMENT,
Measurement.State.COMPUTING -> {} // Do nothing.
Measurement.State.FAILED,
Measurement.State.CANCELLED -> {
val callBatchSetInternalMeasurementFailuresRpc:
suspend (List<Measurement>) -> BatchSetCmmsMeasurementFailuresResponse =
{ items ->
batchSetInternalMeasurementFailures(
items,
principal.resourceKey.measurementConsumerId,
)
}
submitBatchRequests(
measurementsList.asFlow(),
BATCH_SET_MEASUREMENT_FAILURES_LIMIT,
callBatchSetInternalMeasurementFailuresRpc,
) { response: BatchSetCmmsMeasurementFailuresResponse ->
response.measurementsList
}
.toList()
.count()

if (count > 0) {
anyUpdate = true
}

anyUpdate = true
if (failedMeasurements.isNotEmpty()) {
val callBatchSetInternalMeasurementFailuresRpc:
suspend (List<Measurement>) -> BatchSetCmmsMeasurementFailuresResponse =
{ items ->
batchSetInternalMeasurementFailures(items, principal.resourceKey.measurementConsumerId)
}
Measurement.State.STATE_UNSPECIFIED ->
failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) {
"The CMMS measurement state should've been set."
}
Measurement.State.UNRECOGNIZED -> {
failGrpc(status = Status.FAILED_PRECONDITION, cause = IllegalStateException()) {
"Unrecognized CMMS measurement state."
}
submitBatchRequests(
failedMeasurements.asFlow(),
BATCH_SET_MEASUREMENT_FAILURES_LIMIT,
callBatchSetInternalMeasurementFailuresRpc,
) { response: BatchSetCmmsMeasurementFailuresResponse ->
response.measurementsList
}
}
.collect {}

anyUpdate = true
}

return anyUpdate
Expand Down Expand Up @@ -908,31 +935,39 @@ class MetricsService(
private suspend fun getCmmsMeasurements(
internalMeasurements: List<InternalMeasurement>,
principal: MeasurementConsumerPrincipal,
): List<Measurement> {
val measurementNames: List<String> =
internalMeasurements
.map { internalMeasurement ->
MeasurementKey(
principal.resourceKey.measurementConsumerId,
internalMeasurement.cmmsMeasurementId,
)
.toName()
): Flow<List<Measurement>> {
val measurementNames: Flow<String> = flow {
buildSet {
for (internalMeasurement in internalMeasurements) {
val name =
MeasurementKey(
principal.resourceKey.measurementConsumerId,
internalMeasurement.cmmsMeasurementId,
)
.toName()
// Checks if the set already contains the name
if (!contains(name)) {
// If the set doesn't contain the name, emit it and add it to the set so it won't
// get emitted again.
emit(name)
add(name)
}
}
.distinct()
}
}

val callBatchGetMeasurementsRpc: suspend (List<String>) -> BatchGetMeasurementsResponse =
{ items ->
batchGetCmmsMeasurements(principal, items)
}

return submitBatchRequests(
measurementNames.asFlow(),
BATCH_KINGDOM_MEASUREMENTS_LIMIT,
callBatchGetMeasurementsRpc,
) { response: BatchGetMeasurementsResponse ->
response.measurementsList
}
.toList()
measurementNames,
BATCH_KINGDOM_MEASUREMENTS_LIMIT,
callBatchGetMeasurementsRpc,
) { response: BatchGetMeasurementsResponse ->
response.measurementsList
}
}

/** Batch get CMMS measurements. */
Expand Down
Loading

0 comments on commit 9f112bb

Please sign in to comment.