Skip to content

Commit

Permalink
Merge branch 'main' into tristanvuong-minor-improvements-reporting-v2…
Browse files Browse the repository at this point in the history
…-queries
  • Loading branch information
tristanvuong2021 authored Apr 8, 2024
2 parents 9856cf4 + baaf158 commit 6e5e997
Show file tree
Hide file tree
Showing 13 changed files with 659 additions and 126 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ kt_jvm_library(
"@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/db/r2dbc",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/db/r2dbc/postgres",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/identity",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ import org.wfanet.measurement.common.db.r2dbc.BoundStatement
import org.wfanet.measurement.common.db.r2dbc.ReadContext
import org.wfanet.measurement.common.db.r2dbc.ResultRow
import org.wfanet.measurement.common.db.r2dbc.boundStatement
import org.wfanet.measurement.common.db.r2dbc.postgres.ValuesListBoundStatement
import org.wfanet.measurement.common.db.r2dbc.postgres.valuesListBoundStatement
import org.wfanet.measurement.common.identity.InternalId
import org.wfanet.measurement.common.toInstant
import org.wfanet.measurement.common.toProtoDuration
Expand Down Expand Up @@ -81,6 +83,7 @@ class MetricReader(private val readContext: ReadContext) {
val metricSpec: MetricSpec,
val weightedMeasurementInfoMap: MutableMap<MetricMeasurementKey, WeightedMeasurementInfo>,
val details: Metric.Details,
val state: Metric.State,
)

private data class MetricMeasurementKey(
Expand Down Expand Up @@ -133,14 +136,15 @@ class MetricReader(private val readContext: ReadContext) {
Metrics.VidSamplingIntervalWidth,
Metrics.CreateTime,
Metrics.MetricDetails,
Metrics.State as MetricsState,
MetricMeasurements.Coefficient,
MetricMeasurements.BinaryRepresentation,
Measurements.MeasurementId,
Measurements.CmmsCreateMeasurementRequestId,
Measurements.CmmsMeasurementId,
Measurements.TimeIntervalStart AS MeasurementsTimeIntervalStart,
Measurements.TimeIntervalEndExclusive AS MeasurementsTimeIntervalEndExclusive,
Measurements.State,
Measurements.State as MeasurementsState,
Measurements.MeasurementDetails,
PrimitiveReportingSetBases.PrimitiveReportingSetBasisId,
PrimitiveReportingSets.ExternalReportingSetId AS PrimitiveExternalReportingSetId,
Expand Down Expand Up @@ -181,26 +185,64 @@ class MetricReader(private val readContext: ReadContext) {
JOIN Metrics USING(MeasurementConsumerId)
$baseSqlJoins
WHERE Metrics.MeasurementConsumerId = $1
AND CreateMetricRequestId IN
AND CreateMetricRequestId IN (VALUES ${ValuesListBoundStatement.VALUES_LIST_PLACEHOLDER})
"""
.trimIndent()
)

var i = 2
val bindingMap = mutableMapOf<String, String>()
val inList =
createMetricRequestIds.joinToString(separator = ",", prefix = "(", postfix = ")") {
val index = "$$i"
bindingMap[it] = index
i++
index
val statement =
valuesListBoundStatement(valuesStartIndex = 1, paramCount = 1, sql.toString()) {
bind("$1", measurementConsumerId)
createMetricRequestIds.forEach { addValuesBinding { bindValuesParam(0, it) } }
}
sql.append(inList)

return flow {
val metricInfoMap = buildResultMap(statement)

for (entry in metricInfoMap) {
val metricInfo = entry.value

val metric = metricInfo.buildMetric()

val createMetricRequestId = metricInfo.createMetricRequestId ?: ""
emit(
Result(
measurementConsumerId = metricInfo.measurementConsumerId,
metricId = metricInfo.metricId,
createMetricRequestId = createMetricRequestId,
metric = metric,
)
)
}
}
}

fun readMetricsByCmmsMeasurementId(
measurementConsumerId: InternalId,
cmmsMeasurementIds: Collection<String>,
): Flow<Result> {
if (cmmsMeasurementIds.isEmpty()) {
return emptyFlow()
}

val sql =
StringBuilder(
"""
$baseSqlSelect
FROM
MeasurementConsumers
JOIN Metrics USING(MeasurementConsumerId)
$baseSqlJoins
WHERE Metrics.MeasurementConsumerId = $1
AND CmmsMeasurementId IN (VALUES ${ValuesListBoundStatement.VALUES_LIST_PLACEHOLDER})
"""
.trimIndent()
)

val statement =
boundStatement(sql.toString()) {
valuesListBoundStatement(valuesStartIndex = 1, paramCount = 1, sql.toString()) {
bind("$1", measurementConsumerId)
createMetricRequestIds.forEach { bind(bindingMap.getValue(it), it) }
cmmsMeasurementIds.forEach { addValuesBinding { bindValuesParam(0, it) } }
}

return flow {
Expand Down Expand Up @@ -625,6 +667,7 @@ class MetricReader(private val readContext: ReadContext) {
if (metricInfo.details != Metric.Details.getDefaultInstance()) {
details = metricInfo.details
}
state = metricInfo.state
}
}

Expand Down Expand Up @@ -662,12 +705,14 @@ class MetricReader(private val readContext: ReadContext) {
val cmmsMeasurementId: String? = row["CmmsMeasurementId"]
val measurementTimeIntervalStart: Instant = row["MeasurementsTimeIntervalStart"]
val measurementTimeIntervalEnd: Instant = row["MeasurementsTimeIntervalEndExclusive"]
val measurementState: Measurement.State = Measurement.State.forNumber(row["State"])
val measurementState: Measurement.State =
row.getProtoEnum("MeasurementsState", Measurement.State::forNumber)
val measurementDetails: Measurement.Details =
row.getProtoMessage("MeasurementDetails", Measurement.Details.parser())
val primitiveReportingSetBasisId: InternalId = row["PrimitiveReportingSetBasisId"]
val primitiveExternalReportingSetId: String = row["PrimitiveExternalReportingSetId"]
val primitiveReportingSetBasisFilter: String? = row["PrimitiveReportingSetBasisFilter"]
val metricState: Metric.State = row.getProtoEnum("MetricsState", Metric.State::forNumber)

val metricInfo =
metricInfoMap.computeIfAbsent(externalMetricId) {
Expand Down Expand Up @@ -767,6 +812,7 @@ class MetricReader(private val readContext: ReadContext) {
metricSpec = metricSpec,
details = metricDetails,
weightedMeasurementInfoMap = mutableMapOf(),
state = metricState,
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ class CreateMetrics(private val requests: List<CreateMetricRequest>) :
val statement =
valuesListBoundStatement(
valuesStartIndex = 0,
paramCount = 20,
paramCount = 21,
"""
INSERT INTO Metrics
(
Expand All @@ -216,11 +216,13 @@ class CreateMetrics(private val requests: List<CreateMetricRequest>) :
VidSamplingIntervalWidth,
CreateTime,
MetricDetails,
MetricDetailsJson
MetricDetailsJson,
State
)
VALUES ${ValuesListBoundStatement.VALUES_LIST_PLACEHOLDER}
""",
) {
val createTime = Instant.now().atOffset(ZoneOffset.UTC).truncatedTo(ChronoUnit.MICROS)
requests.forEach {
val existingMetric: Metric? = existingMetricsMap[it.requestId]
if (existingMetric != null) {
Expand All @@ -229,7 +231,6 @@ class CreateMetrics(private val requests: List<CreateMetricRequest>) :
val metricId = idGenerator.generateInternalId()
val externalMetricId: String = it.externalMetricId
val reportingSetId: InternalId? = reportingSetMap[it.metric.externalReportingSetId]
val createTime = Instant.now().atOffset(ZoneOffset.UTC).truncatedTo(ChronoUnit.MICROS)
val vidSamplingIntervalStart =
if (it.metric.metricSpec.typeCase == MetricSpec.TypeCase.POPULATION_COUNT) 0
else it.metric.metricSpec.vidSamplingInterval.start
Expand Down Expand Up @@ -317,6 +318,7 @@ class CreateMetrics(private val requests: List<CreateMetricRequest>) :
bindValuesParam(17, createTime)
bindValuesParam(18, it.metric.details)
bindValuesParam(19, it.metric.details.toJson())
bindValuesParam(20, Metric.State.RUNNING)
}

if (it.requestId.isNotEmpty()) {
Expand Down Expand Up @@ -351,6 +353,7 @@ class CreateMetrics(private val requests: List<CreateMetricRequest>) :
weightedMeasurements.clear()
weightedMeasurements.addAll(weightedMeasurementsAndInsertData.weightedMeasurements)
this.createTime = createTime.toInstant().toProtoTime()
state = Metric.State.RUNNING
}
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,14 @@ package org.wfanet.measurement.reporting.deploy.v2.postgres.writers
import org.wfanet.measurement.common.db.r2dbc.postgres.PostgresWriter
import org.wfanet.measurement.common.db.r2dbc.postgres.ValuesListBoundStatement
import org.wfanet.measurement.common.db.r2dbc.postgres.valuesListBoundStatement
import org.wfanet.measurement.common.identity.InternalId
import org.wfanet.measurement.common.toJson
import org.wfanet.measurement.internal.reporting.v2.BatchSetMeasurementFailuresRequest
import org.wfanet.measurement.internal.reporting.v2.Measurement
import org.wfanet.measurement.internal.reporting.v2.MeasurementKt
import org.wfanet.measurement.internal.reporting.v2.Metric
import org.wfanet.measurement.reporting.deploy.v2.postgres.readers.MeasurementConsumerReader
import org.wfanet.measurement.reporting.deploy.v2.postgres.readers.MetricReader
import org.wfanet.measurement.reporting.service.internal.MeasurementConsumerNotFoundException
import org.wfanet.measurement.reporting.service.internal.MeasurementNotFoundException

Expand Down Expand Up @@ -72,5 +75,35 @@ class SetMeasurementFailures(private val request: BatchSetMeasurementFailuresReq
if (result.numRowsUpdated < request.measurementFailuresList.size) {
throw MeasurementNotFoundException()
}

// Read all metrics tied to Measurements that were updated.
val metricIds: List<InternalId> = buildList {
MetricReader(transactionContext)
.readMetricsByCmmsMeasurementId(
measurementConsumerId,
request.measurementFailuresList.map { it.cmmsMeasurementId },
)
.collect { metricReaderResult -> add(metricReaderResult.metricId) }
}

if (metricIds.isNotEmpty()) {
val metricStateUpdateStatement =
valuesListBoundStatement(
valuesStartIndex = 2,
paramCount = 1,
"""
UPDATE Metrics AS m SET State = $1
FROM (VALUES ${ValuesListBoundStatement.VALUES_LIST_PLACEHOLDER})
AS c(MetricId)
WHERE MeasurementConsumerId = $2 AND m.MetricId = c.MetricId
""",
) {
bind("$1", Metric.State.FAILED)
bind("$2", measurementConsumerId)
metricIds.forEach { addValuesBinding { bindValuesParam(0, it) } }
}

transactionContext.executeStatement(metricStateUpdateStatement)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,14 @@ package org.wfanet.measurement.reporting.deploy.v2.postgres.writers
import org.wfanet.measurement.common.db.r2dbc.postgres.PostgresWriter
import org.wfanet.measurement.common.db.r2dbc.postgres.ValuesListBoundStatement
import org.wfanet.measurement.common.db.r2dbc.postgres.valuesListBoundStatement
import org.wfanet.measurement.common.identity.InternalId
import org.wfanet.measurement.common.toJson
import org.wfanet.measurement.internal.reporting.v2.BatchSetMeasurementResultsRequest
import org.wfanet.measurement.internal.reporting.v2.Measurement
import org.wfanet.measurement.internal.reporting.v2.MeasurementKt
import org.wfanet.measurement.internal.reporting.v2.Metric
import org.wfanet.measurement.reporting.deploy.v2.postgres.readers.MeasurementConsumerReader
import org.wfanet.measurement.reporting.deploy.v2.postgres.readers.MetricReader
import org.wfanet.measurement.reporting.service.internal.MeasurementConsumerNotFoundException
import org.wfanet.measurement.reporting.service.internal.MeasurementNotFoundException

Expand Down Expand Up @@ -72,5 +75,43 @@ class SetMeasurementResults(private val request: BatchSetMeasurementResultsReque
if (result.numRowsUpdated < request.measurementResultsList.size) {
throw MeasurementNotFoundException()
}

// Read all metrics tied to Measurements that were updated and determine any state changes.
val metricIds: List<InternalId> = buildList {
MetricReader(transactionContext)
.readMetricsByCmmsMeasurementId(
measurementConsumerId,
request.measurementResultsList.map { it.cmmsMeasurementId },
)
.collect { metricReaderResult ->
if (metricReaderResult.metric.state == Metric.State.RUNNING) {
val measurementStates =
metricReaderResult.metric.weightedMeasurementsList.map { it.measurement.state }
if (measurementStates.all { it == Measurement.State.SUCCEEDED }) {
add(metricReaderResult.metricId)
}
}
}
}

if (metricIds.isNotEmpty()) {
val metricStateUpdateStatement =
valuesListBoundStatement(
valuesStartIndex = 2,
paramCount = 1,
"""
UPDATE Metrics AS m SET State = $1
FROM (VALUES ${ValuesListBoundStatement.VALUES_LIST_PLACEHOLDER})
AS c(MetricId)
WHERE MeasurementConsumerId = $2 AND m.MetricId = c.MetricId
""",
) {
bind("$1", Metric.State.SUCCEEDED)
bind("$2", measurementConsumerId)
metricIds.forEach { addValuesBinding { bindValuesParam(0, it) } }
}

transactionContext.executeStatement(metricStateUpdateStatement)
}
}
}
Loading

0 comments on commit 6e5e997

Please sign in to comment.