diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/writers/CreateMetrics.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/writers/CreateMetrics.kt index 5acb10f73a9..907d51b9ab4 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/writers/CreateMetrics.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/writers/CreateMetrics.kt @@ -18,12 +18,13 @@ package org.wfanet.measurement.reporting.deploy.v2.postgres.writers import io.r2dbc.postgresql.codec.Interval as PostgresInterval import java.time.Instant +import java.time.OffsetDateTime import java.time.ZoneOffset import java.util.UUID import kotlinx.coroutines.flow.toList -import org.wfanet.measurement.common.db.r2dbc.BoundStatement -import org.wfanet.measurement.common.db.r2dbc.boundStatement import org.wfanet.measurement.common.db.r2dbc.postgres.PostgresWriter +import org.wfanet.measurement.common.db.r2dbc.postgres.ValuesListBoundStatement +import org.wfanet.measurement.common.db.r2dbc.postgres.valuesListBoundStatement import org.wfanet.measurement.common.identity.InternalId import org.wfanet.measurement.common.toDuration import org.wfanet.measurement.common.toInstant @@ -56,19 +57,56 @@ import org.wfanet.measurement.reporting.service.internal.ReportingSetNotFoundExc */ class CreateMetrics(private val requests: List) : PostgresWriter>() { - private data class WeightedMeasurementsAndBinders( + private data class MetricCalculationSpecReportingMetricsValues( + val metricId: InternalId, + val createMetricRequestId: UUID, + ) + + private data class WeightedMeasurementsAndInsertData( val weightedMeasurements: Collection, - val measurementsBinders: List Unit>, - val metricMeasurementsBinders: List Unit>, - val primitiveReportingSetBasesBinders: List Unit>, - val primitiveReportingSetBasisFiltersBinders: List Unit>, - val measurementPrimitiveReportingSetBasesBinders: List Unit>, + val measurementsValuesList: List, + val metricMeasurementsValuesList: List, + val primitiveReportingSetBasesValuesList: List, + val primitiveReportingSetBasisFiltersValuesList: List, + val measurementPrimitiveReportingSetBasesValuesList: + List, + ) + + private data class MeasurementsValues( + val measurementId: InternalId, + val cmmsCreateMeasurementRequestId: UUID, + val timeIntervalStart: OffsetDateTime, + val timeIntervalEndExclusive: OffsetDateTime, + ) + + private data class MetricMeasurementsValues( + val metricId: InternalId, + val measurementId: InternalId, + val coefficient: Int, + val binaryRepresentation: Int, + ) + + private data class PrimitiveReportingSetBasesInsertData( + val primitiveReportingSetBasesValuesList: List, + val primitiveReportingSetBasisFiltersValuesList: List, + val measurementPrimitiveReportingSetBasesValuesList: + List, + ) + + private data class PrimitiveReportingSetBasesValues( + val primitiveReportingSetBasisId: InternalId, + val primitiveReportingSetId: InternalId, + ) + + private data class PrimitiveReportingSetBasisFiltersValues( + val primitiveReportingSetBasisId: InternalId, + val primitiveReportingSetBasisFilterId: InternalId, + val filter: String, ) - private data class PrimitiveReportingSetBasesBinders( - val primitiveReportingSetBasesBinders: List Unit>, - val primitiveReportingSetBasisFiltersBinders: List Unit>, - val measurementPrimitiveReportingSetBasesBinders: List Unit>, + private data class MeasurementPrimitiveReportingSetBasesValues( + val measurementId: InternalId, + val primitiveReportingSetBasisId: InternalId, ) override suspend fun TransactionScope.runTransaction(): List { @@ -116,14 +154,14 @@ class CreateMetrics(private val requests: List) : } } - val externalReportingSetIds = mutableSetOf() - - for (request in requests) { - if (!existingMetricsMap.containsKey(request.requestId)) { - externalReportingSetIds.add(request.metric.externalReportingSetId) - for (weightedMeasurement in request.metric.weightedMeasurementsList) { - for (bases in weightedMeasurement.measurement.primitiveReportingSetBasesList) { - externalReportingSetIds.add(bases.externalReportingSetId) + val externalReportingSetIds = buildSet { + for (request in requests) { + if (!existingMetricsMap.containsKey(request.requestId)) { + add(request.metric.externalReportingSetId) + for (weightedMeasurement in request.metric.weightedMeasurementsList) { + for (bases in weightedMeasurement.measurement.primitiveReportingSetBasesList) { + add(bases.externalReportingSetId) + } } } } @@ -141,17 +179,20 @@ class CreateMetrics(private val requests: List) : val metrics = mutableListOf() - val metricCalculationSpecReportingMetricsBinders = - mutableListOf Unit>() - val measurementsBinders = mutableListOf Unit>() - val metricMeasurementsBinders = mutableListOf Unit>() - val primitiveReportingSetBasesBinders = mutableListOf Unit>() - val primitiveReportingSetBasisFiltersBinders = mutableListOf Unit>() - val measurementPrimitiveReportingSetBasesBinders = - mutableListOf Unit>() + val metricCalculationSpecReportingMetricsValuesList = + mutableListOf() + val measurementsValuesList = mutableListOf() + val metricMeasurementsValuesList = mutableListOf() + val primitiveReportingSetBasesValuesList = mutableListOf() + val primitiveReportingSetBasisFiltersValuesList = + mutableListOf() + val measurementPrimitiveReportingSetBasesValuesList = + mutableListOf() val statement = - boundStatement( + valuesListBoundStatement( + valuesStartIndex = 0, + paramCount = 20, """ INSERT INTO Metrics ( @@ -169,15 +210,15 @@ class CreateMetrics(private val requests: List) : FrequencyDifferentialPrivacyDelta, MaximumFrequencyPerUser, MaximumWatchDurationPerUser, + MaximumFrequency, VidSamplingIntervalStart, VidSamplingIntervalWidth, CreateTime, MetricDetails, - MetricDetailsJson, - MaximumFrequency + MetricDetailsJson ) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20) - """ + VALUES ${ValuesListBoundStatement.VALUES_LIST_PLACEHOLDER} + """, ) { requests.forEach { val existingMetric: Metric? = existingMetricsMap[it.requestId] @@ -195,80 +236,86 @@ class CreateMetrics(private val requests: List) : if (it.metric.metricSpec.typeCase == MetricSpec.TypeCase.POPULATION_COUNT) 0 else it.metric.metricSpec.vidSamplingInterval.width - addBinding { - bind("$1", measurementConsumerId) - bind("$2", metricId) + addValuesBinding { + bindValuesParam(0, measurementConsumerId) + bindValuesParam(1, metricId) if (it.requestId.isNotEmpty()) { - bind("$3", it.requestId) + bindValuesParam(2, it.requestId) } else { - bind("$3", null) + bindValuesParam(2, null) } - bind("$4", reportingSetId) - bind("$5", externalMetricId) - bind("$6", it.metric.timeInterval.startTime.toInstant().atOffset(ZoneOffset.UTC)) - bind("$7", it.metric.timeInterval.endTime.toInstant().atOffset(ZoneOffset.UTC)) - bind("$8", it.metric.metricSpec.typeCase.number) + bindValuesParam(3, reportingSetId) + bindValuesParam(4, externalMetricId) + bindValuesParam( + 5, + it.metric.timeInterval.startTime.toInstant().atOffset(ZoneOffset.UTC), + ) + bindValuesParam( + 6, + it.metric.timeInterval.endTime.toInstant().atOffset(ZoneOffset.UTC), + ) + bindValuesParam(7, it.metric.metricSpec.typeCase.number) @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. when (it.metric.metricSpec.typeCase) { MetricSpec.TypeCase.REACH_AND_FREQUENCY -> { val reachAndFrequency = it.metric.metricSpec.reachAndFrequency - bind("$9", reachAndFrequency.reachPrivacyParams.epsilon) - bind("$10", reachAndFrequency.reachPrivacyParams.delta) - bind("$11", reachAndFrequency.frequencyPrivacyParams.epsilon) - bind("$12", reachAndFrequency.reachPrivacyParams.delta) - bind("$13", null) - bind("$14", null) - bind("$20", reachAndFrequency.maximumFrequency) + bindValuesParam(8, reachAndFrequency.reachPrivacyParams.epsilon) + bindValuesParam(9, reachAndFrequency.reachPrivacyParams.delta) + bindValuesParam(10, reachAndFrequency.frequencyPrivacyParams.epsilon) + bindValuesParam(11, reachAndFrequency.reachPrivacyParams.delta) + bindValuesParam(12, null) + bindValuesParam(13, null) + bindValuesParam(14, reachAndFrequency.maximumFrequency) } MetricSpec.TypeCase.REACH -> { val reach = it.metric.metricSpec.reach - bind("$9", reach.privacyParams.epsilon) - bind("$10", reach.privacyParams.delta) - bind("$11", null) - bind("$12", null) - bind("$13", null) - bind("$14", null) - bind("$20", null) + bindValuesParam(8, reach.privacyParams.epsilon) + bindValuesParam(9, reach.privacyParams.delta) + bindValuesParam(10, null) + bindValuesParam(11, null) + bindValuesParam(12, null) + bindValuesParam(13, null) + bindValuesParam(14, null) } MetricSpec.TypeCase.IMPRESSION_COUNT -> { val impressionCount = it.metric.metricSpec.impressionCount - bind("$9", impressionCount.privacyParams.epsilon) - bind("$10", impressionCount.privacyParams.delta) - bind("$11", null) - bind("$12", null) - bind("$13", impressionCount.maximumFrequencyPerUser) - bind("$14", null) - bind("$20", null) + bindValuesParam(8, impressionCount.privacyParams.epsilon) + bindValuesParam(9, impressionCount.privacyParams.delta) + bindValuesParam(10, null) + bindValuesParam(11, null) + bindValuesParam(12, impressionCount.maximumFrequencyPerUser) + bindValuesParam(13, null) + bindValuesParam(14, null) } MetricSpec.TypeCase.WATCH_DURATION -> { val watchDuration = it.metric.metricSpec.watchDuration - bind("$9", watchDuration.privacyParams.epsilon) - bind("$10", watchDuration.privacyParams.delta) - bind("$11", null) - bind("$12", null) - bind("$13", null) - bind( - "$14", + bindValuesParam(8, watchDuration.privacyParams.epsilon) + bindValuesParam(9, watchDuration.privacyParams.delta) + bindValuesParam(10, null) + bindValuesParam(11, null) + bindValuesParam(12, null) + bindValuesParam( + 13, PostgresInterval.of(watchDuration.maximumWatchDurationPerUser.toDuration()), ) - bind("$20", null) + bindValuesParam(14, null) } MetricSpec.TypeCase.POPULATION_COUNT -> { - bind("$9", 0) - bind("$10", 0) - bind("$11", null) - bind("$12", null) - bind("$13", null) - bind("$14", null) - bind("$20", null) + bindValuesParam(8, 0) + bindValuesParam(9, 0) + bindValuesParam(10, null) + bindValuesParam(11, null) + bindValuesParam(12, null) + bindValuesParam(13, null) + bindValuesParam(14, null) } MetricSpec.TypeCase.TYPE_NOT_SET -> {} } - bind("$15", vidSamplingIntervalStart) - bind("$16", vidSamplingIntervalWidth) - bind("$17", createTime) - bind("$18", it.metric.details) - bind("$19", it.metric.details.toJson()) + bindValuesParam(15, vidSamplingIntervalStart) + bindValuesParam(16, vidSamplingIntervalWidth) + bindValuesParam(17, createTime) + bindValuesParam(18, it.metric.details) + bindValuesParam(19, it.metric.details.toJson()) } if (it.requestId.isNotEmpty()) { @@ -281,17 +328,17 @@ class CreateMetrics(private val requests: List) : } if (createMetricRequestUuid != null) { - metricCalculationSpecReportingMetricsBinders.add { - bind("$1", metricId) - bind("$2", measurementConsumerId) - bind("$3", createMetricRequestUuid) - } + metricCalculationSpecReportingMetricsValuesList.add( + MetricCalculationSpecReportingMetricsValues( + metricId = metricId, + createMetricRequestId = createMetricRequestUuid, + ) + ) } } - val weightedMeasurementsAndBindings = - createWeightedMeasurementsBindings( - measurementConsumerId = measurementConsumerId, + val weightedMeasurementsAndInsertData = + createWeightedMeasurementsInsertData( metricId = metricId, it.metric.weightedMeasurementsList, reportingSetMap, @@ -301,132 +348,196 @@ class CreateMetrics(private val requests: List) : it.metric.copy { this.externalMetricId = externalMetricId weightedMeasurements.clear() - weightedMeasurements.addAll(weightedMeasurementsAndBindings.weightedMeasurements) + weightedMeasurements.addAll(weightedMeasurementsAndInsertData.weightedMeasurements) this.createTime = createTime.toInstant().toProtoTime() } ) - measurementsBinders.addAll(weightedMeasurementsAndBindings.measurementsBinders) - metricMeasurementsBinders.addAll( - weightedMeasurementsAndBindings.metricMeasurementsBinders + measurementsValuesList.addAll(weightedMeasurementsAndInsertData.measurementsValuesList) + metricMeasurementsValuesList.addAll( + weightedMeasurementsAndInsertData.metricMeasurementsValuesList ) - primitiveReportingSetBasesBinders.addAll( - weightedMeasurementsAndBindings.primitiveReportingSetBasesBinders + primitiveReportingSetBasesValuesList.addAll( + weightedMeasurementsAndInsertData.primitiveReportingSetBasesValuesList ) - primitiveReportingSetBasisFiltersBinders.addAll( - weightedMeasurementsAndBindings.primitiveReportingSetBasisFiltersBinders + primitiveReportingSetBasisFiltersValuesList.addAll( + weightedMeasurementsAndInsertData.primitiveReportingSetBasisFiltersValuesList ) - measurementPrimitiveReportingSetBasesBinders.addAll( - weightedMeasurementsAndBindings.measurementPrimitiveReportingSetBasesBinders + measurementPrimitiveReportingSetBasesValuesList.addAll( + weightedMeasurementsAndInsertData.measurementPrimitiveReportingSetBasesValuesList ) } } } val metricCalculationSpecReportingMetricsStatement = - boundStatement( - """ - UPDATE MetricCalculationSpecReportingMetrics SET MetricId = $1 - WHERE MeasurementConsumerId = $2 AND CreateMetricRequestId = $3 + valuesListBoundStatement( + valuesStartIndex = 1, + paramCount = 2, """ + UPDATE MetricCalculationSpecReportingMetrics AS m SET MetricId = c.MetricId + FROM (VALUES ${ValuesListBoundStatement.VALUES_LIST_PLACEHOLDER}) + AS c(MetricId, CreateMetricRequestId) + WHERE MeasurementConsumerId = $1 AND m.CreateMetricRequestId = c.CreateMetricRequestId + """, ) { - metricCalculationSpecReportingMetricsBinders.forEach { addBinding(it) } + bind("$1", measurementConsumerId) + metricCalculationSpecReportingMetricsValuesList.forEach { + addValuesBinding { + bindValuesParam(0, it.metricId) + bindValuesParam(1, it.createMetricRequestId) + } + } } val measurementsStatement = - boundStatement( + valuesListBoundStatement( + valuesStartIndex = 0, + paramCount = 9, """ - INSERT INTO Measurements - ( - MeasurementConsumerId, - MeasurementId, - CmmsCreateMeasurementRequestId, - CmmsMeasurementId, - TimeIntervalStart, - TimeIntervalEndExclusive, - State, - MeasurementDetails, - MeasurementDetailsJson - ) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) - """ + INSERT INTO Measurements + ( + MeasurementConsumerId, + MeasurementId, + CmmsCreateMeasurementRequestId, + CmmsMeasurementId, + TimeIntervalStart, + TimeIntervalEndExclusive, + State, + MeasurementDetails, + MeasurementDetailsJson + ) + VALUES + ${ValuesListBoundStatement.VALUES_LIST_PLACEHOLDER} + """, ) { - measurementsBinders.forEach { addBinding(it) } + measurementsValuesList.forEach { + addValuesBinding { + bindValuesParam(0, measurementConsumerId) + bindValuesParam(1, it.measurementId) + bindValuesParam(2, it.cmmsCreateMeasurementRequestId) + bindValuesParam(3, null) + bindValuesParam(4, it.timeIntervalStart) + bindValuesParam(5, it.timeIntervalEndExclusive) + bindValuesParam(6, Measurement.State.STATE_UNSPECIFIED) + bindValuesParam(7, Measurement.Details.getDefaultInstance()) + bindValuesParam(8, Measurement.Details.getDefaultInstance().toJson()) + } + } } val metricMeasurementsStatement = - boundStatement( + valuesListBoundStatement( + valuesStartIndex = 0, + paramCount = 5, """ - INSERT INTO MetricMeasurements - ( - MeasurementConsumerId, - MetricId, - MeasurementId, - Coefficient, - BinaryRepresentation - ) - VALUES ($1, $2, $3, $4, $5) - """ + INSERT INTO MetricMeasurements + ( + MeasurementConsumerId, + MetricId, + MeasurementId, + Coefficient, + BinaryRepresentation + ) + VALUES + ${ValuesListBoundStatement.VALUES_LIST_PLACEHOLDER} + """, ) { - metricMeasurementsBinders.forEach { addBinding(it) } + metricMeasurementsValuesList.forEach { + addValuesBinding { + bindValuesParam(0, measurementConsumerId) + bindValuesParam(1, it.metricId) + bindValuesParam(2, it.measurementId) + bindValuesParam(3, it.coefficient) + bindValuesParam(4, it.binaryRepresentation) + } + } } val primitiveReportingSetBasesStatement = - boundStatement( + valuesListBoundStatement( + valuesStartIndex = 0, + paramCount = 3, """ - INSERT INTO PrimitiveReportingSetBases - ( - MeasurementConsumerId, - PrimitiveReportingSetBasisId, - PrimitiveReportingSetId - ) - VALUES ($1, $2, $3) - """ + INSERT INTO PrimitiveReportingSetBases + ( + MeasurementConsumerId, + PrimitiveReportingSetBasisId, + PrimitiveReportingSetId + ) + VALUES + ${ValuesListBoundStatement.VALUES_LIST_PLACEHOLDER} + """, ) { - primitiveReportingSetBasesBinders.forEach { addBinding(it) } + primitiveReportingSetBasesValuesList.forEach { + addValuesBinding { + bindValuesParam(0, measurementConsumerId) + bindValuesParam(1, it.primitiveReportingSetBasisId) + bindValuesParam(2, it.primitiveReportingSetId) + } + } } val primitiveReportingSetBasisFiltersStatement = - boundStatement( + valuesListBoundStatement( + valuesStartIndex = 0, + paramCount = 4, """ - INSERT INTO PrimitiveReportingSetBasisFilters - ( - MeasurementConsumerId, - PrimitiveReportingSetBasisId, - PrimitiveReportingSetBasisFilterId, - Filter - ) - VALUES ($1, $2, $3, $4) - """ + INSERT INTO PrimitiveReportingSetBasisFilters + ( + MeasurementConsumerId, + PrimitiveReportingSetBasisId, + PrimitiveReportingSetBasisFilterId, + Filter + ) + VALUES + ${ValuesListBoundStatement.VALUES_LIST_PLACEHOLDER} + """, ) { - primitiveReportingSetBasisFiltersBinders.forEach { addBinding(it) } + primitiveReportingSetBasisFiltersValuesList.forEach { + addValuesBinding { + bindValuesParam(0, measurementConsumerId) + bindValuesParam(1, it.primitiveReportingSetBasisId) + bindValuesParam(2, it.primitiveReportingSetBasisFilterId) + bindValuesParam(3, it.filter) + } + } } val measurementPrimitiveReportingSetBasesStatement = - boundStatement( + valuesListBoundStatement( + valuesStartIndex = 0, + paramCount = 3, """ - INSERT INTO MeasurementPrimitiveReportingSetBases - ( - MeasurementConsumerId, - MeasurementId, - PrimitiveReportingSetBasisId - ) - VALUES ($1, $2, $3) - """ + INSERT INTO MeasurementPrimitiveReportingSetBases + ( + MeasurementConsumerId, + MeasurementId, + PrimitiveReportingSetBasisId + ) + VALUES + ${ValuesListBoundStatement.VALUES_LIST_PLACEHOLDER} + """, ) { - measurementPrimitiveReportingSetBasesBinders.forEach { addBinding(it) } + measurementPrimitiveReportingSetBasesValuesList.forEach { + addValuesBinding { + bindValuesParam(0, measurementConsumerId) + bindValuesParam(1, it.measurementId) + bindValuesParam(2, it.primitiveReportingSetBasisId) + } + } } if (existingMetricsMap.size < requests.size) { transactionContext.run { executeStatement(statement) - if (metricCalculationSpecReportingMetricsBinders.size > 0) { + if (metricCalculationSpecReportingMetricsValuesList.size > 0) { executeStatement(metricCalculationSpecReportingMetricsStatement) } executeStatement(measurementsStatement) executeStatement(metricMeasurementsStatement) executeStatement(primitiveReportingSetBasesStatement) - if (primitiveReportingSetBasisFiltersBinders.size > 0) { + if (primitiveReportingSetBasisFiltersValuesList.size > 0) { executeStatement(primitiveReportingSetBasisFiltersStatement) } executeStatement(measurementPrimitiveReportingSetBasesStatement) @@ -436,19 +547,19 @@ class CreateMetrics(private val requests: List) : return metrics } - private fun TransactionScope.createWeightedMeasurementsBindings( - measurementConsumerId: InternalId, + private fun TransactionScope.createWeightedMeasurementsInsertData( metricId: InternalId, weightedMeasurements: Collection, reportingSetMap: Map, - ): WeightedMeasurementsAndBinders { + ): WeightedMeasurementsAndInsertData { val updatedWeightedMeasurements = mutableListOf() - val measurementsBinders = mutableListOf Unit>() - val metricMeasurementsBinders = mutableListOf Unit>() - val primitiveReportingSetBasesBinders = mutableListOf Unit>() - val primitiveReportingSetBasisFiltersBinders = mutableListOf Unit>() - val measurementPrimitiveReportingSetBasesBinders = - mutableListOf Unit>() + val measurementsValuesList = mutableListOf() + val metricMeasurementsValuesList = mutableListOf() + val primitiveReportingSetBasesValuesList = mutableListOf() + val primitiveReportingSetBasisFiltersValuesList = + mutableListOf() + val measurementPrimitiveReportingSetBasesValuesList = + mutableListOf() weightedMeasurements.forEach { val measurementId = idGenerator.generateInternalId() @@ -458,95 +569,99 @@ class CreateMetrics(private val requests: List) : measurement = measurement.copy { cmmsCreateMeasurementRequestId = uuid.toString() } } ) - measurementsBinders.add { - bind("$1", measurementConsumerId) - bind("$2", measurementId) - bind("$3", uuid) - bind("$4", null) - bind("$5", it.measurement.timeInterval.startTime.toInstant().atOffset(ZoneOffset.UTC)) - bind("$6", it.measurement.timeInterval.endTime.toInstant().atOffset(ZoneOffset.UTC)) - bind("$7", Measurement.State.STATE_UNSPECIFIED) - bind("$8", Measurement.Details.getDefaultInstance()) - bind("$9", Measurement.Details.getDefaultInstance().toJson()) - } + measurementsValuesList.add( + MeasurementsValues( + measurementId = measurementId, + cmmsCreateMeasurementRequestId = uuid, + timeIntervalStart = + it.measurement.timeInterval.startTime.toInstant().atOffset(ZoneOffset.UTC), + timeIntervalEndExclusive = + it.measurement.timeInterval.endTime.toInstant().atOffset(ZoneOffset.UTC), + ) + ) - metricMeasurementsBinders.add { - bind("$1", measurementConsumerId) - bind("$2", metricId) - bind("$3", measurementId) - bind("$4", it.weight) - bind("$5", it.binaryRepresentation) - } + metricMeasurementsValuesList.add( + MetricMeasurementsValues( + metricId = metricId, + measurementId = measurementId, + coefficient = it.weight, + binaryRepresentation = it.binaryRepresentation, + ) + ) - val primitiveReportingSetBasesBindings = - createPrimitiveReportingSetBasesBindings( - measurementConsumerId = measurementConsumerId, + val primitiveReportingSetBasesInsertData = + createPrimitiveReportingSetBasesInsertData( measurementId = measurementId, it.measurement.primitiveReportingSetBasesList, reportingSetMap, ) - primitiveReportingSetBasesBinders.addAll( - primitiveReportingSetBasesBindings.primitiveReportingSetBasesBinders + primitiveReportingSetBasesValuesList.addAll( + primitiveReportingSetBasesInsertData.primitiveReportingSetBasesValuesList ) - primitiveReportingSetBasisFiltersBinders.addAll( - primitiveReportingSetBasesBindings.primitiveReportingSetBasisFiltersBinders + primitiveReportingSetBasisFiltersValuesList.addAll( + primitiveReportingSetBasesInsertData.primitiveReportingSetBasisFiltersValuesList ) - measurementPrimitiveReportingSetBasesBinders.addAll( - primitiveReportingSetBasesBindings.measurementPrimitiveReportingSetBasesBinders + measurementPrimitiveReportingSetBasesValuesList.addAll( + primitiveReportingSetBasesInsertData.measurementPrimitiveReportingSetBasesValuesList ) } - return WeightedMeasurementsAndBinders( + return WeightedMeasurementsAndInsertData( weightedMeasurements = updatedWeightedMeasurements, - measurementsBinders = measurementsBinders, - metricMeasurementsBinders = metricMeasurementsBinders, - primitiveReportingSetBasesBinders = primitiveReportingSetBasesBinders, - primitiveReportingSetBasisFiltersBinders = primitiveReportingSetBasisFiltersBinders, - measurementPrimitiveReportingSetBasesBinders = measurementPrimitiveReportingSetBasesBinders, + measurementsValuesList = measurementsValuesList, + metricMeasurementsValuesList = metricMeasurementsValuesList, + primitiveReportingSetBasesValuesList = primitiveReportingSetBasesValuesList, + primitiveReportingSetBasisFiltersValuesList = primitiveReportingSetBasisFiltersValuesList, + measurementPrimitiveReportingSetBasesValuesList = + measurementPrimitiveReportingSetBasesValuesList, ) } - private fun TransactionScope.createPrimitiveReportingSetBasesBindings( - measurementConsumerId: InternalId, + private fun TransactionScope.createPrimitiveReportingSetBasesInsertData( measurementId: InternalId, primitiveReportingSetBases: Collection, reportingSetMap: Map, - ): PrimitiveReportingSetBasesBinders { - val primitiveReportingSetBasesBinders = mutableListOf Unit>() - val primitiveReportingSetBasisFiltersBinders = mutableListOf Unit>() - val measurementPrimitiveReportingSetBasesBinders = - mutableListOf Unit>() + ): PrimitiveReportingSetBasesInsertData { + val primitiveReportingSetBasesValuesList = mutableListOf() + val primitiveReportingSetBasisFiltersValuesList = + mutableListOf() + val measurementPrimitiveReportingSetBasesValuesList = + mutableListOf() primitiveReportingSetBases.forEach { val primitiveReportingSetBasisId = idGenerator.generateInternalId() - primitiveReportingSetBasesBinders.add { - bind("$1", measurementConsumerId) - bind("$2", primitiveReportingSetBasisId) - bind("$3", reportingSetMap[it.externalReportingSetId]) - } + primitiveReportingSetBasesValuesList.add( + PrimitiveReportingSetBasesValues( + primitiveReportingSetBasisId = primitiveReportingSetBasisId, + primitiveReportingSetId = reportingSetMap.getValue(it.externalReportingSetId), + ) + ) it.filtersList.forEach { filter -> val primitiveReportingSetBasisFilterId = idGenerator.generateInternalId() - primitiveReportingSetBasisFiltersBinders.add { - bind("$1", measurementConsumerId) - bind("$2", primitiveReportingSetBasisId) - bind("$3", primitiveReportingSetBasisFilterId) - bind("$4", filter) - } + primitiveReportingSetBasisFiltersValuesList.add( + PrimitiveReportingSetBasisFiltersValues( + primitiveReportingSetBasisId = primitiveReportingSetBasisId, + primitiveReportingSetBasisFilterId = primitiveReportingSetBasisFilterId, + filter = filter, + ) + ) } - measurementPrimitiveReportingSetBasesBinders.add { - bind("$1", measurementConsumerId) - bind("$2", measurementId) - bind("$3", primitiveReportingSetBasisId) - } + measurementPrimitiveReportingSetBasesValuesList.add( + MeasurementPrimitiveReportingSetBasesValues( + measurementId = measurementId, + primitiveReportingSetBasisId = primitiveReportingSetBasisId, + ) + ) } - return PrimitiveReportingSetBasesBinders( - primitiveReportingSetBasesBinders = primitiveReportingSetBasesBinders, - primitiveReportingSetBasisFiltersBinders = primitiveReportingSetBasisFiltersBinders, - measurementPrimitiveReportingSetBasesBinders = measurementPrimitiveReportingSetBasesBinders, + return PrimitiveReportingSetBasesInsertData( + primitiveReportingSetBasesValuesList = primitiveReportingSetBasesValuesList, + primitiveReportingSetBasisFiltersValuesList = primitiveReportingSetBasisFiltersValuesList, + measurementPrimitiveReportingSetBasesValuesList = + measurementPrimitiveReportingSetBasesValuesList, ) } } diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/writers/CreateReport.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/writers/CreateReport.kt index 496fc9d7cc5..708adc6999b 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/writers/CreateReport.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/writers/CreateReport.kt @@ -23,6 +23,8 @@ import kotlinx.coroutines.flow.toList import org.wfanet.measurement.common.db.r2dbc.BoundStatement import org.wfanet.measurement.common.db.r2dbc.boundStatement import org.wfanet.measurement.common.db.r2dbc.postgres.PostgresWriter +import org.wfanet.measurement.common.db.r2dbc.postgres.ValuesListBoundStatement +import org.wfanet.measurement.common.db.r2dbc.postgres.valuesListBoundStatement import org.wfanet.measurement.common.identity.InternalId import org.wfanet.measurement.common.toInstant import org.wfanet.measurement.common.toJson @@ -54,8 +56,8 @@ import org.wfanet.measurement.reporting.service.internal.ReportingSetNotFoundExc * * [ReportAlreadyExistsException] Report already exists */ class CreateReport(private val request: CreateReportRequest) : PostgresWriter() { - private data class ReportingMetricEntriesAndBinders( - val metricCalculationSpecReportingMetricsBinders: List Unit>, + private data class ReportingMetricEntriesAndStatement( + val metricCalculationSpecReportingMetricsStatement: BoundStatement, val updatedReportingMetricEntries: Map, ) @@ -207,8 +209,8 @@ class CreateReport(private val request: CreateReportRequest) : PostgresWriter, reportingMetricMap: Map>, - ): ReportingMetricEntriesAndBinders { - val metricCalculationSpecReportingMetricsBinders = - mutableListOf Unit>() + ): ReportingMetricEntriesAndStatement { val updatedReportingMetricEntries = mutableMapOf() - for (entry in report.reportingMetricEntriesMap.entries) { - val reportingSetId = reportingSetIdsByExternalId.getValue(entry.key) - val updatedMetricCalculationSpecReportingMetricsList = - mutableListOf() - - for (metricCalSpecReportingMetrics in entry.value.metricCalculationSpecReportingMetricsList) { - val externalMetricCalculationSpecId = - metricCalSpecReportingMetrics.externalMetricCalculationSpecId - val metricCalculationSpecResult = - metricCalculationSpecsByExternalId.getValue(externalMetricCalculationSpecId) - val metricCalculationSpecReportingMetricKey = - MetricCalculationSpecReportingMetricKey( - reportingSetId, - metricCalculationSpecResult.metricCalculationSpecId, - ) - val updatedReportingMetricsList = mutableListOf() - - // If we found an existing set of metrics for this `MetricCalculationSpecReportingMetricKey` - if (reportingMetricMap.contains(metricCalculationSpecReportingMetricKey)) { - // Note that metric reuse mechanism won't produce the same order of `ReportingMetric` - // internally, but this won't impact any immutable fields in public resources. - reportingMetricMap.getValue(metricCalculationSpecReportingMetricKey).forEach { - val reportingMetric = - ReportKt.reportingMetric { - createMetricRequestId = it.createMetricRequestId - externalMetricId = it.externalMetricId - details = - ReportKt.ReportingMetricKt.details { - metricSpec = it.metricSpec - timeInterval = it.reportingMetricKey.timeInterval - - val filters: MutableList = it.metricDetails.filtersList.toMutableList() - // The filters in a Metric is the combination of the grouping predicates and - // the filter in `MetricCalculationSpec` - filters.remove(metricCalculationSpecResult.metricCalculationSpec.details.filter) - groupingPredicates += filters + val statement = + valuesListBoundStatement( + valuesStartIndex = 0, + paramCount = 8, + """ + INSERT INTO MetricCalculationSpecReportingMetrics + ( + MeasurementConsumerId, + ReportId, + ReportingSetId, + MetricCalculationSpecId, + CreateMetricRequestId, + MetricId, + ReportingMetricDetails, + ReportingMetricDetailsJson + ) + VALUES + ${ValuesListBoundStatement.VALUES_LIST_PLACEHOLDER} + """, + ) { + for (entry in report.reportingMetricEntriesMap.entries) { + val reportingSetId = reportingSetIdsByExternalId.getValue(entry.key) + val updatedMetricCalculationSpecReportingMetricsList = + mutableListOf() + + for (metricCalSpecReportingMetrics in + entry.value.metricCalculationSpecReportingMetricsList) { + val externalMetricCalculationSpecId = + metricCalSpecReportingMetrics.externalMetricCalculationSpecId + val metricCalculationSpecResult = + metricCalculationSpecsByExternalId.getValue(externalMetricCalculationSpecId) + val metricCalculationSpecReportingMetricKey = + MetricCalculationSpecReportingMetricKey( + reportingSetId, + metricCalculationSpecResult.metricCalculationSpecId, + ) + val updatedReportingMetricsList = mutableListOf() + + // If we found an existing set of metrics for this + // `MetricCalculationSpecReportingMetricKey` + if (reportingMetricMap.contains(metricCalculationSpecReportingMetricKey)) { + // Note that metric reuse mechanism won't produce the same order of `ReportingMetric` + // internally, but this won't impact any immutable fields in public resources. + reportingMetricMap.getValue(metricCalculationSpecReportingMetricKey).forEach { + val reportingMetric = + ReportKt.reportingMetric { + createMetricRequestId = it.createMetricRequestId + externalMetricId = it.externalMetricId + details = + ReportKt.ReportingMetricKt.details { + metricSpec = it.metricSpec + timeInterval = it.reportingMetricKey.timeInterval + + val filters: MutableList = + it.metricDetails.filtersList.toMutableList() + // The filters in a Metric is the combination of the grouping predicates and + // the filter in `MetricCalculationSpec` + filters.remove( + metricCalculationSpecResult.metricCalculationSpec.details.filter + ) + groupingPredicates += filters + } } + addValuesBinding { + bindValuesParam(0, measurementConsumerId) + bindValuesParam(1, reportId) + bindValuesParam(2, reportingSetIdsByExternalId[entry.key]) + bindValuesParam(3, metricCalculationSpecResult.metricCalculationSpecId) + bindValuesParam(4, UUID.fromString(reportingMetric.createMetricRequestId)) + bindValuesParam(5, it.metricId) + bindValuesParam(6, reportingMetric.details) + bindValuesParam(7, reportingMetric.details.toJson()) + } + updatedReportingMetricsList.add(reportingMetric) + } + } else { + metricCalSpecReportingMetrics.reportingMetricsList.forEach { + val createMetricRequestId = UUID.randomUUID() + addValuesBinding { + bindValuesParam(0, measurementConsumerId) + bindValuesParam(1, reportId) + bindValuesParam(2, reportingSetIdsByExternalId[entry.key]) + bindValuesParam(3, metricCalculationSpecResult.metricCalculationSpecId) + bindValuesParam(4, createMetricRequestId) + bindValuesParam(5, null) + bindValuesParam(6, it.details) + bindValuesParam(7, it.details.toJson()) + } + updatedReportingMetricsList.add( + it.copy { this.createMetricRequestId = createMetricRequestId.toString() } + ) } - metricCalculationSpecReportingMetricsBinders.add { - bind("$1", measurementConsumerId) - bind("$2", reportId) - bind("$3", reportingSetIdsByExternalId[entry.key]) - bind("$4", metricCalculationSpecResult.metricCalculationSpecId) - bind("$5", UUID.fromString(reportingMetric.createMetricRequestId)) - bind("$6", it.metricId) - bind("$7", reportingMetric.details) - bind("$8", reportingMetric.details.toJson()) - } - updatedReportingMetricsList.add(reportingMetric) - } - } else { - metricCalSpecReportingMetrics.reportingMetricsList.forEach { - val createMetricRequestId = UUID.randomUUID() - metricCalculationSpecReportingMetricsBinders.add { - bind("$1", measurementConsumerId) - bind("$2", reportId) - bind("$3", reportingSetIdsByExternalId[entry.key]) - bind("$4", metricCalculationSpecResult.metricCalculationSpecId) - bind("$5", createMetricRequestId) - bind("$6", null) - bind("$7", it.details) - bind("$8", it.details.toJson()) } - updatedReportingMetricsList.add( - it.copy { this.createMetricRequestId = createMetricRequestId.toString() } + + updatedMetricCalculationSpecReportingMetricsList.add( + metricCalSpecReportingMetrics.copy { + reportingMetrics.clear() + reportingMetrics.addAll(updatedReportingMetricsList) + } ) } - } - updatedMetricCalculationSpecReportingMetricsList.add( - metricCalSpecReportingMetrics.copy { - reportingMetrics.clear() - reportingMetrics.addAll(updatedReportingMetricsList) - } - ) - } - - updatedReportingMetricEntries[entry.key] = - entry.value.copy { - metricCalculationSpecReportingMetrics.clear() - metricCalculationSpecReportingMetrics.addAll( - updatedMetricCalculationSpecReportingMetricsList - ) + updatedReportingMetricEntries[entry.key] = + entry.value.copy { + metricCalculationSpecReportingMetrics.clear() + metricCalculationSpecReportingMetrics.addAll( + updatedMetricCalculationSpecReportingMetricsList + ) + } } - } + } - return ReportingMetricEntriesAndBinders( - metricCalculationSpecReportingMetricsBinders = metricCalculationSpecReportingMetricsBinders, + return ReportingMetricEntriesAndStatement( + metricCalculationSpecReportingMetricsStatement = statement, updatedReportingMetricEntries = updatedReportingMetricEntries, ) } diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/writers/CreateReportingSet.kt b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/writers/CreateReportingSet.kt index 4c0ba379ac0..a4b9df92f14 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/writers/CreateReportingSet.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/deploy/v2/postgres/writers/CreateReportingSet.kt @@ -21,6 +21,8 @@ import io.r2dbc.spi.R2dbcDataIntegrityViolationException import org.wfanet.measurement.common.db.r2dbc.BoundStatement import org.wfanet.measurement.common.db.r2dbc.boundStatement import org.wfanet.measurement.common.db.r2dbc.postgres.PostgresWriter +import org.wfanet.measurement.common.db.r2dbc.postgres.ValuesListBoundStatement +import org.wfanet.measurement.common.db.r2dbc.postgres.valuesListBoundStatement import org.wfanet.measurement.common.identity.InternalId import org.wfanet.measurement.common.toJson import org.wfanet.measurement.internal.reporting.v2.CreateReportingSetRequest @@ -49,11 +51,42 @@ const val INTEGRITY_CONSTRAINT_VIOLATION = "23505" */ class CreateReportingSet(private val request: CreateReportingSetRequest) : PostgresWriter() { - private data class PrimitiveReportingSetBasesBinders( - val primitiveReportingSetBasesBinders: List Unit>, - val weightedSubsetUnionPrimitiveReportingSetBasesBinders: - List Unit>, - val primitiveReportingSetBasisFiltersBinders: List Unit>, + private data class SetExpressionsValues( + val setExpressionId: InternalId, + val operationValue: Int, + val leftHandSetExpressionId: InternalId?, + val leftHandReportingSetId: InternalId?, + val rightHandSetExpressionId: InternalId?, + val rightHandReportingSetId: InternalId?, + ) + + private data class WeightedSubsetUnionsValues( + val weightedSubsetUnionId: InternalId, + val weight: Int, + val binaryRepresentation: Int, + ) + + private data class PrimitiveReportingSetBasesInsertData( + val primitiveReportingSetBasesValues: PrimitiveReportingSetBasesValues, + val weightedSubsetUnionPrimitiveReportingSetBasesValues: + WeightedSubsetUnionPrimitiveReportingSetBasesValues, + val primitiveReportingSetBasisFiltersValuesList: List, + ) + + private data class PrimitiveReportingSetBasesValues( + val primitiveReportingSetBasisId: InternalId, + val primitiveReportingSetId: InternalId, + ) + + private data class WeightedSubsetUnionPrimitiveReportingSetBasesValues( + val weightedSubsetUnionId: InternalId, + val primitiveReportingSetBasisId: InternalId, + ) + + private data class PrimitiveReportingSetBasisFiltersValues( + val primitiveReportingSetBasisId: InternalId, + val primitiveReportingSetBasisFilterId: InternalId, + val filter: String, ) override suspend fun TransactionScope.runTransaction(): ReportingSet { @@ -122,22 +155,47 @@ class CreateReportingSet(private val request: CreateReportingSetRequest) : .collect { reportingSetMap[it.externalReportingSetId] = it.reportingSetId } val setExpressionId = idGenerator.generateInternalId() + val setExpressionsValuesList: List = buildList { + createSetExpressionsValues( + this, + setExpressionId = setExpressionId, + measurementConsumerId = measurementConsumerId, + reportingSetId = reportingSetId, + request.reportingSet.composite, + reportingSetMap, + ) + } val setExpressionsStatement = - boundStatement( + valuesListBoundStatement( + valuesStartIndex = 0, + paramCount = 8, """ - INSERT INTO SetExpressions (MeasurementConsumerId, ReportingSetId, SetExpressionId, Operation, LeftHandSetExpressionId, LeftHandReportingSetId, RightHandSetExpressionId, RightHandReportingSetId) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8) - """ + INSERT INTO SetExpressions ( + MeasurementConsumerId, + ReportingSetId, + SetExpressionId, + Operation, + LeftHandSetExpressionId, + LeftHandReportingSetId, + RightHandSetExpressionId, + RightHandReportingSetId + ) + VALUES ${ValuesListBoundStatement.VALUES_LIST_PLACEHOLDER} + """, ) { - addSetExpressionsBindings( - this, - setExpressionId = setExpressionId, - measurementConsumerId = measurementConsumerId, - reportingSetId = reportingSetId, - request.reportingSet.composite, - reportingSetMap, - ) + setExpressionsValuesList.forEach { + addValuesBinding { + bindValuesParam(0, measurementConsumerId) + bindValuesParam(1, reportingSetId) + bindValuesParam(2, it.setExpressionId) + bindValuesParam(3, it.operationValue) + bindValuesParam(4, it.leftHandSetExpressionId) + bindValuesParam(5, it.leftHandReportingSetId) + bindValuesParam(6, it.rightHandSetExpressionId) + bindValuesParam(7, it.rightHandReportingSetId) + } + } } transactionContext.executeStatement(setExpressionsStatement) @@ -210,7 +268,8 @@ class CreateReportingSet(private val request: CreateReportingSetRequest) : }] = it.eventGroupId } - val eventGroupBinders = mutableListOf Unit>() + val eventGroupBinders = + mutableListOf Unit>() cmmsEventGroupKeys.forEach { eventGroupMap.computeIfAbsent( @@ -221,10 +280,10 @@ class CreateReportingSet(private val request: CreateReportingSetRequest) : ) { val id = idGenerator.generateInternalId() eventGroupBinders.add { - bind("$1", measurementConsumerId) - bind("$2", id) - bind("$3", it.cmmsDataProviderId) - bind("$4", it.cmmsEventGroupId) + bindValuesParam(0, measurementConsumerId) + bindValuesParam(1, id) + bindValuesParam(2, it.cmmsDataProviderId) + bindValuesParam(3, it.cmmsEventGroupId) } id } @@ -232,30 +291,44 @@ class CreateReportingSet(private val request: CreateReportingSetRequest) : val eventGroupsStatement: BoundStatement? = if (eventGroupBinders.size > 0) { - boundStatement( + valuesListBoundStatement( + valuesStartIndex = 0, + paramCount = 4, """ - INSERT INTO EventGroups (MeasurementConsumerId, EventGroupId, CmmsDataProviderId, CmmsEventGroupId) - VALUES ($1, $2, $3, $4) - """ + INSERT INTO EventGroups ( + MeasurementConsumerId, + EventGroupId, + CmmsDataProviderId, + CmmsEventGroupId + ) + VALUES ${ValuesListBoundStatement.VALUES_LIST_PLACEHOLDER + } + """, ) { - eventGroupBinders.forEach { addBinding(it) } + eventGroupBinders.forEach { addValuesBinding(it) } } } else { null } val reportingSetEventGroupsStatement = - boundStatement( + valuesListBoundStatement( + valuesStartIndex = 0, + paramCount = 3, """ - INSERT INTO ReportingSetEventGroups (MeasurementConsumerId, ReportingSetId, EventGroupId) - VALUES ($1, $2, $3) - """ + INSERT INTO ReportingSetEventGroups ( + MeasurementConsumerId, + ReportingSetId, + EventGroupId + ) + VALUES ${ValuesListBoundStatement.VALUES_LIST_PLACEHOLDER} + """, ) { eventGroupMap.values.forEach { - addBinding { - bind("$1", measurementConsumerId) - bind("$2", reportingSetId) - bind("$3", it) + addValuesBinding { + bindValuesParam(0, measurementConsumerId) + bindValuesParam(1, reportingSetId) + bindValuesParam(2, it) } } } @@ -309,8 +382,8 @@ class CreateReportingSet(private val request: CreateReportingSetRequest) : return primitiveReportingSetBasesList.asSequence().map { it.externalReportingSetId }.toSet() } - private fun TransactionScope.addSetExpressionsBindings( - statementBuilder: BoundStatement.Builder, + private fun TransactionScope.createSetExpressionsValues( + values: MutableList, setExpressionId: InternalId, measurementConsumerId: InternalId, reportingSetId: InternalId, @@ -326,8 +399,8 @@ class CreateReportingSet(private val request: CreateReportingSetRequest) : when (setExpression.lhs.operandCase) { SetExpression.Operand.OperandCase.EXPRESSION -> { leftHandSetExpressionId = idGenerator.generateInternalId() - addSetExpressionsBindings( - statementBuilder, + createSetExpressionsValues( + values, leftHandSetExpressionId, measurementConsumerId, reportingSetId, @@ -352,8 +425,8 @@ class CreateReportingSet(private val request: CreateReportingSetRequest) : when (setExpression.rhs.operandCase) { SetExpression.Operand.OperandCase.EXPRESSION -> { rightHandSetExpressionId = idGenerator.generateInternalId() - addSetExpressionsBindings( - statementBuilder, + createSetExpressionsValues( + values, rightHandSetExpressionId, measurementConsumerId, reportingSetId, @@ -374,16 +447,16 @@ class CreateReportingSet(private val request: CreateReportingSetRequest) : } } - statementBuilder.addBinding { - bind("$1", measurementConsumerId) - bind("$2", reportingSetId) - bind("$3", setExpressionId) - bind("$4", setExpression.operationValue) - bind("$5", leftHandSetExpressionId) - bind("$6", leftHandReportingSetId) - bind("$7", rightHandSetExpressionId) - bind("$8", rightHandReportingSetId) - } + values.add( + SetExpressionsValues( + setExpressionId = setExpressionId, + operationValue = setExpression.operationValue, + leftHandSetExpressionId = leftHandSetExpressionId, + leftHandReportingSetId = leftHandReportingSetId, + rightHandSetExpressionId = rightHandSetExpressionId, + rightHandReportingSetId = rightHandReportingSetId, + ) + ) } private suspend fun TransactionScope.insertWeightedSubsetUnions( @@ -392,148 +465,183 @@ class CreateReportingSet(private val request: CreateReportingSetRequest) : weightedSubsetUnions: List, reportingSetMap: Map = mapOf(), ) { - val weightedSubsetUnionsBinders = mutableListOf Unit>() - val primitiveReportingSetBasesBinders = mutableListOf Unit>() - val weightedSubsetUnionPrimitiveReportingSetBasesBinders = - mutableListOf Unit>() - val primitiveReportingSetBasisFiltersBinders = mutableListOf Unit>() + val weightedSubsetUnionsValuesList = mutableListOf() + val primitiveReportingSetBasesValuesList = mutableListOf() + val weightedSubsetUnionPrimitiveReportingSetBasesValuesList = + mutableListOf() + val primitiveReportingSetBasisFiltersValuesList = + mutableListOf() weightedSubsetUnions.forEach { weightedSubsetUnion -> val weightedSubsetUnionId = idGenerator.generateInternalId() - weightedSubsetUnionsBinders.add { - bind("$1", measurementConsumerId) - bind("$2", reportingSetId) - bind("$3", weightedSubsetUnionId) - bind("$4", weightedSubsetUnion.weight) - bind("$5", weightedSubsetUnion.binaryRepresentation) - } + weightedSubsetUnionsValuesList.add( + WeightedSubsetUnionsValues( + weightedSubsetUnionId = weightedSubsetUnionId, + weight = weightedSubsetUnion.weight, + binaryRepresentation = weightedSubsetUnion.binaryRepresentation, + ) + ) weightedSubsetUnion.primitiveReportingSetBasesList.forEach { - val binders = - createPrimitiveReportingSetBasisBindings( - measurementConsumerId, - reportingSetId, - weightedSubsetUnionId, - it, - reportingSetMap, - ) - primitiveReportingSetBasesBinders.addAll(binders.primitiveReportingSetBasesBinders) - weightedSubsetUnionPrimitiveReportingSetBasesBinders.addAll( - binders.weightedSubsetUnionPrimitiveReportingSetBasesBinders + val insertData = + createPrimitiveReportingSetBasisInsertData(weightedSubsetUnionId, it, reportingSetMap) + primitiveReportingSetBasesValuesList.add(insertData.primitiveReportingSetBasesValues) + weightedSubsetUnionPrimitiveReportingSetBasesValuesList.add( + insertData.weightedSubsetUnionPrimitiveReportingSetBasesValues ) - primitiveReportingSetBasisFiltersBinders.addAll( - binders.primitiveReportingSetBasisFiltersBinders + primitiveReportingSetBasisFiltersValuesList.addAll( + insertData.primitiveReportingSetBasisFiltersValuesList ) } } val weightedSubsetUnionsStatement = - boundStatement( - """ - INSERT INTO WeightedSubsetUnions (MeasurementConsumerId, ReportingSetId, WeightedSubsetUnionId, Weight, BinaryRepresentation) - VALUES ($1, $2, $3, $4, $5) + valuesListBoundStatement( + valuesStartIndex = 0, + paramCount = 5, """ + INSERT INTO WeightedSubsetUnions ( + MeasurementConsumerId, + ReportingSetId, + WeightedSubsetUnionId, + Weight, + BinaryRepresentation + ) + VALUES ${ValuesListBoundStatement.VALUES_LIST_PLACEHOLDER} + """, ) { - weightedSubsetUnionsBinders.forEach { addBinding(it) } + weightedSubsetUnionsValuesList.forEach { + addValuesBinding { + bindValuesParam(0, measurementConsumerId) + bindValuesParam(1, reportingSetId) + bindValuesParam(2, it.weightedSubsetUnionId) + bindValuesParam(3, it.weight) + bindValuesParam(4, it.binaryRepresentation) + } + } } val primitiveReportingSetBasesStatement = - boundStatement( - """ - INSERT INTO PrimitiveReportingSetBases (MeasurementConsumerId, PrimitiveReportingSetBasisId, PrimitiveReportingSetId) - VALUES ($1, $2, $3) + valuesListBoundStatement( + valuesStartIndex = 0, + paramCount = 3, """ + INSERT INTO PrimitiveReportingSetBases ( + MeasurementConsumerId, + PrimitiveReportingSetBasisId, + PrimitiveReportingSetId + ) + VALUES ${ValuesListBoundStatement.VALUES_LIST_PLACEHOLDER} + """, ) { - primitiveReportingSetBasesBinders.forEach { addBinding(it) } + primitiveReportingSetBasesValuesList.forEach { + addValuesBinding { + bindValuesParam(0, measurementConsumerId) + bindValuesParam(1, it.primitiveReportingSetBasisId) + bindValuesParam(2, it.primitiveReportingSetId) + } + } } val weightedSubsetUnionPrimitiveReportingSetBasesStatement = - boundStatement( - """ - INSERT INTO WeightedSubsetUnionPrimitiveReportingSetBases (MeasurementConsumerId, ReportingSetId, WeightedSubsetUnionId, PrimitiveReportingSetBasisId) - VALUES ($1, $2, $3, $4) + valuesListBoundStatement( + valuesStartIndex = 0, + paramCount = 4, """ + INSERT INTO WeightedSubsetUnionPrimitiveReportingSetBases ( + MeasurementConsumerId, + ReportingSetId, + WeightedSubsetUnionId, + PrimitiveReportingSetBasisId + ) + VALUES ${ValuesListBoundStatement.VALUES_LIST_PLACEHOLDER} + """, ) { - weightedSubsetUnionPrimitiveReportingSetBasesBinders.forEach { addBinding(it) } + weightedSubsetUnionPrimitiveReportingSetBasesValuesList.forEach { + addValuesBinding { + bindValuesParam(0, measurementConsumerId) + bindValuesParam(1, reportingSetId) + bindValuesParam(2, it.weightedSubsetUnionId) + bindValuesParam(3, it.primitiveReportingSetBasisId) + } + } } val primitiveReportingSetBasisFiltersStatement = - boundStatement( - """ - INSERT INTO PrimitiveReportingSetBasisFilters (MeasurementConsumerId, PrimitiveReportingSetBasisId, PrimitiveReportingSetBasisFilterId, Filter) - VALUES ($1, $2, $3, $4) + valuesListBoundStatement( + valuesStartIndex = 0, + paramCount = 4, """ + INSERT INTO PrimitiveReportingSetBasisFilters ( + MeasurementConsumerId, + PrimitiveReportingSetBasisId, + PrimitiveReportingSetBasisFilterId, + Filter + ) + VALUES ${ValuesListBoundStatement.VALUES_LIST_PLACEHOLDER} + """, ) { - primitiveReportingSetBasisFiltersBinders.forEach { addBinding(it) } + primitiveReportingSetBasisFiltersValuesList.forEach { + addValuesBinding { + bindValuesParam(0, measurementConsumerId) + bindValuesParam(1, it.primitiveReportingSetBasisId) + bindValuesParam(2, it.primitiveReportingSetBasisFilterId) + bindValuesParam(3, it.filter) + } + } } transactionContext.run { executeStatement(weightedSubsetUnionsStatement) executeStatement(primitiveReportingSetBasesStatement) executeStatement(weightedSubsetUnionPrimitiveReportingSetBasesStatement) - if (primitiveReportingSetBasisFiltersBinders.size > 0) { + if (primitiveReportingSetBasisFiltersValuesList.size > 0) { executeStatement(primitiveReportingSetBasisFiltersStatement) } } } - private fun TransactionScope.createPrimitiveReportingSetBasisBindings( - measurementConsumerId: InternalId, - reportingSetId: InternalId, + private fun TransactionScope.createPrimitiveReportingSetBasisInsertData( weightedSubsetUnionId: InternalId, primitiveReportingSetBasis: PrimitiveReportingSetBasis, reportingSetMap: Map = mapOf(), - ): PrimitiveReportingSetBasesBinders { + ): PrimitiveReportingSetBasesInsertData { val primitiveReportingSetBasisId = idGenerator.generateInternalId() val primitiveReportingSetId = reportingSetMap[primitiveReportingSetBasis.externalReportingSetId] ?: throw ReportingSetNotFoundException() - val primitiveReportingSetBasesBinder: BoundStatement.Binder.() -> Unit = { - bind("$1", measurementConsumerId) - bind("$2", primitiveReportingSetBasisId) - bind("$3", primitiveReportingSetId) - } + val primitiveReportingSetBasesValues = + PrimitiveReportingSetBasesValues( + primitiveReportingSetBasisId = primitiveReportingSetBasisId, + primitiveReportingSetId = primitiveReportingSetId, + ) - val weightedSubsetUnionPrimitiveReportingSetBasesBinder: BoundStatement.Binder.() -> Unit = { - bind("$1", measurementConsumerId) - bind("$2", reportingSetId) - bind("$3", weightedSubsetUnionId) - bind("$4", primitiveReportingSetBasisId) - } + val weightedSubsetUnionPrimitiveReportingSetBasesValues = + WeightedSubsetUnionPrimitiveReportingSetBasesValues( + weightedSubsetUnionId = weightedSubsetUnionId, + primitiveReportingSetBasisId = primitiveReportingSetBasisId, + ) - val primitiveReportingSetBasisFiltersBinders = mutableListOf Unit>() - primitiveReportingSetBasis.filtersList.forEach { - primitiveReportingSetBasisFiltersBinders.add( - insertPrimitiveReportingSetBasisFilter( - measurementConsumerId, - primitiveReportingSetBasisId, - it, + val primitiveReportingSetBasisFiltersValuesList = + mutableListOf() + primitiveReportingSetBasis.filtersList.forEach { filter -> + val primitiveReportingSetBasisFilterId = idGenerator.generateInternalId() + primitiveReportingSetBasisFiltersValuesList.add( + PrimitiveReportingSetBasisFiltersValues( + primitiveReportingSetBasisId = primitiveReportingSetBasisId, + primitiveReportingSetBasisFilterId = primitiveReportingSetBasisFilterId, + filter = filter, ) ) } - return PrimitiveReportingSetBasesBinders( - primitiveReportingSetBasesBinders = listOf(primitiveReportingSetBasesBinder), - weightedSubsetUnionPrimitiveReportingSetBasesBinders = - listOf(weightedSubsetUnionPrimitiveReportingSetBasesBinder), - primitiveReportingSetBasisFiltersBinders, + return PrimitiveReportingSetBasesInsertData( + primitiveReportingSetBasesValues = primitiveReportingSetBasesValues, + weightedSubsetUnionPrimitiveReportingSetBasesValues = + weightedSubsetUnionPrimitiveReportingSetBasesValues, + primitiveReportingSetBasisFiltersValuesList, ) } - - private fun TransactionScope.insertPrimitiveReportingSetBasisFilter( - measurementConsumerId: InternalId, - primitiveReportingSetBasisId: InternalId, - filter: String, - ): BoundStatement.Binder.() -> Unit { - val primitiveReportingSetBasisFilterId = idGenerator.generateInternalId() - - return { - bind("$1", measurementConsumerId) - bind("$2", primitiveReportingSetBasisId) - bind("$3", primitiveReportingSetBasisFilterId) - bind("$4", filter) - } - } }