Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace one update call per row with one update call for multiple rows #1518

Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@

package org.wfanet.measurement.reporting.deploy.v2.postgres.writers

import org.wfanet.measurement.common.db.r2dbc.boundStatement
import org.wfanet.measurement.common.db.r2dbc.postgres.PostgresWriter
import org.wfanet.measurement.common.db.r2dbc.postgres.ValuesListBoundStatement
import org.wfanet.measurement.common.db.r2dbc.postgres.valuesListBoundStatement
import org.wfanet.measurement.internal.reporting.v2.BatchSetCmmsMeasurementIdsRequest
import org.wfanet.measurement.internal.reporting.v2.Measurement
import org.wfanet.measurement.reporting.deploy.v2.postgres.readers.MeasurementConsumerReader
Expand All @@ -41,19 +42,22 @@ class SetCmmsMeasurementIds(private val request: BatchSetCmmsMeasurementIdsReque
.measurementConsumerId

val statement =
boundStatement(
valuesListBoundStatement(
valuesStartIndex = 2,
paramCount = 2,
"""
UPDATE Measurements SET CmmsMeasurementId = $1, State = $2
WHERE MeasurementConsumerId = $3 AND CmmsCreateMeasurementRequestId::text = $4
"""
UPDATE Measurements AS m SET CmmsMeasurementId = c.CmmsMeasurementId, State = $1
FROM (VALUES ${ValuesListBoundStatement.VALUES_LIST_PLACEHOLDER})
AS c(CmmsMeasurementId, CmmsCreateMeasurementRequestId)
WHERE MeasurementConsumerId = $2 AND m.CmmsCreateMeasurementRequestId = c.CmmsCreateMeasurementRequestId::uuid
""",
) {
val state = Measurement.State.PENDING
bind("$1", Measurement.State.PENDING)
bind("$2", measurementConsumerId)
request.measurementIdsList.forEach {
addBinding {
bind("$1", it.cmmsMeasurementId)
bind("$2", state)
bind("$3", measurementConsumerId)
bind("$4", it.cmmsCreateMeasurementRequestId)
addValuesBinding {
bindValuesParam(0, it.cmmsMeasurementId)
bindValuesParam(1, it.cmmsCreateMeasurementRequestId)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@

package org.wfanet.measurement.reporting.deploy.v2.postgres.writers

import org.wfanet.measurement.common.db.r2dbc.boundStatement
import org.wfanet.measurement.common.db.r2dbc.postgres.PostgresWriter
import org.wfanet.measurement.common.db.r2dbc.postgres.ValuesListBoundStatement
import org.wfanet.measurement.common.db.r2dbc.postgres.valuesListBoundStatement
import org.wfanet.measurement.common.toJson
import org.wfanet.measurement.internal.reporting.v2.BatchSetMeasurementFailuresRequest
import org.wfanet.measurement.internal.reporting.v2.Measurement
Expand All @@ -43,22 +44,27 @@ class SetMeasurementFailures(private val request: BatchSetMeasurementFailuresReq
.measurementConsumerId

val statement =
boundStatement(
valuesListBoundStatement(
valuesStartIndex = 2,
paramCount = 3,
"""
UPDATE Measurements SET MeasurementDetails = $1,
MeasurementDetailsJson = $2, State = $3
WHERE MeasurementConsumerId = $4 AND CmmsMeasurementId = $5
"""
UPDATE Measurements AS m SET
MeasurementDetails = c.MeasurementDetails,
MeasurementDetailsJson = c.MeasurementDetailsJson,
State = $1
FROM (VALUES ${ValuesListBoundStatement.VALUES_LIST_PLACEHOLDER})
AS c(MeasurementDetails, MeasurementDetailsJson, CmmsMeasurementId)
WHERE MeasurementConsumerId = $2 AND m.CmmsMeasurementId = c.CmmsMeasurementId
""",
) {
val state = Measurement.State.FAILED
bind("$1", Measurement.State.FAILED)
bind("$2", measurementConsumerId)
request.measurementFailuresList.forEach {
val details = MeasurementKt.details { failure = it.failure }
addBinding {
bind("$1", details)
bind("$2", details.toJson())
bind("$3", state)
bind("$4", measurementConsumerId)
bind("$5", it.cmmsMeasurementId)
addValuesBinding {
bindValuesParam(0, details)
bindValuesParam(1, details.toJson())
bindValuesParam(2, it.cmmsMeasurementId)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@

package org.wfanet.measurement.reporting.deploy.v2.postgres.writers

import org.wfanet.measurement.common.db.r2dbc.boundStatement
import org.wfanet.measurement.common.db.r2dbc.postgres.PostgresWriter
import org.wfanet.measurement.common.db.r2dbc.postgres.ValuesListBoundStatement
import org.wfanet.measurement.common.db.r2dbc.postgres.valuesListBoundStatement
import org.wfanet.measurement.common.toJson
import org.wfanet.measurement.internal.reporting.v2.BatchSetMeasurementResultsRequest
import org.wfanet.measurement.internal.reporting.v2.Measurement
Expand All @@ -43,22 +44,27 @@ class SetMeasurementResults(private val request: BatchSetMeasurementResultsReque
.measurementConsumerId

val statement =
boundStatement(
valuesListBoundStatement(
valuesStartIndex = 2,
paramCount = 3,
"""
UPDATE Measurements SET MeasurementDetails = $1,
MeasurementDetailsJson = $2, State = $3
WHERE MeasurementConsumerId = $4 AND CmmsMeasurementId = $5
"""
UPDATE Measurements AS m SET
MeasurementDetails = c.MeasurementDetails,
MeasurementDetailsJson = c.MeasurementDetailsJson,
State = $1
FROM (VALUES ${ValuesListBoundStatement.VALUES_LIST_PLACEHOLDER})
AS c(MeasurementDetails, MeasurementDetailsJson, CmmsMeasurementId)
WHERE MeasurementConsumerId = $2 AND m.CmmsMeasurementId = c.CmmsMeasurementId
""",
) {
val state = Measurement.State.SUCCEEDED
bind("$1", Measurement.State.SUCCEEDED)
bind("$2", measurementConsumerId)
request.measurementResultsList.forEach {
val details = MeasurementKt.details { results += it.resultsList }
addBinding {
bind("$1", details)
bind("$2", details.toJson())
bind("$3", state)
bind("$4", measurementConsumerId)
bind("$5", it.cmmsMeasurementId)
addValuesBinding {
bindValuesParam(0, details)
bindValuesParam(1, details.toJson())
bindValuesParam(2, it.cmmsMeasurementId)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import com.google.type.interval
import io.grpc.Status
import io.grpc.StatusRuntimeException
import java.time.Clock
import java.util.UUID
import kotlin.random.Random
import kotlin.test.assertFailsWith
import kotlinx.coroutines.runBlocking
Expand Down Expand Up @@ -327,7 +328,7 @@ abstract class MeasurementsServiceTest<T : MeasurementsGrpcKt.MeasurementsCorout
}
measurementIds +=
BatchSetCmmsMeasurementIdsRequestKt.measurementIds {
cmmsCreateMeasurementRequestId = "1"
cmmsCreateMeasurementRequestId = UUID.randomUUID().toString()
cmmsMeasurementId = "1235"
}
}
Expand All @@ -350,7 +351,7 @@ abstract class MeasurementsServiceTest<T : MeasurementsGrpcKt.MeasurementsCorout
cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID
measurementIds +=
BatchSetCmmsMeasurementIdsRequestKt.measurementIds {
cmmsCreateMeasurementRequestId = "1"
cmmsCreateMeasurementRequestId = UUID.randomUUID().toString()
cmmsMeasurementId = "1235"
}
}
Expand All @@ -371,7 +372,7 @@ abstract class MeasurementsServiceTest<T : MeasurementsGrpcKt.MeasurementsCorout
cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID
measurementIds +=
BatchSetCmmsMeasurementIdsRequestKt.measurementIds {
cmmsCreateMeasurementRequestId = "1"
cmmsCreateMeasurementRequestId = UUID.randomUUID().toString()
cmmsMeasurementId = "1235"
}
}
Expand All @@ -391,7 +392,7 @@ abstract class MeasurementsServiceTest<T : MeasurementsGrpcKt.MeasurementsCorout
batchSetCmmsMeasurementIdsRequest {
measurementIds +=
BatchSetCmmsMeasurementIdsRequestKt.measurementIds {
cmmsCreateMeasurementRequestId = "1234"
cmmsCreateMeasurementRequestId = UUID.randomUUID().toString()
cmmsMeasurementId = "1234"
}
}
Expand All @@ -412,7 +413,7 @@ abstract class MeasurementsServiceTest<T : MeasurementsGrpcKt.MeasurementsCorout
for (i in 1L..(MAX_BATCH_SIZE + 1)) {
measurementIds +=
BatchSetCmmsMeasurementIdsRequestKt.measurementIds {
cmmsCreateMeasurementRequestId = "123"
cmmsCreateMeasurementRequestId = UUID.randomUUID().toString()
cmmsMeasurementId = "1234"
}
}
Expand Down Expand Up @@ -1381,6 +1382,7 @@ abstract class MeasurementsServiceTest<T : MeasurementsGrpcKt.MeasurementsCorout
MetricKt.weightedMeasurement {
weight = 2
measurement = measurement {
cmmsCreateMeasurementRequestId = UUID.randomUUID().toString()
this.cmmsMeasurementConsumerId = cmmsMeasurementConsumerId
timeInterval = interval {
startTime = timestamp { seconds = 10 }
Expand Down
Loading