diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/readers/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/readers/BUILD.bazel index c91a4277eaa..8221bb00395 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/readers/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/readers/BUILD.bazel @@ -18,6 +18,7 @@ kt_jvm_library( "@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core", "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common", "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/db/r2dbc", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc", "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/identity", ], ) diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/readers/ComputationBlobReferenceReader.kt b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/readers/ComputationBlobReferenceReader.kt new file mode 100644 index 00000000000..cf0d8a92e1f --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/readers/ComputationBlobReferenceReader.kt @@ -0,0 +1,179 @@ +// Copyright 2023 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package org.wfanet.measurement.duchy.deploy.common.postgres.readers + +import kotlinx.coroutines.flow.firstOrNull +import kotlinx.coroutines.flow.toList +import org.wfanet.measurement.common.db.r2dbc.ReadContext +import org.wfanet.measurement.common.db.r2dbc.ResultRow +import org.wfanet.measurement.common.db.r2dbc.boundStatement +import org.wfanet.measurement.internal.duchy.ComputationBlobDependency +import org.wfanet.measurement.internal.duchy.ComputationStageBlobMetadata +import org.wfanet.measurement.internal.duchy.computationStageBlobMetadata + +/** Performs read operations on ComputationBlobReferences tables */ +class ComputationBlobReferenceReader { + + /** + * Reads the [ComputationBlobDependency] of a computation. + * + * @param localId local identifier of the computation + * @param stage stage enum of the computation + * @param blobId local identifier of the blob + * @return [ComputationBlobDependency] if the blob exists, or null. + */ + suspend fun readBlobDependency( + readContext: ReadContext, + localId: Long, + stage: Long, + blobId: Long, + ): ComputationBlobDependency? { + val sql = + boundStatement( + """ + SELECT DependencyType + FROM ComputationBlobReferences + WHERE + ComputationId = $1 + AND + ComputationStage = $2 + AND + BlobId = $3 + """ + .trimIndent() + ) { + bind("$1", localId) + bind("$2", stage) + bind("$3", blobId) + } + + return readContext + .executeQuery(sql) + .consume { row -> row.getProtoEnum("DependencyType", ComputationBlobDependency::forNumber) } + .firstOrNull() + } + + /** + * Reads a map of blobId to pathToBlob of a computation based on localComputationId. + * + * @param localId local identifier of the computation + * @param stage stage enum of the computation + * @param dependencyType enum value of the dependency type + * @return map of blobId to pathToBlob + */ + suspend fun readBlobIdToPathMap( + readContext: ReadContext, + localId: Long, + stage: Long, + dependencyType: Long + ): Map { + val sql = + boundStatement( + """ + SELECT BlobId, PathToBlob + FROM ComputationBlobReferences + WHERE + ComputationId = $1 + AND + ComputationStage = $2 + AND + DependencyType = $3 + """ + .trimIndent() + ) { + bind("$1", localId) + bind("$2", stage) + bind("$3", dependencyType) + } + + return readContext + .executeQuery(sql) + .consume { it.get("BlobId") to it.get("PathToBlob") } + .toList() + .toMap() + } + + /** + * Reads a list of computationBlobKeys by localComputationId + * + * @param readContext The transaction context for reading from the Postgres database. + * @param localComputationId A local identifier for a computation + * @return A list of computation blob keys + */ + suspend fun readComputationBlobKeys( + readContext: ReadContext, + localComputationId: Long + ): List { + val statement = + boundStatement( + """ + SELECT PathToBlob + FROM ComputationBlobReferences + WHERE ComputationId = $1 AND PathToBlob IS NOT NULL + """ + .trimIndent() + ) { + bind("$1", localComputationId) + } + + return readContext + .executeQuery(statement) + .consume { row -> row.get("PathToBlob") } + .toList() + } + + /** + * Reads a list of [ComputationStageBlobMetadata] by localComputationId + * + * @param readContext The transaction context for reading from the Postgres database. + * @param localComputationId A local identifier for a computation + * @param computationStage stage enum of the computation + * @return A list of [ComputationStageBlobMetadata] + */ + suspend fun readBlobMetadata( + readContext: ReadContext, + localComputationId: Long, + computationStage: Long + ): List { + val statement = + boundStatement( + """ + SELECT BlobId, PathToBlob, DependencyType + FROM ComputationBlobReferences + WHERE + ComputationId = $1 + AND + ComputationStage = $2 + """ + .trimIndent() + ) { + bind("$1", localComputationId) + bind("$2", computationStage) + } + + return readContext.executeQuery(statement).consume(::buildBlobMetadata).toList() + } + + private fun buildBlobMetadata(row: ResultRow): ComputationStageBlobMetadata { + val path = row.get("PathToBlob") + return computationStageBlobMetadata { + blobId = row["BlobId"] + dependencyType = row.getProtoEnum("DependencyType", ComputationBlobDependency::forNumber) + if (path != null) { + this.path = path + } + } + } +} diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/readers/ComputationReader.kt b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/readers/ComputationReader.kt index db157612417..e0f15da051a 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/readers/ComputationReader.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/readers/ComputationReader.kt @@ -17,11 +17,12 @@ package org.wfanet.measurement.duchy.deploy.common.postgres.readers import com.google.protobuf.Timestamp import java.time.Instant import kotlinx.coroutines.flow.firstOrNull -import kotlinx.coroutines.flow.toList +import kotlinx.coroutines.flow.toSet import org.wfanet.measurement.common.db.r2dbc.DatabaseClient import org.wfanet.measurement.common.db.r2dbc.ReadContext import org.wfanet.measurement.common.db.r2dbc.ResultRow import org.wfanet.measurement.common.db.r2dbc.boundStatement +import org.wfanet.measurement.common.grpc.grpcRequire import org.wfanet.measurement.common.toProtoTime import org.wfanet.measurement.duchy.db.computation.ComputationProtocolStagesEnumHelper import org.wfanet.measurement.duchy.db.computation.ComputationStageLongValues @@ -32,16 +33,10 @@ import org.wfanet.measurement.internal.duchy.ComputationStageDetails import org.wfanet.measurement.internal.duchy.ComputationToken import org.wfanet.measurement.internal.duchy.ComputationTypeEnum.ComputationType import org.wfanet.measurement.internal.duchy.ExternalRequisitionKey -import org.wfanet.measurement.internal.duchy.RequisitionDetails import org.wfanet.measurement.internal.duchy.RequisitionMetadata -import org.wfanet.measurement.internal.duchy.computationStageBlobMetadata import org.wfanet.measurement.internal.duchy.computationToken -import org.wfanet.measurement.internal.duchy.externalRequisitionKey -import org.wfanet.measurement.internal.duchy.requisitionMetadata /** - * Performs read operations on Computations tables - * * @param computationProtocolStagesEnumHelper [ComputationProtocolStagesEnumHelper] a helper class * to work with Enum representations of [ComputationType] and [ComputationStage]. */ @@ -50,7 +45,10 @@ class ComputationReader( ComputationProtocolStagesEnumHelper ) { - private data class Computation( + private val blobReferenceReader = ComputationBlobReferenceReader() + private val requisitionReader = RequisitionReader() + + data class Computation( val globalComputationId: String, val localComputationId: Long, val protocol: Long, @@ -74,29 +72,10 @@ class ComputationReader( version = row.get("UpdateTime").toEpochMilli(), stageSpecificDetails = row.getProtoMessage("StageDetails", ComputationStageDetails.parser()), lockOwner = row["LockOwner"], - lockExpirationTime = row.get("LockExpirationTime").toProtoTime() + lockExpirationTime = row.get("LockExpirationTime")?.toProtoTime() ) } - private fun buildBlob(row: ResultRow): ComputationStageBlobMetadata { - return computationStageBlobMetadata { - blobId = row["BlobId"] - path = row["PathToBlob"] - dependencyType = row["DependencyType"] - } - } - - private fun buildRequisition(row: ResultRow): RequisitionMetadata { - return requisitionMetadata { - externalKey = externalRequisitionKey { - externalRequisitionId = row["ExternalRequisitionId"] - requisitionFingerprint = row["RequisitionFingerprint"] - } - row.get("PathToBlob")?.let { path = it } - details = row.getProtoMessage("RequisitionDetails", RequisitionDetails.parser()) - } - } - private fun buildComputationToken( computation: Computation, blobs: List, @@ -126,54 +105,7 @@ class ComputationReader( } } - private suspend fun readBlobs( - readContext: ReadContext, - localComputationId: Long, - computationStage: Long - ): List { - val statement = - boundStatement( - """ - SELECT BlobId, PathToBlob, DependencyType - FROM ComputationBlobReferences - WHERE - ComputationId = $1 - AND - ComputationStage = $2 - """ - .trimIndent() - ) { - bind("$1", localComputationId) - bind("$2", computationStage) - } - - return readContext.executeQuery(statement).consume(::buildBlob).toList() - } - - private suspend fun readRequisitions( - readContext: ReadContext, - localComputationId: Long, - ): List { - val statement = - boundStatement( - """ - SELECT - ExternalRequisitionId, RequisitionFingerprint, PathToBlob, RequisitionDetails - FROM Requisitions - WHERE ComputationId = $1 - """ - .trimIndent() - ) { - bind("$1", localComputationId) - } - - return readContext.executeQuery(statement).consume(::buildRequisition).toList() - } - - private suspend fun readComputation( - readContext: ReadContext, - globalComputationId: String - ): Computation? { + suspend fun readComputation(readContext: ReadContext, globalComputationId: String): Computation? { val statement = boundStatement( """ @@ -196,7 +128,7 @@ class ComputationReader( ) { bind("$1", globalComputationId) } - return readContext.executeQuery(statement).consume(ComputationReader::Computation).firstOrNull() + return readContext.executeQuery(statement).consume(::Computation).firstOrNull() } private suspend fun readComputation( @@ -227,18 +159,17 @@ class ComputationReader( """ ) { bind("$1", externalRequisitionKey.externalRequisitionId) - bind("$2", externalRequisitionKey.requisitionFingerprint) + bind("$2", externalRequisitionKey.requisitionFingerprint.toByteArray()) } - return readContext.executeQuery(statement).consume(ComputationReader::Computation).firstOrNull() + return readContext.executeQuery(statement).consume(::Computation).firstOrNull() } /** - * Gets a [ComputationToken] by globalComputationId. + * Reads a [ComputationToken] by globalComputationId. * - * @param readContext The transaction context for reading from the Postgres database. + * @param client The [DatabaseClient] to the Postgres database. * @param globalComputationId A global identifier for a computation. - * @return [ReadComputationTokenResult] when a Computation with globalComputationId is found. - * @return null otherwise. + * @return [ComputationToken] when a Computation with globalComputationId is found, or null. */ suspend fun readComputationToken( client: DatabaseClient, @@ -250,8 +181,13 @@ class ComputationReader( readComputation(readContext, globalComputationId) ?: return null val blobs = - readBlobs(readContext, computation.localComputationId, computation.computationStage) - val requisitions = readRequisitions(readContext, computation.localComputationId) + blobReferenceReader.readBlobMetadata( + readContext, + computation.localComputationId, + computation.computationStage + ) + val requisitions = + requisitionReader.readRequisitionMetadata(readContext, computation.localComputationId) return buildComputationToken(computation, blobs, requisitions) } finally { @@ -260,12 +196,11 @@ class ComputationReader( } /** - * Gets a [ComputationToken] by externalRequisitionKey. + * Reads a [ComputationToken] by externalRequisitionKey. * - * @param readContext The transaction context for reading from the Postgres database. + * @param client The [DatabaseClient] to the Postgres database. * @param externalRequisitionKey The [ExternalRequisitionKey] for a computation. - * @return [ReadComputationTokenResult] when a Computation with externalRequisitionKey is found. - * @return null otherwise. + * @return [ComputationToken] when a Computation with externalRequisitionKey is found, or null. */ suspend fun readComputationToken( client: DatabaseClient, @@ -276,12 +211,73 @@ class ComputationReader( val computation = readComputation(readContext, externalRequisitionKey) ?: return null val blobs = - readBlobs(readContext, computation.localComputationId, computation.computationStage) - val requisitions = readRequisitions(readContext, computation.localComputationId) + blobReferenceReader.readBlobMetadata( + readContext, + computation.localComputationId, + computation.computationStage + ) + val requisitions = + requisitionReader.readRequisitionMetadata(readContext, computation.localComputationId) return buildComputationToken(computation, blobs, requisitions) } finally { readContext.close() } } + + /** + * Reads a set of globalComputationIds + * + * @param readContext The transaction context for reading from the Postgres database. + * @param stages A list of stage's long values + * @param updatedBefore An [Instant] to filter for the computations that has been updated before + * this + * @return A set of global computation Ids + */ + suspend fun readGlobalComputationIds( + readContext: ReadContext, + stages: List, + updatedBefore: Instant? = null + ): Set { + val computationTypes = + stages.map { computationProtocolStagesEnumHelper.stageToProtocol(it) }.distinct() + grpcRequire(computationTypes.count() == 1) { + "All stages should have the same ComputationType." + } + + /** + * Binding list of String into the IN clause does not work as expected with r2dbc library. + * Hence, manually joining targeting stages into a comma separated string and stub it into the + * query. + */ + val stagesString = + stages + .map { computationProtocolStagesEnumHelper.computationStageEnumToLongValues(it).stage } + .toList() + .joinToString(",") + val baseSql = + """ + SELECT GlobalComputationId + FROM Computations + WHERE + ComputationStage IN ($stagesString) + AND + Protocol = $1 + """ + + val query = + if (updatedBefore == null) { + boundStatement(baseSql) { bind("$1", computationTypes[0]) } + } else { + boundStatement(baseSql + " AND UpdatedTime <= $2") { + bind("$1", computationTypes[0]) + bind("$2", updatedBefore) + } + } + + return readContext + .executeQuery(query) + .consume { row -> row.get("GlobalComputationId") } + .toSet() + } } diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/readers/ComputationStageAttemptReader.kt b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/readers/ComputationStageAttemptReader.kt new file mode 100644 index 00000000000..4eca3bcd878 --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/readers/ComputationStageAttemptReader.kt @@ -0,0 +1,62 @@ +// Copyright 2023 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package org.wfanet.measurement.duchy.deploy.common.postgres.readers + +import kotlinx.coroutines.flow.firstOrNull +import org.wfanet.measurement.common.db.r2dbc.ReadContext +import org.wfanet.measurement.common.db.r2dbc.boundStatement +import org.wfanet.measurement.internal.duchy.ComputationStageAttemptDetails + +/** Performs read operations on ComputationStageAttempts tables */ +class ComputationStageAttemptReader { + + /** + * Reads a [ComputationStageAttemptDetails] by localComputationId and stage. + * + * @param readContext The transaction context for reading from the Postgres database. + * @param localId The local identifier for the target computation. + * @param stage The target stage of this computation. + * @return [ComputationStageAttemptDetails] when computation stage details is found, or null. + */ + suspend fun readComputationStageDetails( + readContext: ReadContext, + localId: Long, + stage: Long, + currentAttempt: Long + ): ComputationStageAttemptDetails? { + val readComputationStageDetailsSql = + boundStatement( + """ + SELECT Details + FROM ComputationStageAttempts + WHERE + ComputationId = $1 + AND + ComputationStage = $2 + AND + Attempt = $3 + """ + ) { + bind("$1", localId) + bind("$2", stage) + bind("$3", currentAttempt) + } + + return readContext + .executeQuery(readComputationStageDetailsSql) + .consume { it.getProtoMessage("Details", ComputationStageAttemptDetails.parser()) } + .firstOrNull() + } +} diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/readers/ContinuationTokenReader.kt b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/readers/ContinuationTokenReader.kt index 61893a29e76..7af62229b75 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/readers/ContinuationTokenReader.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/readers/ContinuationTokenReader.kt @@ -19,6 +19,7 @@ import org.wfanet.measurement.common.db.r2dbc.ReadContext import org.wfanet.measurement.common.db.r2dbc.ResultRow import org.wfanet.measurement.common.db.r2dbc.boundStatement +/** Performs read operations on HeraldContinuationTokens tables */ class ContinuationTokenReader { companion object { private const val parameterizedQueryString = @@ -36,8 +37,7 @@ class ContinuationTokenReader { /** * Reads a ContinuationToken from the HeraldContinuationTokens table. * - * @return [Result] when a ContinuationToken is found. - * @return null when there is no ContinuationToken. + * @return [Result] when a ContinuationToken is found, or null. */ suspend fun getContinuationToken(readContext: ReadContext): Result? { val statement = boundStatement(parameterizedQueryString) diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/readers/RequisitionReader.kt b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/readers/RequisitionReader.kt new file mode 100644 index 00000000000..7d94ff4f464 --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/readers/RequisitionReader.kt @@ -0,0 +1,134 @@ +// Copyright 2023 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package org.wfanet.measurement.duchy.deploy.common.postgres.readers + +import kotlinx.coroutines.flow.firstOrNull +import kotlinx.coroutines.flow.toList +import org.wfanet.measurement.common.db.r2dbc.ReadContext +import org.wfanet.measurement.common.db.r2dbc.ResultRow +import org.wfanet.measurement.common.db.r2dbc.boundStatement +import org.wfanet.measurement.internal.duchy.ExternalRequisitionKey +import org.wfanet.measurement.internal.duchy.RequisitionDetails +import org.wfanet.measurement.internal.duchy.RequisitionMetadata +import org.wfanet.measurement.internal.duchy.externalRequisitionKey +import org.wfanet.measurement.internal.duchy.requisitionMetadata + +/** Performs read operations on Requisitions tables */ +class RequisitionReader { + data class RequisitionResult(val computationId: Long, val requisitionId: Long) { + constructor( + row: ResultRow + ) : this( + computationId = row.get("ComputationId"), + requisitionId = row.get("RequisitionId") + ) + } + + /** + * Reads a set of globalComputationIds + * + * @param readContext The transaction context for reading from the Postgres database. + * @param key [ExternalRequisitionKey] external requisition key that identifies a requisition. + * @return [RequisitionResult] when the target requisition is found, or null. + */ + suspend fun readRequisitionByExternalKey( + readContext: ReadContext, + key: ExternalRequisitionKey + ): RequisitionResult? { + val sql = + boundStatement( + """ + SELECT ComputationId, RequisitionId + FROM Requisitions + WHERE + ExternalRequisitionId = $1 + AND + RequisitionFingerprint = $2 + """ + .trimIndent() + ) { + bind("$1", key.externalRequisitionId) + bind("$2", key.requisitionFingerprint.toByteArray()) + } + return readContext.executeQuery(sql).consume(RequisitionReader::RequisitionResult).firstOrNull() + } + + /** + * Gets a list of requisitionBlobKeys by localComputationId + * + * @param readContext The transaction context for reading from the Postgres database. + * @param localComputationId A local identifier for a computation + * @return A list of requisition blob keys + */ + suspend fun readRequisitionBlobKeys( + readContext: ReadContext, + localComputationId: Long + ): List { + val statement = + boundStatement( + """ + SELECT PathToBlob + FROM Requisitions + WHERE ComputationId = $1 AND PathToBlob IS NOT NULL + """ + .trimIndent() + ) { + bind("$1", localComputationId) + } + + return readContext + .executeQuery(statement) + .consume { row -> row.get("PathToBlob") } + .toList() + } + + /** + * Reads a list of [RequisitionMetadata] by localComputationId + * + * @param readContext The transaction context for reading from the Postgres database. + * @param localComputationId A local identifier for a computation + * @return A list of requisition blob keys + */ + suspend fun readRequisitionMetadata( + readContext: ReadContext, + localComputationId: Long, + ): List { + val statement = + boundStatement( + """ + SELECT + ExternalRequisitionId, RequisitionFingerprint, PathToBlob, RequisitionDetails + FROM Requisitions + WHERE ComputationId = $1 + """ + .trimIndent() + ) { + bind("$1", localComputationId) + } + + return readContext.executeQuery(statement).consume(::buildRequisitionMetadata).toList() + } + + private fun buildRequisitionMetadata(row: ResultRow): RequisitionMetadata { + return requisitionMetadata { + externalKey = externalRequisitionKey { + externalRequisitionId = row["ExternalRequisitionId"] + requisitionFingerprint = row["RequisitionFingerprint"] + } + row.get("PathToBlob")?.let { path = it } + details = row.getProtoMessage("RequisitionDetails", RequisitionDetails.parser()) + } + } +}