From 6dbb326a1118b4af74b5ff966786147b7e4211bc Mon Sep 17 00:00:00 2001 From: Tristan Vuong <85768771+tristanvuong2021@users.noreply.github.com> Date: Wed, 13 Mar 2024 09:59:53 -0700 Subject: [PATCH] Replace one update call per row with one update call for multiple rows (#1518) --- .../postgres/writers/SetCmmsMeasurementIds.kt | 26 ++++++++------- .../writers/SetMeasurementFailures.kt | 32 +++++++++++-------- .../postgres/writers/SetMeasurementResults.kt | 32 +++++++++++-------- .../testing/v2/MeasurementsServiceTest.kt | 12 ++++--- 4 files changed, 60 insertions(+), 42 deletions(-) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/writers/SetCmmsMeasurementIds.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/writers/SetCmmsMeasurementIds.kt index 871d065b2d2..476a0daa2be 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/writers/SetCmmsMeasurementIds.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/writers/SetCmmsMeasurementIds.kt @@ -16,8 +16,9 @@ package org.wfanet.measurement.reporting.deploy.v2.postgres.writers -import org.wfanet.measurement.common.db.r2dbc.boundStatement import org.wfanet.measurement.common.db.r2dbc.postgres.PostgresWriter +import org.wfanet.measurement.common.db.r2dbc.postgres.ValuesListBoundStatement +import org.wfanet.measurement.common.db.r2dbc.postgres.valuesListBoundStatement import org.wfanet.measurement.internal.reporting.v2.BatchSetCmmsMeasurementIdsRequest import org.wfanet.measurement.internal.reporting.v2.Measurement import org.wfanet.measurement.reporting.deploy.v2.postgres.readers.MeasurementConsumerReader @@ -41,19 +42,22 @@ class SetCmmsMeasurementIds(private val request: BatchSetCmmsMeasurementIdsReque .measurementConsumerId val statement = - boundStatement( + valuesListBoundStatement( + valuesStartIndex = 2, + paramCount = 2, """ - UPDATE Measurements SET CmmsMeasurementId = $1, State = $2 - WHERE MeasurementConsumerId = $3 AND CmmsCreateMeasurementRequestId::text = $4 - """ + UPDATE Measurements AS m SET CmmsMeasurementId = c.CmmsMeasurementId, State = $1 + FROM (VALUES ${ValuesListBoundStatement.VALUES_LIST_PLACEHOLDER}) + AS c(CmmsMeasurementId, CmmsCreateMeasurementRequestId) + WHERE MeasurementConsumerId = $2 AND m.CmmsCreateMeasurementRequestId = c.CmmsCreateMeasurementRequestId::uuid + """, ) { - val state = Measurement.State.PENDING + bind("$1", Measurement.State.PENDING) + bind("$2", measurementConsumerId) request.measurementIdsList.forEach { - addBinding { - bind("$1", it.cmmsMeasurementId) - bind("$2", state) - bind("$3", measurementConsumerId) - bind("$4", it.cmmsCreateMeasurementRequestId) + addValuesBinding { + bindValuesParam(0, it.cmmsMeasurementId) + bindValuesParam(1, it.cmmsCreateMeasurementRequestId) } } } diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/writers/SetMeasurementFailures.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/writers/SetMeasurementFailures.kt index 5896a5a68f4..67b09d48af6 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/writers/SetMeasurementFailures.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/writers/SetMeasurementFailures.kt @@ -16,8 +16,9 @@ package org.wfanet.measurement.reporting.deploy.v2.postgres.writers -import org.wfanet.measurement.common.db.r2dbc.boundStatement import org.wfanet.measurement.common.db.r2dbc.postgres.PostgresWriter +import org.wfanet.measurement.common.db.r2dbc.postgres.ValuesListBoundStatement +import org.wfanet.measurement.common.db.r2dbc.postgres.valuesListBoundStatement import org.wfanet.measurement.common.toJson import org.wfanet.measurement.internal.reporting.v2.BatchSetMeasurementFailuresRequest import org.wfanet.measurement.internal.reporting.v2.Measurement @@ -43,22 +44,27 @@ class SetMeasurementFailures(private val request: BatchSetMeasurementFailuresReq .measurementConsumerId val statement = - boundStatement( + valuesListBoundStatement( + valuesStartIndex = 2, + paramCount = 3, """ - UPDATE Measurements SET MeasurementDetails = $1, - MeasurementDetailsJson = $2, State = $3 - WHERE MeasurementConsumerId = $4 AND CmmsMeasurementId = $5 - """ + UPDATE Measurements AS m SET + MeasurementDetails = c.MeasurementDetails, + MeasurementDetailsJson = c.MeasurementDetailsJson, + State = $1 + FROM (VALUES ${ValuesListBoundStatement.VALUES_LIST_PLACEHOLDER}) + AS c(MeasurementDetails, MeasurementDetailsJson, CmmsMeasurementId) + WHERE MeasurementConsumerId = $2 AND m.CmmsMeasurementId = c.CmmsMeasurementId + """, ) { - val state = Measurement.State.FAILED + bind("$1", Measurement.State.FAILED) + bind("$2", measurementConsumerId) request.measurementFailuresList.forEach { val details = MeasurementKt.details { failure = it.failure } - addBinding { - bind("$1", details) - bind("$2", details.toJson()) - bind("$3", state) - bind("$4", measurementConsumerId) - bind("$5", it.cmmsMeasurementId) + addValuesBinding { + bindValuesParam(0, details) + bindValuesParam(1, details.toJson()) + bindValuesParam(2, it.cmmsMeasurementId) } } } diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/writers/SetMeasurementResults.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/writers/SetMeasurementResults.kt index 128b211d4ec..24e5f6caaa9 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/writers/SetMeasurementResults.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/writers/SetMeasurementResults.kt @@ -16,8 +16,9 @@ package org.wfanet.measurement.reporting.deploy.v2.postgres.writers -import org.wfanet.measurement.common.db.r2dbc.boundStatement import org.wfanet.measurement.common.db.r2dbc.postgres.PostgresWriter +import org.wfanet.measurement.common.db.r2dbc.postgres.ValuesListBoundStatement +import org.wfanet.measurement.common.db.r2dbc.postgres.valuesListBoundStatement import org.wfanet.measurement.common.toJson import org.wfanet.measurement.internal.reporting.v2.BatchSetMeasurementResultsRequest import org.wfanet.measurement.internal.reporting.v2.Measurement @@ -43,22 +44,27 @@ class SetMeasurementResults(private val request: BatchSetMeasurementResultsReque .measurementConsumerId val statement = - boundStatement( + valuesListBoundStatement( + valuesStartIndex = 2, + paramCount = 3, """ - UPDATE Measurements SET MeasurementDetails = $1, - MeasurementDetailsJson = $2, State = $3 - WHERE MeasurementConsumerId = $4 AND CmmsMeasurementId = $5 - """ + UPDATE Measurements AS m SET + MeasurementDetails = c.MeasurementDetails, + MeasurementDetailsJson = c.MeasurementDetailsJson, + State = $1 + FROM (VALUES ${ValuesListBoundStatement.VALUES_LIST_PLACEHOLDER}) + AS c(MeasurementDetails, MeasurementDetailsJson, CmmsMeasurementId) + WHERE MeasurementConsumerId = $2 AND m.CmmsMeasurementId = c.CmmsMeasurementId + """, ) { - val state = Measurement.State.SUCCEEDED + bind("$1", Measurement.State.SUCCEEDED) + bind("$2", measurementConsumerId) request.measurementResultsList.forEach { val details = MeasurementKt.details { results += it.resultsList } - addBinding { - bind("$1", details) - bind("$2", details.toJson()) - bind("$3", state) - bind("$4", measurementConsumerId) - bind("$5", it.cmmsMeasurementId) + addValuesBinding { + bindValuesParam(0, details) + bindValuesParam(1, details.toJson()) + bindValuesParam(2, it.cmmsMeasurementId) } } } diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/internal/testing/v2/MeasurementsServiceTest.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/internal/testing/v2/MeasurementsServiceTest.kt index 2020fab3358..fb2d827af98 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/internal/testing/v2/MeasurementsServiceTest.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/internal/testing/v2/MeasurementsServiceTest.kt @@ -23,6 +23,7 @@ import com.google.type.interval import io.grpc.Status import io.grpc.StatusRuntimeException import java.time.Clock +import java.util.UUID import kotlin.random.Random import kotlin.test.assertFailsWith import kotlinx.coroutines.runBlocking @@ -327,7 +328,7 @@ abstract class MeasurementsServiceTest