Skip to content

Commit

Permalink
Fix Ordering of Responses from Internal Batch Get Methods in Reportin…
Browse files Browse the repository at this point in the history
…g V2 (#1056)
  • Loading branch information
tristanvuong2021 authored and ple13 committed Aug 16, 2024
1 parent 972c43d commit 901fa10
Show file tree
Hide file tree
Showing 9 changed files with 196 additions and 121 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ import org.wfanet.measurement.internal.reporting.v2.batchGetMetricsResponse
import org.wfanet.measurement.reporting.deploy.v2.postgres.readers.MetricReader
import org.wfanet.measurement.reporting.deploy.v2.postgres.writers.CreateMetrics
import org.wfanet.measurement.reporting.service.internal.MeasurementConsumerNotFoundException
import org.wfanet.measurement.reporting.service.internal.MetricNotFoundException
import org.wfanet.measurement.reporting.service.internal.ReportingSetNotFoundException

private const val MAX_BATCH_SIZE = 1000
Expand Down Expand Up @@ -131,14 +132,16 @@ class PostgresMetricsService(
.map { it.metric }
.withSerializableErrorRetries()
.toList()
} catch (e: MetricNotFoundException) {
throw e.asStatusRuntimeException(Status.Code.NOT_FOUND, "Metric not found")
} catch (e: IllegalStateException) {
failGrpc(Status.NOT_FOUND) { "Metric is not found" }
failGrpc(Status.NOT_FOUND) { "Metric not found" }
} finally {
readContext.close()
}

if (metrics.size < request.externalMetricIdsList.size) {
failGrpc(Status.NOT_FOUND) { "Metric is not found" }
failGrpc(Status.NOT_FOUND) { "Metric not found" }
}

return batchGetMetricsResponse { this.metrics += metrics }
Expand All @@ -158,8 +161,6 @@ class PostgresMetricsService(
.map { it.metric }
.withSerializableErrorRetries()
)
} catch (e: IllegalStateException) {
failGrpc(Status.NOT_FOUND) { "Metric is not found" }
} finally {
readContext.close()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ class PostgresReportingSetsService(
.map { it.reportingSet }
.withSerializableErrorRetries()
.toList()
} catch (e: ReportingSetNotFoundException) {
throw e.asStatusRuntimeException(Status.Code.NOT_FOUND, "Reporting Set not found")
} finally {
readContext.close()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ kt_jvm_library(
name = "readers",
srcs = glob(["*.kt"]),
deps = [
"//src/main/kotlin/org/wfanet/measurement/reporting/service/internal:internal_exception",
"//src/main/proto/wfa/measurement/internal/reporting/v2:measurement_kt_jvm_proto",
"//src/main/proto/wfa/measurement/internal/reporting/v2:metrics_service_kt_jvm_grpc_proto",
"//src/main/proto/wfa/measurement/internal/reporting/v2:reporting_sets_service_kt_jvm_grpc_proto",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ import org.wfanet.measurement.internal.reporting.v2.measurement
import org.wfanet.measurement.internal.reporting.v2.metric
import org.wfanet.measurement.internal.reporting.v2.metricSpec
import org.wfanet.measurement.internal.reporting.v2.timeInterval
import org.wfanet.measurement.reporting.service.internal.MetricNotFoundException

class MetricReader(private val readContext: ReadContext) {
data class Result(
Expand Down Expand Up @@ -181,9 +182,28 @@ class MetricReader(private val readContext: ReadContext) {
createMetricRequestIds.forEach { bind(bindingMap.getValue(it), it) }
}

return createResultFlow(statement)
return flow {
val metricInfoMap = buildResultMap(statement)

for (entry in metricInfoMap) {
val metricInfo = entry.value

val metric = metricInfo.buildMetric()

val createMetricRequestId = metricInfo.createMetricRequestId ?: ""
emit(
Result(
measurementConsumerId = metricInfo.measurementConsumerId,
metricId = metricInfo.metricId,
createMetricRequestId = createMetricRequestId,
metric = metric
)
)
}
}
}

/** Throws [MetricNotFoundException] if any Metric not found. */
fun batchGetMetrics(
request: BatchGetMetricsRequest,
): Flow<Result> {
Expand Down Expand Up @@ -218,7 +238,26 @@ class MetricReader(private val readContext: ReadContext) {
request.externalMetricIdsList.forEach { bind(bindingMap.getValue(it), it) }
}

return createResultFlow(statement)
return flow {
val metricInfoMap = buildResultMap(statement)

for (externalMetricId in request.externalMetricIdsList) {
val metricInfo =
metricInfoMap[ExternalId(externalMetricId)] ?: throw MetricNotFoundException()

val metric = metricInfo.buildMetric()

val createMetricRequestId = metricInfo.createMetricRequestId ?: ""
emit(
Result(
measurementConsumerId = metricInfo.measurementConsumerId,
metricId = metricInfo.metricId,
createMetricRequestId = createMetricRequestId,
metric = metric
)
)
}
}
}

fun readMetrics(
Expand Down Expand Up @@ -252,59 +291,19 @@ class MetricReader(private val readContext: ReadContext) {
}
}

return createResultFlow(statement)
}

private fun createResultFlow(statement: BoundStatement): Flow<Result> {
return flow {
val metricInfoMap = buildResultMap(statement)

for (entry in metricInfoMap) {
val metricId = entry.key
val metricInfo = entry.value

val metric = metric {
cmmsMeasurementConsumerId = metricInfo.cmmsMeasurementConsumerId
externalMetricId = metricInfo.externalMetricId.value
externalReportingSetId = metricInfo.externalReportingSetId.value
createTime = metricInfo.createTime
timeInterval = metricInfo.timeInterval
metricSpec = metricInfo.metricSpec
metricInfo.weightedMeasurementInfoMap.values.forEach {
weightedMeasurements +=
MetricKt.weightedMeasurement {
weight = it.weight
measurement = measurement {
cmmsMeasurementConsumerId = metricInfo.cmmsMeasurementConsumerId
if (it.measurementInfo.cmmsMeasurementId != null) {
cmmsMeasurementId = it.measurementInfo.cmmsMeasurementId
}
cmmsCreateMeasurementRequestId = it.measurementInfo.cmmsCreateMeasurementRequestId
timeInterval = it.measurementInfo.timeInterval
it.measurementInfo.primitiveReportingSetBasisInfoMap.values.forEach {
primitiveReportingSetBases +=
ReportingSetKt.primitiveReportingSetBasis {
externalReportingSetId = it.externalReportingSetId.value
filters += it.filterSet
}
}
state = it.measurementInfo.state
if (it.measurementInfo.details != Measurement.Details.getDefaultInstance()) {
details = it.measurementInfo.details
}
}
}
}
if (metricInfo.details != Metric.Details.getDefaultInstance()) {
details = metricInfo.details
}
}
val metric = metricInfo.buildMetric()

val createMetricRequestId = metricInfo.createMetricRequestId ?: ""
emit(
Result(
measurementConsumerId = metricInfo.measurementConsumerId,
metricId = metricId,
metricId = metricInfo.metricId,
createMetricRequestId = createMetricRequestId,
metric = metric
)
Expand All @@ -313,10 +312,50 @@ class MetricReader(private val readContext: ReadContext) {
}
}

private fun MetricInfo.buildMetric(): Metric {
val metricInfo = this
return metric {
cmmsMeasurementConsumerId = metricInfo.cmmsMeasurementConsumerId
externalMetricId = metricInfo.externalMetricId.value
externalReportingSetId = metricInfo.externalReportingSetId.value
createTime = metricInfo.createTime
timeInterval = metricInfo.timeInterval
metricSpec = metricInfo.metricSpec
metricInfo.weightedMeasurementInfoMap.values.forEach {
weightedMeasurements +=
MetricKt.weightedMeasurement {
weight = it.weight
measurement = measurement {
cmmsMeasurementConsumerId = metricInfo.cmmsMeasurementConsumerId
if (it.measurementInfo.cmmsMeasurementId != null) {
cmmsMeasurementId = it.measurementInfo.cmmsMeasurementId
}
cmmsCreateMeasurementRequestId = it.measurementInfo.cmmsCreateMeasurementRequestId
timeInterval = it.measurementInfo.timeInterval
it.measurementInfo.primitiveReportingSetBasisInfoMap.values.forEach {
primitiveReportingSetBases +=
ReportingSetKt.primitiveReportingSetBasis {
externalReportingSetId = it.externalReportingSetId.value
filters += it.filterSet
}
}
state = it.measurementInfo.state
if (it.measurementInfo.details != Measurement.Details.getDefaultInstance()) {
details = it.measurementInfo.details
}
}
}
}
if (metricInfo.details != Metric.Details.getDefaultInstance()) {
details = metricInfo.details
}
}
}

/** Returns a map that maintains the order of the query result. */
private suspend fun buildResultMap(statement: BoundStatement): Map<InternalId, MetricInfo> {
// Key is metricId.
val metricInfoMap: MutableMap<InternalId, MetricInfo> = linkedMapOf()
private suspend fun buildResultMap(statement: BoundStatement): Map<ExternalId, MetricInfo> {
// Key is externalMetricId.
val metricInfoMap: MutableMap<ExternalId, MetricInfo> = linkedMapOf()

val translate: (row: ResultRow) -> Unit = { row: ResultRow ->
val measurementConsumerId: InternalId = row["MeasurementConsumerId"]
Expand Down Expand Up @@ -353,7 +392,7 @@ class MetricReader(private val readContext: ReadContext) {
val primitiveReportingSetBasisFilter: String? = row["PrimitiveReportingSetBasisFilter"]

val metricInfo =
metricInfoMap.computeIfAbsent(metricId) {
metricInfoMap.computeIfAbsent(externalMetricId) {
val metricTimeInterval = timeInterval {
startTime = metricTimeIntervalStart.toProtoTime()
endTime = metricTimeIntervalEnd.toProtoTime()
Expand Down
Loading

0 comments on commit 901fa10

Please sign in to comment.