Skip to content

Commit

Permalink
Replace one update call per row with one update call for multiple rows (
Browse files Browse the repository at this point in the history
  • Loading branch information
tristanvuong2021 authored Mar 13, 2024
1 parent 746b7a6 commit 6dbb326
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@

package org.wfanet.measurement.reporting.deploy.v2.postgres.writers

import org.wfanet.measurement.common.db.r2dbc.boundStatement
import org.wfanet.measurement.common.db.r2dbc.postgres.PostgresWriter
import org.wfanet.measurement.common.db.r2dbc.postgres.ValuesListBoundStatement
import org.wfanet.measurement.common.db.r2dbc.postgres.valuesListBoundStatement
import org.wfanet.measurement.internal.reporting.v2.BatchSetCmmsMeasurementIdsRequest
import org.wfanet.measurement.internal.reporting.v2.Measurement
import org.wfanet.measurement.reporting.deploy.v2.postgres.readers.MeasurementConsumerReader
Expand All @@ -41,19 +42,22 @@ class SetCmmsMeasurementIds(private val request: BatchSetCmmsMeasurementIdsReque
.measurementConsumerId

val statement =
boundStatement(
valuesListBoundStatement(
valuesStartIndex = 2,
paramCount = 2,
"""
UPDATE Measurements SET CmmsMeasurementId = $1, State = $2
WHERE MeasurementConsumerId = $3 AND CmmsCreateMeasurementRequestId::text = $4
"""
UPDATE Measurements AS m SET CmmsMeasurementId = c.CmmsMeasurementId, State = $1
FROM (VALUES ${ValuesListBoundStatement.VALUES_LIST_PLACEHOLDER})
AS c(CmmsMeasurementId, CmmsCreateMeasurementRequestId)
WHERE MeasurementConsumerId = $2 AND m.CmmsCreateMeasurementRequestId = c.CmmsCreateMeasurementRequestId::uuid
""",
) {
val state = Measurement.State.PENDING
bind("$1", Measurement.State.PENDING)
bind("$2", measurementConsumerId)
request.measurementIdsList.forEach {
addBinding {
bind("$1", it.cmmsMeasurementId)
bind("$2", state)
bind("$3", measurementConsumerId)
bind("$4", it.cmmsCreateMeasurementRequestId)
addValuesBinding {
bindValuesParam(0, it.cmmsMeasurementId)
bindValuesParam(1, it.cmmsCreateMeasurementRequestId)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@

package org.wfanet.measurement.reporting.deploy.v2.postgres.writers

import org.wfanet.measurement.common.db.r2dbc.boundStatement
import org.wfanet.measurement.common.db.r2dbc.postgres.PostgresWriter
import org.wfanet.measurement.common.db.r2dbc.postgres.ValuesListBoundStatement
import org.wfanet.measurement.common.db.r2dbc.postgres.valuesListBoundStatement
import org.wfanet.measurement.common.toJson
import org.wfanet.measurement.internal.reporting.v2.BatchSetMeasurementFailuresRequest
import org.wfanet.measurement.internal.reporting.v2.Measurement
Expand All @@ -43,22 +44,27 @@ class SetMeasurementFailures(private val request: BatchSetMeasurementFailuresReq
.measurementConsumerId

val statement =
boundStatement(
valuesListBoundStatement(
valuesStartIndex = 2,
paramCount = 3,
"""
UPDATE Measurements SET MeasurementDetails = $1,
MeasurementDetailsJson = $2, State = $3
WHERE MeasurementConsumerId = $4 AND CmmsMeasurementId = $5
"""
UPDATE Measurements AS m SET
MeasurementDetails = c.MeasurementDetails,
MeasurementDetailsJson = c.MeasurementDetailsJson,
State = $1
FROM (VALUES ${ValuesListBoundStatement.VALUES_LIST_PLACEHOLDER})
AS c(MeasurementDetails, MeasurementDetailsJson, CmmsMeasurementId)
WHERE MeasurementConsumerId = $2 AND m.CmmsMeasurementId = c.CmmsMeasurementId
""",
) {
val state = Measurement.State.FAILED
bind("$1", Measurement.State.FAILED)
bind("$2", measurementConsumerId)
request.measurementFailuresList.forEach {
val details = MeasurementKt.details { failure = it.failure }
addBinding {
bind("$1", details)
bind("$2", details.toJson())
bind("$3", state)
bind("$4", measurementConsumerId)
bind("$5", it.cmmsMeasurementId)
addValuesBinding {
bindValuesParam(0, details)
bindValuesParam(1, details.toJson())
bindValuesParam(2, it.cmmsMeasurementId)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@

package org.wfanet.measurement.reporting.deploy.v2.postgres.writers

import org.wfanet.measurement.common.db.r2dbc.boundStatement
import org.wfanet.measurement.common.db.r2dbc.postgres.PostgresWriter
import org.wfanet.measurement.common.db.r2dbc.postgres.ValuesListBoundStatement
import org.wfanet.measurement.common.db.r2dbc.postgres.valuesListBoundStatement
import org.wfanet.measurement.common.toJson
import org.wfanet.measurement.internal.reporting.v2.BatchSetMeasurementResultsRequest
import org.wfanet.measurement.internal.reporting.v2.Measurement
Expand All @@ -43,22 +44,27 @@ class SetMeasurementResults(private val request: BatchSetMeasurementResultsReque
.measurementConsumerId

val statement =
boundStatement(
valuesListBoundStatement(
valuesStartIndex = 2,
paramCount = 3,
"""
UPDATE Measurements SET MeasurementDetails = $1,
MeasurementDetailsJson = $2, State = $3
WHERE MeasurementConsumerId = $4 AND CmmsMeasurementId = $5
"""
UPDATE Measurements AS m SET
MeasurementDetails = c.MeasurementDetails,
MeasurementDetailsJson = c.MeasurementDetailsJson,
State = $1
FROM (VALUES ${ValuesListBoundStatement.VALUES_LIST_PLACEHOLDER})
AS c(MeasurementDetails, MeasurementDetailsJson, CmmsMeasurementId)
WHERE MeasurementConsumerId = $2 AND m.CmmsMeasurementId = c.CmmsMeasurementId
""",
) {
val state = Measurement.State.SUCCEEDED
bind("$1", Measurement.State.SUCCEEDED)
bind("$2", measurementConsumerId)
request.measurementResultsList.forEach {
val details = MeasurementKt.details { results += it.resultsList }
addBinding {
bind("$1", details)
bind("$2", details.toJson())
bind("$3", state)
bind("$4", measurementConsumerId)
bind("$5", it.cmmsMeasurementId)
addValuesBinding {
bindValuesParam(0, details)
bindValuesParam(1, details.toJson())
bindValuesParam(2, it.cmmsMeasurementId)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import com.google.type.interval
import io.grpc.Status
import io.grpc.StatusRuntimeException
import java.time.Clock
import java.util.UUID
import kotlin.random.Random
import kotlin.test.assertFailsWith
import kotlinx.coroutines.runBlocking
Expand Down Expand Up @@ -327,7 +328,7 @@ abstract class MeasurementsServiceTest<T : MeasurementsGrpcKt.MeasurementsCorout
}
measurementIds +=
BatchSetCmmsMeasurementIdsRequestKt.measurementIds {
cmmsCreateMeasurementRequestId = "1"
cmmsCreateMeasurementRequestId = UUID.randomUUID().toString()
cmmsMeasurementId = "1235"
}
}
Expand All @@ -350,7 +351,7 @@ abstract class MeasurementsServiceTest<T : MeasurementsGrpcKt.MeasurementsCorout
cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID
measurementIds +=
BatchSetCmmsMeasurementIdsRequestKt.measurementIds {
cmmsCreateMeasurementRequestId = "1"
cmmsCreateMeasurementRequestId = UUID.randomUUID().toString()
cmmsMeasurementId = "1235"
}
}
Expand All @@ -371,7 +372,7 @@ abstract class MeasurementsServiceTest<T : MeasurementsGrpcKt.MeasurementsCorout
cmmsMeasurementConsumerId = CMMS_MEASUREMENT_CONSUMER_ID
measurementIds +=
BatchSetCmmsMeasurementIdsRequestKt.measurementIds {
cmmsCreateMeasurementRequestId = "1"
cmmsCreateMeasurementRequestId = UUID.randomUUID().toString()
cmmsMeasurementId = "1235"
}
}
Expand All @@ -391,7 +392,7 @@ abstract class MeasurementsServiceTest<T : MeasurementsGrpcKt.MeasurementsCorout
batchSetCmmsMeasurementIdsRequest {
measurementIds +=
BatchSetCmmsMeasurementIdsRequestKt.measurementIds {
cmmsCreateMeasurementRequestId = "1234"
cmmsCreateMeasurementRequestId = UUID.randomUUID().toString()
cmmsMeasurementId = "1234"
}
}
Expand All @@ -412,7 +413,7 @@ abstract class MeasurementsServiceTest<T : MeasurementsGrpcKt.MeasurementsCorout
for (i in 1L..(MAX_BATCH_SIZE + 1)) {
measurementIds +=
BatchSetCmmsMeasurementIdsRequestKt.measurementIds {
cmmsCreateMeasurementRequestId = "123"
cmmsCreateMeasurementRequestId = UUID.randomUUID().toString()
cmmsMeasurementId = "1234"
}
}
Expand Down Expand Up @@ -1381,6 +1382,7 @@ abstract class MeasurementsServiceTest<T : MeasurementsGrpcKt.MeasurementsCorout
MetricKt.weightedMeasurement {
weight = 2
measurement = measurement {
cmmsCreateMeasurementRequestId = UUID.randomUUID().toString()
this.cmmsMeasurementConsumerId = cmmsMeasurementConsumerId
timeInterval = interval {
startTime = timestamp { seconds = 10 }
Expand Down

0 comments on commit 6dbb326

Please sign in to comment.