From 7e71aaf41010df619719bf2b394098affcd0db0b Mon Sep 17 00:00:00 2001 From: Yuhong Wang Date: Wed, 26 Jul 2023 17:24:50 +0000 Subject: [PATCH 01/12] Add postgres computations service --- .../duchy/db/computation/BUILD.bazel | 1 + .../db/computation/ComputationEditToken.kt | 39 +++ .../duchy/db/computation/ComputationTypes.kt | 8 + .../db/computation/ComputationsDatabase.kt | 18 - .../testing/FakeComputationsDatabase.kt | 2 +- .../duchy/deploy/common/postgres/BUILD.bazel | 2 +- .../postgres/PostgresComputationsService.kt | 315 ++++++++++++++++++ .../postgres/readers/ComputationReader.kt | 7 +- .../readers/ComputationStageAttemptReader.kt | 2 +- .../common/postgres/testing/BUILD.bazel | 2 +- .../writers/AdvanceComputationStage.kt | 2 +- .../postgres/writers/FinishComputation.kt | 2 +- ...cpSpannerComputationsDatabaseTransactor.kt | 2 +- .../computations/ComputationsService.kt | 25 +- .../testing/ComputationsServiceTest.kt | 79 ++++- .../deploy/{ => common}/postgres/BUILD.bazel | 24 +- .../{ => common}/postgres/DuchySchemaTest.kt | 2 +- .../PostgresComputationStatsServiceTest.kt | 11 +- .../PostgresComputationsServiceTest.kt | 78 +++++ .../PostgresContinuationTokensServiceTest.kt | 3 +- ...annerComputationsDatabaseTransactorTest.kt | 2 +- 21 files changed, 566 insertions(+), 60 deletions(-) create mode 100644 src/main/kotlin/org/wfanet/measurement/duchy/db/computation/ComputationEditToken.kt rename src/test/kotlin/org/wfanet/measurement/duchy/deploy/{ => common}/postgres/BUILD.bazel (62%) rename src/test/kotlin/org/wfanet/measurement/duchy/deploy/{ => common}/postgres/DuchySchemaTest.kt (97%) rename src/test/kotlin/org/wfanet/measurement/duchy/deploy/{ => common}/postgres/PostgresComputationStatsServiceTest.kt (90%) create mode 100644 src/test/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresComputationsServiceTest.kt rename src/test/kotlin/org/wfanet/measurement/duchy/deploy/{ => common}/postgres/PostgresContinuationTokensServiceTest.kt (91%) diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/BUILD.bazel index e69469ff8c0..34ecd421996 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/BUILD.bazel @@ -18,5 +18,6 @@ kt_jvm_library( "//src/main/proto/wfa/measurement/internal/duchy:computations_service_kt_jvm_grpc_proto", "//src/main/proto/wfa/measurement/internal/duchy/protocol:liquid_legions_v2_kt_jvm_proto", "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc", ], ) diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/ComputationEditToken.kt b/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/ComputationEditToken.kt new file mode 100644 index 00000000000..1209ce783c7 --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/ComputationEditToken.kt @@ -0,0 +1,39 @@ +package org.wfanet.measurement.duchy.db.computation + +import org.wfanet.measurement.common.grpc.failGrpc +import org.wfanet.measurement.internal.duchy.ComputationStage +import org.wfanet.measurement.internal.duchy.ComputationToken +import org.wfanet.measurement.internal.duchy.ComputationTypeEnum.ComputationType + +/** Information about a computation needed to edit a computation. */ +data class ComputationEditToken( + /** The identifier for the computation used locally. */ + val localId: Long, + /** The protocol used for the computation. */ + val protocol: ProtocolT, + /** The stage of the computation when the token was created. */ + val stage: StageT, + /** The number of the current attempt of this stage for this computation. */ + val attempt: Int, + /** + * The version number of the last known edit to the computation. The version is a monotonically + * increasing number used as a guardrail to protect against concurrent edits to the same + * computation. + */ + val editVersion: Long +) + +fun ComputationToken.toDatabaseEditToken(): + ComputationEditToken { + val protocol = computationStage.toComputationType() + if (protocol == ComputationType.UNRECOGNIZED) { + failGrpc { "Computation type for $this is unknown" } + } + return ComputationEditToken( + localId = localComputationId, + protocol = protocol, + stage = computationStage, + attempt = attempt, + editVersion = version + ) +} diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/ComputationTypes.kt b/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/ComputationTypes.kt index 3fe85eff996..f30e12ae1ad 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/ComputationTypes.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/ComputationTypes.kt @@ -15,6 +15,7 @@ package org.wfanet.measurement.duchy.db.computation import org.wfanet.measurement.common.numberAsLong +import org.wfanet.measurement.internal.duchy.ComputationStage import org.wfanet.measurement.internal.duchy.ComputationTypeEnum.ComputationType /** Helper class for working with [ComputationType] protocols. */ @@ -28,3 +29,10 @@ object ComputationTypes : ComputationTypeEnumHelper { return ComputationType.forNumber(value.toInt()) ?: ComputationType.UNRECOGNIZED } } + +fun ComputationStage.toComputationType() = + when (stageCase) { + ComputationStage.StageCase.LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 -> + ComputationType.LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 + else -> ComputationType.UNRECOGNIZED + } diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/ComputationsDatabase.kt b/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/ComputationsDatabase.kt index 0d15c890da7..92a7ee0630d 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/ComputationsDatabase.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/ComputationsDatabase.kt @@ -176,24 +176,6 @@ interface ComputationsDatabaseTransactor( - /** The identifier for the computation used locally. */ - val localId: Long, - /** The protocol used for the computation. */ - val protocol: ProtocolT, - /** The stage of the computation when the token was created. */ - val stage: StageT, - /** The number of the current attempt of this stage for this computation. */ - val attempt: Int, - /** - * The version number of the last known edit to the computation. The version is a monotonically - * increasing number used as a guardrail to protect against concurrent edits to the same - * computation. - */ - val editVersion: Long - ) } /** diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/testing/FakeComputationsDatabase.kt b/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/testing/FakeComputationsDatabase.kt index c44a54e7fd6..1c60675b339 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/testing/FakeComputationsDatabase.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/testing/FakeComputationsDatabase.kt @@ -21,11 +21,11 @@ import kotlin.experimental.ExperimentalTypeInference import org.wfanet.measurement.common.toJson import org.wfanet.measurement.duchy.db.computation.AfterTransition import org.wfanet.measurement.duchy.db.computation.BlobRef +import org.wfanet.measurement.duchy.db.computation.ComputationEditToken import org.wfanet.measurement.duchy.db.computation.ComputationProtocolStages import org.wfanet.measurement.duchy.db.computation.ComputationProtocolStagesEnumHelper import org.wfanet.measurement.duchy.db.computation.ComputationStatMetric import org.wfanet.measurement.duchy.db.computation.ComputationsDatabase -import org.wfanet.measurement.duchy.db.computation.ComputationsDatabaseTransactor.ComputationEditToken import org.wfanet.measurement.duchy.db.computation.EndComputationReason import org.wfanet.measurement.duchy.db.computation.toCompletedReason import org.wfanet.measurement.duchy.service.internal.computations.newEmptyOutputBlobMetadata diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/BUILD.bazel index b602dc25b22..65c4f3e15a7 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/BUILD.bazel @@ -5,7 +5,7 @@ kt_jvm_library( srcs = glob(["*Service.kt"]), visibility = [ "//src/main/kotlin/org/wfanet/measurement/duchy/common/deploy/postgres:__pkg__", - "//src/test/kotlin/org/wfanet/measurement/duchy/deploy/postgres:__pkg__", + "//src/test/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres:__pkg__", ], deps = [ "//src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/readers", diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresComputationsService.kt b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresComputationsService.kt index 5fa8a7f21a4..64610405a0e 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresComputationsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresComputationsService.kt @@ -14,7 +14,9 @@ package org.wfanet.measurement.duchy.deploy.common.postgres +import com.google.protobuf.Empty import io.grpc.Status +import io.grpc.StatusException import java.time.Clock import java.time.Duration import java.util.logging.Level @@ -25,21 +27,45 @@ import org.wfanet.measurement.common.grpc.grpcRequire import org.wfanet.measurement.common.identity.IdGenerator import org.wfanet.measurement.common.protoTimestamp import org.wfanet.measurement.common.toDuration +import org.wfanet.measurement.common.toInstant +import org.wfanet.measurement.duchy.db.computation.AfterTransition +import org.wfanet.measurement.duchy.db.computation.BlobRef import org.wfanet.measurement.duchy.db.computation.ComputationProtocolStageDetailsHelper import org.wfanet.measurement.duchy.db.computation.ComputationProtocolStagesEnumHelper import org.wfanet.measurement.duchy.db.computation.ComputationTypeEnumHelper +import org.wfanet.measurement.duchy.db.computation.EndComputationReason +import org.wfanet.measurement.duchy.db.computation.toDatabaseEditToken +import org.wfanet.measurement.duchy.deploy.common.postgres.readers.ComputationBlobReferenceReader import org.wfanet.measurement.duchy.deploy.common.postgres.readers.ComputationReader +import org.wfanet.measurement.duchy.deploy.common.postgres.readers.RequisitionReader +import org.wfanet.measurement.duchy.deploy.common.postgres.writers.AdvanceComputationStage import org.wfanet.measurement.duchy.deploy.common.postgres.writers.ClaimWork import org.wfanet.measurement.duchy.deploy.common.postgres.writers.CreateComputation +import org.wfanet.measurement.duchy.deploy.common.postgres.writers.DeleteComputation +import org.wfanet.measurement.duchy.deploy.common.postgres.writers.EnqueueComputation +import org.wfanet.measurement.duchy.deploy.common.postgres.writers.FinishComputation +import org.wfanet.measurement.duchy.deploy.common.postgres.writers.RecordOutputBlobPath +import org.wfanet.measurement.duchy.deploy.common.postgres.writers.RecordRequisitionBlobPath +import org.wfanet.measurement.duchy.deploy.common.postgres.writers.UpdateComputationDetails import org.wfanet.measurement.duchy.name import org.wfanet.measurement.duchy.number import org.wfanet.measurement.duchy.service.internal.ComputationAlreadyExistsException import org.wfanet.measurement.duchy.service.internal.ComputationDetailsNotFoundException import org.wfanet.measurement.duchy.service.internal.ComputationInitialStageInvalidException import org.wfanet.measurement.duchy.service.internal.ComputationNotFoundException +import org.wfanet.measurement.duchy.service.internal.computations.toAdvanceComputationStageResponse import org.wfanet.measurement.duchy.service.internal.computations.toClaimWorkResponse import org.wfanet.measurement.duchy.service.internal.computations.toCreateComputationResponse +import org.wfanet.measurement.duchy.service.internal.computations.toFinishComputationResponse import org.wfanet.measurement.duchy.service.internal.computations.toGetComputationTokenResponse +import org.wfanet.measurement.duchy.service.internal.computations.toRecordOutputBlobPathResponse +import org.wfanet.measurement.duchy.service.internal.computations.toRecordRequisitionBlobPathResponse +import org.wfanet.measurement.duchy.service.internal.computations.toUpdateComputationDetailsResponse +import org.wfanet.measurement.duchy.storage.ComputationStore +import org.wfanet.measurement.duchy.storage.RequisitionStore +import org.wfanet.measurement.duchy.toProtocolStage +import org.wfanet.measurement.internal.duchy.AdvanceComputationStageRequest +import org.wfanet.measurement.internal.duchy.AdvanceComputationStageResponse import org.wfanet.measurement.internal.duchy.ClaimWorkRequest import org.wfanet.measurement.internal.duchy.ClaimWorkResponse import org.wfanet.measurement.internal.duchy.ComputationDetails @@ -49,9 +75,26 @@ import org.wfanet.measurement.internal.duchy.ComputationTypeEnum.ComputationType import org.wfanet.measurement.internal.duchy.ComputationsGrpcKt.ComputationsCoroutineImplBase import org.wfanet.measurement.internal.duchy.CreateComputationRequest import org.wfanet.measurement.internal.duchy.CreateComputationResponse +import org.wfanet.measurement.internal.duchy.DeleteComputationRequest +import org.wfanet.measurement.internal.duchy.EnqueueComputationRequest +import org.wfanet.measurement.internal.duchy.EnqueueComputationResponse +import org.wfanet.measurement.internal.duchy.FinishComputationRequest +import org.wfanet.measurement.internal.duchy.FinishComputationResponse +import org.wfanet.measurement.internal.duchy.GetComputationIdsRequest +import org.wfanet.measurement.internal.duchy.GetComputationIdsResponse import org.wfanet.measurement.internal.duchy.GetComputationTokenRequest import org.wfanet.measurement.internal.duchy.GetComputationTokenRequest.KeyCase import org.wfanet.measurement.internal.duchy.GetComputationTokenResponse +import org.wfanet.measurement.internal.duchy.PurgeComputationsRequest +import org.wfanet.measurement.internal.duchy.PurgeComputationsResponse +import org.wfanet.measurement.internal.duchy.RecordOutputBlobPathRequest +import org.wfanet.measurement.internal.duchy.RecordOutputBlobPathResponse +import org.wfanet.measurement.internal.duchy.RecordRequisitionBlobPathRequest +import org.wfanet.measurement.internal.duchy.RecordRequisitionBlobPathResponse +import org.wfanet.measurement.internal.duchy.UpdateComputationDetailsRequest +import org.wfanet.measurement.internal.duchy.UpdateComputationDetailsResponse +import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggregationV2 +import org.wfanet.measurement.internal.duchy.purgeComputationsResponse import org.wfanet.measurement.system.v1alpha.ComputationLogEntriesGrpcKt.ComputationLogEntriesCoroutineStub import org.wfanet.measurement.system.v1alpha.ComputationParticipantKey import org.wfanet.measurement.system.v1alpha.CreateComputationLogEntryRequest @@ -71,12 +114,16 @@ class PostgresComputationsService( private val client: DatabaseClient, private val idGenerator: IdGenerator, private val duchyName: String, + private val computationStorageClient: ComputationStore, + private val requisitionStorageClient: RequisitionStore, private val computationLogEntriesClient: ComputationLogEntriesCoroutineStub, private val clock: Clock = Clock.systemUTC(), private val defaultLockDuration: Duration = Duration.ofMinutes(5), ) : ComputationsCoroutineImplBase() { private val computationReader = ComputationReader(protocolStagesEnumHelper) + private val computationBlobReferenceReader = ComputationBlobReferenceReader() + private val requisitionReader = RequisitionReader() override suspend fun createComputation( request: CreateComputationRequest @@ -167,6 +214,255 @@ class PostgresComputationsService( return token.toGetComputationTokenResponse() } + override suspend fun deleteComputation(request: DeleteComputationRequest): Empty { + val computationBlobKeys = + computationBlobReferenceReader.readComputationBlobKeys( + client.singleUse(), + request.localComputationId + ) + for (blobKey in computationBlobKeys) { + try { + computationStorageClient.get(blobKey)?.delete() + } catch (e: StatusException) { + if (e.status.code != Status.Code.NOT_FOUND) { + throw e + } + } + } + + val requisitionBlobKeys = + requisitionReader.readRequisitionBlobKeys(client.singleUse(), request.localComputationId) + for (blobKey in requisitionBlobKeys) { + try { + requisitionStorageClient.get(blobKey)?.delete() + } catch (e: StatusException) { + if (e.status.code != Status.NOT_FOUND.code) { + throw e + } + } + } + + DeleteComputation(request.localComputationId).execute(client, idGenerator) + return Empty.getDefaultInstance() + } + + override suspend fun purgeComputations( + request: PurgeComputationsRequest + ): PurgeComputationsResponse { + var deleted = 0 + try { + val globalIds: Set = + computationReader.readGlobalComputationIds( + client.singleUse(), + request.stagesList, + request.updatedBefore.toInstant() + ) + if (!request.force) { + return purgeComputationsResponse { + purgeCount = globalIds.size + purgeSample += globalIds + } + } + for (globalId in globalIds) { + val token = computationReader.readComputationToken(client, globalId) ?: continue + val computationStageEnum = token.computationStage + val endComputationStage = getEndingComputationStage(computationStageEnum) + + if (!isTerminated(computationStageEnum)) { + FinishComputation( + token.toDatabaseEditToken(), + endingStage = endComputationStage, + endComputationReason = EndComputationReason.FAILED, + computationDetails = token.computationDetails, + clock = clock, + protocolStagesEnumHelper = protocolStagesEnumHelper, + protocolStageDetailsHelper = computationProtocolStageDetailsHelper, + ) + .execute(client, idGenerator) + sendStatusUpdateToKingdom( + newCreateComputationLogEntryRequest( + token.globalComputationId, + endComputationStage, + ) + ) + } + DeleteComputation(token.localComputationId).execute(client, idGenerator) + deleted += 1 + } + } catch (e: Exception) { + logger.log(Level.WARNING, "Exception during Computations cleaning. $e") + } + return purgeComputationsResponse { this.purgeCount = deleted } + } + + override suspend fun finishComputation( + request: FinishComputationRequest + ): FinishComputationResponse { + FinishComputation( + request.token.toDatabaseEditToken(), + endingStage = request.endingComputationStage, + endComputationReason = + when (val it = request.reason) { + ComputationDetails.CompletedReason.SUCCEEDED -> EndComputationReason.SUCCEEDED + ComputationDetails.CompletedReason.FAILED -> EndComputationReason.FAILED + ComputationDetails.CompletedReason.CANCELED -> EndComputationReason.CANCELED + else -> error("Unknown CompletedReason $it") + }, + computationDetails = request.token.computationDetails, + clock = clock, + protocolStagesEnumHelper = protocolStagesEnumHelper, + protocolStageDetailsHelper = computationProtocolStageDetailsHelper, + ) + .execute(client, idGenerator) + + sendStatusUpdateToKingdom( + newCreateComputationLogEntryRequest( + request.token.globalComputationId, + request.endingComputationStage + ) + ) + + val token = + computationReader.readComputationToken(client, request.token.globalComputationId) + ?: failGrpc(Status.INTERNAL) { + "Finished computation ${request.token.globalComputationId} not found." + } + + return token.toFinishComputationResponse() + } + + override suspend fun updateComputationDetails( + request: UpdateComputationDetailsRequest + ): UpdateComputationDetailsResponse { + require(request.token.computationDetails.protocolCase == request.details.protocolCase) { + "The protocol type cannot change." + } + UpdateComputationDetails( + clock = clock, + localId = request.token.localComputationId, + editVersion = request.token.version, + computationDetails = request.details, + requisitionEntries = request.requisitionsList + ) + .execute(client, idGenerator) + + val token = + computationReader.readComputationToken(client, request.token.globalComputationId) + ?: failGrpc(Status.INTERNAL) { + "Updated computation ${request.token.globalComputationId} not found." + } + return token.toUpdateComputationDetailsResponse() + } + + override suspend fun recordOutputBlobPath( + request: RecordOutputBlobPathRequest + ): RecordOutputBlobPathResponse { + + RecordOutputBlobPath( + clock = clock, + localId = request.token.localComputationId, + editVersion = request.token.version, + stage = request.token.computationStage, + blobRef = BlobRef(request.outputBlobId, request.blobPath), + protocolStagesEnumHelper = protocolStagesEnumHelper + ) + .execute(client, idGenerator) + + val token = + computationReader.readComputationToken(client, request.token.globalComputationId) + ?: failGrpc(Status.INTERNAL) { + "Computation ${request.token.globalComputationId} not found." + } + return token.toRecordOutputBlobPathResponse() + } + + override suspend fun advanceComputationStage( + request: AdvanceComputationStageRequest + ): AdvanceComputationStageResponse { + val lockExtension: Duration = + if (request.hasLockExtension()) request.lockExtension.toDuration() else defaultLockDuration + val afterTransition = + when (val it = request.afterTransition) { + AdvanceComputationStageRequest.AfterTransition.ADD_UNCLAIMED_TO_QUEUE -> + AfterTransition.ADD_UNCLAIMED_TO_QUEUE + AdvanceComputationStageRequest.AfterTransition.DO_NOT_ADD_TO_QUEUE -> + AfterTransition.DO_NOT_ADD_TO_QUEUE + AdvanceComputationStageRequest.AfterTransition.RETAIN_AND_EXTEND_LOCK -> + AfterTransition.CONTINUE_WORKING + else -> error("Unsupported AdvanceComputationStageRequest.AfterTransition '$it'. ") + } + + AdvanceComputationStage( + request.token.toDatabaseEditToken(), + nextStage = request.nextComputationStage, + nextStageDetails = request.stageDetails, + inputBlobPaths = request.inputBlobsList, + passThroughBlobPaths = request.passThroughBlobsList, + outputBlobs = request.outputBlobs, + afterTransition = afterTransition, + lockExtension = lockExtension, + clock = clock, + protocolStagesEnumHelper = protocolStagesEnumHelper + ) + .execute(client, idGenerator) + + sendStatusUpdateToKingdom( + newCreateComputationLogEntryRequest( + request.token.globalComputationId, + request.nextComputationStage + ) + ) + + val token = + computationReader.readComputationToken(client, request.token.globalComputationId) + ?: failGrpc(Status.INTERNAL) { + "Computation ${request.token.globalComputationId} not found." + } + return token.toAdvanceComputationStageResponse() + } + + override suspend fun getComputationIds( + request: GetComputationIdsRequest + ): GetComputationIdsResponse { + val ids = computationReader.readGlobalComputationIds(client.singleUse(), request.stagesList) + return GetComputationIdsResponse.newBuilder().addAllGlobalIds(ids).build() + } + + override suspend fun enqueueComputation( + request: EnqueueComputationRequest + ): EnqueueComputationResponse { + grpcRequire(request.delaySecond >= 0) { + "DelaySecond ${request.delaySecond} should be non-negative." + } + EnqueueComputation( + request.token.localComputationId, + request.token.version, + request.delaySecond.toLong(), + clock, + ) + .execute(client, idGenerator) + return EnqueueComputationResponse.getDefaultInstance() + } + + override suspend fun recordRequisitionBlobPath( + request: RecordRequisitionBlobPathRequest + ): RecordRequisitionBlobPathResponse { + RecordRequisitionBlobPath( + clock = clock, + localId = request.token.localComputationId, + externalRequisitionKey = request.key, + pathToBlob = request.blobPath + ) + .execute(client, idGenerator) + + val token = + computationReader.readComputationToken(client, request.token.globalComputationId) + ?: failGrpc(Status.INTERNAL) { + "Computation ${request.token.globalComputationId} not found." + } + return token.toRecordRequisitionBlobPathResponse() + } + private fun newCreateComputationLogEntryRequest( globalId: String, computationStage: ComputationStage, @@ -196,6 +492,25 @@ class PostgresComputationsService( } } + private fun isTerminated(computationStage: ComputationStage): Boolean { + @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. + return when (computationStage.stageCase) { + ComputationStage.StageCase.LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 -> + computationStage.liquidLegionsSketchAggregationV2 == + LiquidLegionsSketchAggregationV2.Stage.COMPLETE + ComputationStage.StageCase.STAGE_NOT_SET -> false + } + } + + private fun getEndingComputationStage(computationStage: ComputationStage): ComputationStage { + @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. + return when (computationStage.stageCase) { + ComputationStage.StageCase.LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 -> + LiquidLegionsSketchAggregationV2.Stage.COMPLETE.toProtocolStage() + ComputationStage.StageCase.STAGE_NOT_SET -> error("protocol not set") + } + } + companion object { private val logger: Logger = Logger.getLogger(this::class.java.name) } diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/readers/ComputationReader.kt b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/readers/ComputationReader.kt index 4f4ffd2ad17..61141d0fd17 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/readers/ComputationReader.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/readers/ComputationReader.kt @@ -107,7 +107,10 @@ class ComputationReader( } } - suspend fun readComputation(readContext: ReadContext, globalComputationId: String): Computation? { + private suspend fun readComputation( + readContext: ReadContext, + globalComputationId: String + ): Computation? { val statement = boundStatement( """ @@ -271,7 +274,7 @@ class ComputationReader( if (updatedBefore == null) { boundStatement(baseSql) { bind("$1", computationTypes[0]) } } else { - boundStatement(baseSql + " AND UpdatedTime <= $2") { + boundStatement("$baseSql AND UpdateTime <= $2") { bind("$1", computationTypes[0]) bind("$2", updatedBefore) } diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/readers/ComputationStageAttemptReader.kt b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/readers/ComputationStageAttemptReader.kt index 2ed6081ea73..275d6df7330 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/readers/ComputationStageAttemptReader.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/readers/ComputationStageAttemptReader.kt @@ -89,7 +89,7 @@ class ComputationStageAttemptReader { """ SELECT ComputationId, ComputationStage, Attempt, Details FROM ComputationStageAttempts - WHERE s.ComputationId = $1 + WHERE ComputationId = $1 AND EndTime IS NULL """ ) { diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/testing/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/testing/BUILD.bazel index 6ba7dad719f..1390b937044 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/testing/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/testing/BUILD.bazel @@ -3,7 +3,7 @@ load("@io_bazel_rules_kotlin//kotlin:jvm.bzl", "kt_jvm_library") package( default_testonly = True, default_visibility = [ - "//src/test/kotlin/org/wfanet/measurement/duchy/deploy/postgres:__subpackages__", + "//src/test/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres:__subpackages__", ], ) diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/AdvanceComputationStage.kt b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/AdvanceComputationStage.kt index a54a4fb3c6e..5c7c3874fc7 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/AdvanceComputationStage.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/AdvanceComputationStage.kt @@ -20,8 +20,8 @@ import java.time.Duration import org.wfanet.measurement.common.db.r2dbc.postgres.PostgresWriter import org.wfanet.measurement.common.numberAsLong import org.wfanet.measurement.duchy.db.computation.AfterTransition +import org.wfanet.measurement.duchy.db.computation.ComputationEditToken import org.wfanet.measurement.duchy.db.computation.ComputationProtocolStagesEnumHelper -import org.wfanet.measurement.duchy.db.computation.ComputationsDatabaseTransactor.ComputationEditToken import org.wfanet.measurement.duchy.deploy.common.postgres.readers.ComputationBlobReferenceReader import org.wfanet.measurement.duchy.deploy.common.postgres.readers.ComputationStageAttemptReader import org.wfanet.measurement.internal.duchy.ComputationBlobDependency diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/FinishComputation.kt b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/FinishComputation.kt index 6fc013fb283..67ab2d2c5a2 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/FinishComputation.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/FinishComputation.kt @@ -18,9 +18,9 @@ import com.google.protobuf.Message import java.time.Clock import java.util.logging.Logger import org.wfanet.measurement.common.db.r2dbc.postgres.PostgresWriter +import org.wfanet.measurement.duchy.db.computation.ComputationEditToken import org.wfanet.measurement.duchy.db.computation.ComputationProtocolStageDetailsHelper import org.wfanet.measurement.duchy.db.computation.ComputationProtocolStagesEnumHelper -import org.wfanet.measurement.duchy.db.computation.ComputationsDatabaseTransactor.ComputationEditToken import org.wfanet.measurement.duchy.db.computation.EndComputationReason import org.wfanet.measurement.duchy.deploy.common.postgres.readers.ComputationStageAttemptReader import org.wfanet.measurement.internal.duchy.ComputationStageAttemptDetails diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/gcloud/spanner/computation/GcpSpannerComputationsDatabaseTransactor.kt b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/gcloud/spanner/computation/GcpSpannerComputationsDatabaseTransactor.kt index c16aa42564a..593d4a2abc2 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/gcloud/spanner/computation/GcpSpannerComputationsDatabaseTransactor.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/gcloud/spanner/computation/GcpSpannerComputationsDatabaseTransactor.kt @@ -28,9 +28,9 @@ import kotlinx.coroutines.flow.firstOrNull import kotlinx.coroutines.flow.toList import org.wfanet.measurement.duchy.db.computation.AfterTransition import org.wfanet.measurement.duchy.db.computation.BlobRef +import org.wfanet.measurement.duchy.db.computation.ComputationEditToken import org.wfanet.measurement.duchy.db.computation.ComputationStatMetric import org.wfanet.measurement.duchy.db.computation.ComputationsDatabaseTransactor -import org.wfanet.measurement.duchy.db.computation.ComputationsDatabaseTransactor.ComputationEditToken import org.wfanet.measurement.duchy.db.computation.EndComputationReason import org.wfanet.measurement.gcloud.common.gcloudTimestamp import org.wfanet.measurement.gcloud.common.toGcloudByteArray diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/computations/ComputationsService.kt b/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/computations/ComputationsService.kt index 1807c4d3f1a..05f4351c05e 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/computations/ComputationsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/computations/ComputationsService.kt @@ -21,7 +21,6 @@ import java.time.Clock import java.time.Duration import java.util.logging.Level import java.util.logging.Logger -import org.wfanet.measurement.common.grpc.failGrpc import org.wfanet.measurement.common.grpc.grpcRequire import org.wfanet.measurement.common.protoTimestamp import org.wfanet.measurement.common.toDuration @@ -29,8 +28,8 @@ import org.wfanet.measurement.common.toInstant import org.wfanet.measurement.duchy.db.computation.AfterTransition import org.wfanet.measurement.duchy.db.computation.BlobRef import org.wfanet.measurement.duchy.db.computation.ComputationsDatabase -import org.wfanet.measurement.duchy.db.computation.ComputationsDatabaseTransactor.ComputationEditToken import org.wfanet.measurement.duchy.db.computation.EndComputationReason +import org.wfanet.measurement.duchy.db.computation.toDatabaseEditToken import org.wfanet.measurement.duchy.name import org.wfanet.measurement.duchy.number import org.wfanet.measurement.duchy.storage.ComputationStore @@ -43,7 +42,6 @@ import org.wfanet.measurement.internal.duchy.ClaimWorkResponse import org.wfanet.measurement.internal.duchy.ComputationDetails import org.wfanet.measurement.internal.duchy.ComputationStage import org.wfanet.measurement.internal.duchy.ComputationToken -import org.wfanet.measurement.internal.duchy.ComputationTypeEnum.ComputationType import org.wfanet.measurement.internal.duchy.ComputationsGrpcKt.ComputationsCoroutineImplBase import org.wfanet.measurement.internal.duchy.CreateComputationRequest import org.wfanet.measurement.internal.duchy.CreateComputationResponse @@ -98,7 +96,9 @@ class ComputationsService( ) ) token.toClaimWorkResponse() - } else ClaimWorkResponse.getDefaultInstance() + } else { + ClaimWorkResponse.getDefaultInstance() + } } override suspend fun createComputation( @@ -391,20 +391,3 @@ class ComputationsService( private val logger: Logger = Logger.getLogger(this::class.java.name) } } - -private fun ComputationToken.toDatabaseEditToken(): - ComputationEditToken = - ComputationEditToken( - localId = localComputationId, - protocol = computationStage.toComputationType(), - stage = computationStage, - attempt = attempt, - editVersion = version - ) - -private fun ComputationStage.toComputationType() = - when (stageCase) { - ComputationStage.StageCase.LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 -> - ComputationType.LIQUID_LEGIONS_SKETCH_AGGREGATION_V2 - else -> failGrpc { "Computation type for $this is unknown" } - } diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/testing/ComputationsServiceTest.kt b/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/testing/ComputationsServiceTest.kt index d314ce9046d..584ef2a547c 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/testing/ComputationsServiceTest.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/testing/ComputationsServiceTest.kt @@ -522,6 +522,83 @@ abstract class ComputationsServiceTest { assertThat(exception.message).contains("editVersion mismatch") } + @Test + fun `advanceComputationStage enqueues the computation when afterTransition is ADD_UNCLAIMED_TO_QUEUE`() = + runBlocking { + service.createComputation(DEFAULT_CREATE_COMPUTATION_REQUEST) + val claimWorkResponse = service.claimWork(DEFAULT_CLAIM_WORK_REQUEST) + + val nextStage = computationStage { + liquidLegionsSketchAggregationV2 = Stage.WAIT_REQUISITIONS_AND_KEY_SET + } + val advanceComputationStageResp = + service.advanceComputationStage( + advanceComputationStageRequest { + token = claimWorkResponse.token + nextComputationStage = nextStage + afterTransition = AfterTransition.ADD_UNCLAIMED_TO_QUEUE + } + ) + + assertThat(advanceComputationStageResp.token.attempt).isEqualTo(0) + assertThat(advanceComputationStageResp.token.lockOwner).isEmpty() + assertThat(advanceComputationStageResp.token.lockExpirationTime) + .isEqualTo(clock.last().toProtoTime()) + } + + @Test + fun `advanceComputationStage releases the computation lock when afterTransition is DO_NOT_ADD_TO_QUEUE`() = + runBlocking { + service.createComputation(DEFAULT_CREATE_COMPUTATION_REQUEST) + val claimWorkResponse = service.claimWork(DEFAULT_CLAIM_WORK_REQUEST) + + val nextStage = computationStage { + liquidLegionsSketchAggregationV2 = Stage.WAIT_REQUISITIONS_AND_KEY_SET + } + val advanceComputationStageResp = + service.advanceComputationStage( + advanceComputationStageRequest { + token = claimWorkResponse.token + nextComputationStage = nextStage + afterTransition = AfterTransition.DO_NOT_ADD_TO_QUEUE + } + ) + + assertThat(advanceComputationStageResp.token.lockOwner).isEmpty() + assertThat(advanceComputationStageResp.token.lockExpirationTime).isEqualToDefaultInstance() + } + + @Test + fun `advanceComputationStage throws when output blobs are not fulfilled`() = runBlocking { + service.createComputation(DEFAULT_CREATE_COMPUTATION_REQUEST) + val claimWorkResponse = service.claimWork(DEFAULT_CLAIM_WORK_REQUEST) + val advanceResp = + service.advanceComputationStage( + advanceComputationStageRequest { + token = claimWorkResponse.token + nextComputationStage = computationStage { + liquidLegionsSketchAggregationV2 = Stage.WAIT_REQUISITIONS_AND_KEY_SET + } + afterTransition = AfterTransition.RETAIN_AND_EXTEND_LOCK + outputBlobs = 2 + } + ) + + val nextStage = computationStage { liquidLegionsSketchAggregationV2 = Stage.CONFIRMATION_PHASE } + val exception = + assertFailsWith { + service.advanceComputationStage( + advanceComputationStageRequest { + token = advanceResp.token + nextComputationStage = nextStage + afterTransition = AfterTransition.DO_NOT_ADD_TO_QUEUE + } + ) + } + + assertThat(exception.message).contains("written") + } + @Test fun `finishComputation returns computation in terminal stage`() = runBlocking { val createComputationResp = service.createComputation(DEFAULT_CREATE_COMPUTATION_REQUEST) @@ -892,6 +969,6 @@ abstract class ComputationsServiceTest { service.recordRequisitionBlobPath(recordRequisitionBlobPathRequest) } - assertThat(exception.message).contains("No Computation found") + assertThat(exception.message).contains("found") } } diff --git a/src/test/kotlin/org/wfanet/measurement/duchy/deploy/postgres/BUILD.bazel b/src/test/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/BUILD.bazel similarity index 62% rename from src/test/kotlin/org/wfanet/measurement/duchy/deploy/postgres/BUILD.bazel rename to src/test/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/BUILD.bazel index 9269ed8f3d0..20222ec0d1d 100644 --- a/src/test/kotlin/org/wfanet/measurement/duchy/deploy/postgres/BUILD.bazel +++ b/src/test/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/BUILD.bazel @@ -3,7 +3,7 @@ load("@io_bazel_rules_kotlin//kotlin:jvm.bzl", "kt_jvm_test") kt_jvm_test( name = "DuchySchemaTest", srcs = ["DuchySchemaTest.kt"], - test_class = "org.wfanet.measurement.duchy.deploy.postgres.DuchySchemaTest", + test_class = "org.wfanet.measurement.duchy.deploy.common.postgres.DuchySchemaTest", deps = [ "//src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/testing", "@wfa_common_jvm//imports/java/com/google/common/truth", @@ -16,7 +16,7 @@ kt_jvm_test( kt_jvm_test( name = "PostgresContinuationTokensServiceTest", srcs = ["PostgresContinuationTokensServiceTest.kt"], - test_class = "org.wfanet.measurement.duchy.deploy.postgres.PostgresContinuationTokensServiceTest", + test_class = "org.wfanet.measurement.duchy.deploy.common.postgres.PostgresContinuationTokensServiceTest", deps = [ "//src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres:services", "//src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/testing", @@ -31,7 +31,25 @@ kt_jvm_test( kt_jvm_test( name = "PostgresComputationStatsServiceTest", srcs = ["PostgresComputationStatsServiceTest.kt"], - test_class = "org.wfanet.measurement.duchy.deploy.postgres.PostgresComputationStatsServiceTest", + test_class = "org.wfanet.measurement.duchy.deploy.common.postgres.PostgresComputationStatsServiceTest", + deps = [ + "//src/main/kotlin/org/wfanet/measurement/duchy/db/computation/testing", + "//src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres:services", + "//src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/testing", + "//src/main/kotlin/org/wfanet/measurement/duchy/service/internal/computations", + "//src/main/kotlin/org/wfanet/measurement/duchy/service/internal/testing", + "@wfa_common_jvm//imports/java/com/google/common/truth", + "@wfa_common_jvm//imports/java/com/opentable/db/postgres:pg_embedded", + "@wfa_common_jvm//imports/java/org/junit", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/db/r2dbc/postgres/testing:embedded_postgres", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/storage/filesystem:client", + ], +) + +kt_jvm_test( + name = "PostgresComputationsServiceTest", + srcs = ["PostgresComputationsServiceTest.kt"], + test_class = "org.wfanet.measurement.duchy.deploy.common.postgres.PostgresComputationsServiceTest", deps = [ "//src/main/kotlin/org/wfanet/measurement/duchy/db/computation/testing", "//src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres:services", diff --git a/src/test/kotlin/org/wfanet/measurement/duchy/deploy/postgres/DuchySchemaTest.kt b/src/test/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/DuchySchemaTest.kt similarity index 97% rename from src/test/kotlin/org/wfanet/measurement/duchy/deploy/postgres/DuchySchemaTest.kt rename to src/test/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/DuchySchemaTest.kt index 4450bc3a54f..526c6365b86 100644 --- a/src/test/kotlin/org/wfanet/measurement/duchy/deploy/postgres/DuchySchemaTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/DuchySchemaTest.kt @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package org.wfanet.measurement.duchy.deploy.postgres +package org.wfanet.measurement.duchy.deploy.common.postgres import com.google.common.truth.Truth.assertThat import kotlinx.coroutines.flow.toList diff --git a/src/test/kotlin/org/wfanet/measurement/duchy/deploy/postgres/PostgresComputationStatsServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresComputationStatsServiceTest.kt similarity index 90% rename from src/test/kotlin/org/wfanet/measurement/duchy/deploy/postgres/PostgresComputationStatsServiceTest.kt rename to src/test/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresComputationStatsServiceTest.kt index 0b4ac95986e..a2d3e280064 100644 --- a/src/test/kotlin/org/wfanet/measurement/duchy/deploy/postgres/PostgresComputationStatsServiceTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresComputationStatsServiceTest.kt @@ -12,10 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -package org.wfanet.measurement.duchy.deploy.postgres +package org.wfanet.measurement.duchy.deploy.common.postgres import java.time.Clock import kotlin.random.Random +import org.junit.Rule import org.junit.rules.TemporaryFolder import org.junit.runner.RunWith import org.junit.runners.JUnit4 @@ -23,11 +24,10 @@ import org.wfanet.measurement.common.db.r2dbc.postgres.testing.EmbeddedPostgresD import org.wfanet.measurement.common.grpc.testing.GrpcTestServerRule import org.wfanet.measurement.common.grpc.testing.mockService import org.wfanet.measurement.common.identity.RandomIdGenerator +import org.wfanet.measurement.common.testing.chainRulesSequentially import org.wfanet.measurement.duchy.db.computation.ComputationProtocolStageDetails import org.wfanet.measurement.duchy.db.computation.ComputationProtocolStages import org.wfanet.measurement.duchy.db.computation.ComputationTypes -import org.wfanet.measurement.duchy.deploy.common.postgres.PostgresComputationStatsService -import org.wfanet.measurement.duchy.deploy.common.postgres.PostgresComputationsService import org.wfanet.measurement.duchy.deploy.common.postgres.testing.Schemata.DUCHY_CHANGELOG_PATH import org.wfanet.measurement.duchy.service.internal.testing.ComputationStatsServiceTest import org.wfanet.measurement.duchy.storage.ComputationStore @@ -59,6 +59,7 @@ class PostgresComputationStatsServiceTest : requisitionStore = RequisitionStore(storageClient) addService(mockComputationLogEntriesService) } + @get:Rule val ruleChain = chainRulesSequentially(tempDirectory, grpcTestServerRule) private val systemComputationLogEntriesClient = ComputationLogEntriesCoroutineStub(grpcTestServerRule.channel) @@ -74,7 +75,9 @@ class PostgresComputationStatsServiceTest : client = client, idGenerator = idGenerator, duchyName = ALSACE, - computationLogEntriesClient = systemComputationLogEntriesClient + computationStorageClient = ComputationStore(storageClient), + requisitionStorageClient = RequisitionStore(storageClient), + computationLogEntriesClient = systemComputationLogEntriesClient, ) } } diff --git a/src/test/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresComputationsServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresComputationsServiceTest.kt new file mode 100644 index 00000000000..86f27a1f822 --- /dev/null +++ b/src/test/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresComputationsServiceTest.kt @@ -0,0 +1,78 @@ +// Copyright 2023 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package org.wfanet.measurement.duchy.deploy.common.postgres + +import java.time.Clock +import kotlin.random.Random +import org.junit.Rule +import org.junit.rules.TemporaryFolder +import org.junit.runner.RunWith +import org.junit.runners.JUnit4 +import org.wfanet.measurement.common.db.r2dbc.postgres.testing.EmbeddedPostgresDatabaseProvider +import org.wfanet.measurement.common.grpc.testing.GrpcTestServerRule +import org.wfanet.measurement.common.grpc.testing.mockService +import org.wfanet.measurement.common.identity.RandomIdGenerator +import org.wfanet.measurement.common.testing.chainRulesSequentially +import org.wfanet.measurement.duchy.db.computation.ComputationProtocolStageDetails +import org.wfanet.measurement.duchy.db.computation.ComputationProtocolStages +import org.wfanet.measurement.duchy.db.computation.ComputationTypes +import org.wfanet.measurement.duchy.deploy.common.postgres.testing.Schemata.DUCHY_CHANGELOG_PATH +import org.wfanet.measurement.duchy.service.internal.testing.ComputationsServiceTest +import org.wfanet.measurement.duchy.storage.ComputationStore +import org.wfanet.measurement.duchy.storage.RequisitionStore +import org.wfanet.measurement.storage.filesystem.FileSystemStorageClient +import org.wfanet.measurement.system.v1alpha.ComputationLogEntriesGrpcKt.ComputationLogEntriesCoroutineImplBase +import org.wfanet.measurement.system.v1alpha.ComputationLogEntriesGrpcKt.ComputationLogEntriesCoroutineStub + +private const val ALSACE = "Alsace" + +@RunWith(JUnit4::class) +class PostgresComputationsServiceTest : ComputationsServiceTest() { + + private lateinit var storageClient: FileSystemStorageClient + private lateinit var computationStore: ComputationStore + private lateinit var requisitionStore: RequisitionStore + private val tempDirectory = TemporaryFolder() + private val mockComputationLogEntriesService: ComputationLogEntriesCoroutineImplBase = + mockService() + + private val client = EmbeddedPostgresDatabaseProvider(DUCHY_CHANGELOG_PATH).createNewDatabase() + private val idGenerator = RandomIdGenerator(Clock.systemUTC(), Random(1)) + + private val grpcTestServerRule = GrpcTestServerRule { + storageClient = FileSystemStorageClient(tempDirectory.root) + computationStore = ComputationStore(storageClient) + requisitionStore = RequisitionStore(storageClient) + addService(mockComputationLogEntriesService) + } + @get:Rule val ruleChain = chainRulesSequentially(tempDirectory, grpcTestServerRule) + private val systemComputationLogEntriesClient = + ComputationLogEntriesCoroutineStub(grpcTestServerRule.channel) + + override fun newService(clock: Clock): PostgresComputationsService { + return PostgresComputationsService( + computationTypeEnumHelper = ComputationTypes, + protocolStagesEnumHelper = ComputationProtocolStages, + computationProtocolStageDetailsHelper = ComputationProtocolStageDetails, + client = client, + idGenerator = idGenerator, + duchyName = ALSACE, + computationStorageClient = computationStore, + requisitionStorageClient = requisitionStore, + computationLogEntriesClient = systemComputationLogEntriesClient, + clock = clock + ) + } +} diff --git a/src/test/kotlin/org/wfanet/measurement/duchy/deploy/postgres/PostgresContinuationTokensServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresContinuationTokensServiceTest.kt similarity index 91% rename from src/test/kotlin/org/wfanet/measurement/duchy/deploy/postgres/PostgresContinuationTokensServiceTest.kt rename to src/test/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresContinuationTokensServiceTest.kt index 370c4db61ac..febbbeb025e 100644 --- a/src/test/kotlin/org/wfanet/measurement/duchy/deploy/postgres/PostgresContinuationTokensServiceTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresContinuationTokensServiceTest.kt @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package org.wfanet.measurement.duchy.deploy.postgres +package org.wfanet.measurement.duchy.deploy.common.postgres import java.time.Clock import kotlin.random.Random @@ -20,7 +20,6 @@ import org.junit.runner.RunWith import org.junit.runners.JUnit4 import org.wfanet.measurement.common.db.r2dbc.postgres.testing.EmbeddedPostgresDatabaseProvider import org.wfanet.measurement.common.identity.RandomIdGenerator -import org.wfanet.measurement.duchy.deploy.common.postgres.PostgresContinuationTokensService import org.wfanet.measurement.duchy.deploy.common.postgres.testing.Schemata.DUCHY_CHANGELOG_PATH import org.wfanet.measurement.duchy.service.internal.testing.ContinuationTokensServiceTest diff --git a/src/test/kotlin/org/wfanet/measurement/duchy/deploy/gcloud/spanner/computation/GcpSpannerComputationsDatabaseTransactorTest.kt b/src/test/kotlin/org/wfanet/measurement/duchy/deploy/gcloud/spanner/computation/GcpSpannerComputationsDatabaseTransactorTest.kt index b676cde6f9f..92ac6159cb3 100644 --- a/src/test/kotlin/org/wfanet/measurement/duchy/deploy/gcloud/spanner/computation/GcpSpannerComputationsDatabaseTransactorTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/duchy/deploy/gcloud/spanner/computation/GcpSpannerComputationsDatabaseTransactorTest.kt @@ -32,12 +32,12 @@ import org.junit.runners.JUnit4 import org.wfanet.measurement.common.testing.TestClockWithNamedInstants import org.wfanet.measurement.duchy.db.computation.AfterTransition import org.wfanet.measurement.duchy.db.computation.BlobRef +import org.wfanet.measurement.duchy.db.computation.ComputationEditToken import org.wfanet.measurement.duchy.db.computation.ComputationProtocolStageDetailsHelper import org.wfanet.measurement.duchy.db.computation.ComputationProtocolStagesEnumHelper import org.wfanet.measurement.duchy.db.computation.ComputationStageLongValues import org.wfanet.measurement.duchy.db.computation.ComputationStatMetric import org.wfanet.measurement.duchy.db.computation.ComputationTypeEnumHelper -import org.wfanet.measurement.duchy.db.computation.ComputationsDatabaseTransactor.ComputationEditToken import org.wfanet.measurement.duchy.db.computation.EndComputationReason import org.wfanet.measurement.duchy.deploy.gcloud.spanner.computation.FakeProtocolStages.A import org.wfanet.measurement.duchy.deploy.gcloud.spanner.computation.FakeProtocolStages.B From 33e19735d28a450e3c023616a6096502dc59defd Mon Sep 17 00:00:00 2001 From: Yuhong Wang Date: Wed, 26 Jul 2023 17:32:49 +0000 Subject: [PATCH 02/12] license --- .../duchy/db/computation/ComputationEditToken.kt | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/ComputationEditToken.kt b/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/ComputationEditToken.kt index 1209ce783c7..24b2a4dd6e2 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/ComputationEditToken.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/ComputationEditToken.kt @@ -1,3 +1,17 @@ +// Copyright 2020 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package org.wfanet.measurement.duchy.db.computation import org.wfanet.measurement.common.grpc.failGrpc From 313ffcdd9d1e4b28a41796d7e377781cda99cf3e Mon Sep 17 00:00:00 2001 From: Yuhong Wang Date: Fri, 28 Jul 2023 14:19:27 +0000 Subject: [PATCH 03/12] feedbacks --- .../postgres/PostgresComputationsService.kt | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresComputationsService.kt b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresComputationsService.kt index 64610405a0e..8b3f0f31e06 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresComputationsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresComputationsService.kt @@ -93,6 +93,7 @@ import org.wfanet.measurement.internal.duchy.RecordRequisitionBlobPathRequest import org.wfanet.measurement.internal.duchy.RecordRequisitionBlobPathResponse import org.wfanet.measurement.internal.duchy.UpdateComputationDetailsRequest import org.wfanet.measurement.internal.duchy.UpdateComputationDetailsResponse +import org.wfanet.measurement.internal.duchy.getComputationIdsResponse import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggregationV2 import org.wfanet.measurement.internal.duchy.purgeComputationsResponse import org.wfanet.measurement.system.v1alpha.ComputationLogEntriesGrpcKt.ComputationLogEntriesCoroutineStub @@ -302,11 +303,11 @@ class PostgresComputationsService( request.token.toDatabaseEditToken(), endingStage = request.endingComputationStage, endComputationReason = - when (val it = request.reason) { + when (request.reason) { ComputationDetails.CompletedReason.SUCCEEDED -> EndComputationReason.SUCCEEDED ComputationDetails.CompletedReason.FAILED -> EndComputationReason.FAILED ComputationDetails.CompletedReason.CANCELED -> EndComputationReason.CANCELED - else -> error("Unknown CompletedReason $it") + else -> error("Unknown CompletedReason ${request.reason}") }, computationDetails = request.token.computationDetails, clock = clock, @@ -357,7 +358,6 @@ class PostgresComputationsService( override suspend fun recordOutputBlobPath( request: RecordOutputBlobPathRequest ): RecordOutputBlobPathResponse { - RecordOutputBlobPath( clock = clock, localId = request.token.localComputationId, @@ -382,14 +382,17 @@ class PostgresComputationsService( val lockExtension: Duration = if (request.hasLockExtension()) request.lockExtension.toDuration() else defaultLockDuration val afterTransition = - when (val it = request.afterTransition) { + when (request.afterTransition) { AdvanceComputationStageRequest.AfterTransition.ADD_UNCLAIMED_TO_QUEUE -> AfterTransition.ADD_UNCLAIMED_TO_QUEUE AdvanceComputationStageRequest.AfterTransition.DO_NOT_ADD_TO_QUEUE -> AfterTransition.DO_NOT_ADD_TO_QUEUE AdvanceComputationStageRequest.AfterTransition.RETAIN_AND_EXTEND_LOCK -> AfterTransition.CONTINUE_WORKING - else -> error("Unsupported AdvanceComputationStageRequest.AfterTransition '$it'. ") + else -> + error( + "Unsupported AdvanceComputationStageRequest.AfterTransition '${request.afterTransition}'. " + ) } AdvanceComputationStage( @@ -425,7 +428,7 @@ class PostgresComputationsService( request: GetComputationIdsRequest ): GetComputationIdsResponse { val ids = computationReader.readGlobalComputationIds(client.singleUse(), request.stagesList) - return GetComputationIdsResponse.newBuilder().addAllGlobalIds(ids).build() + return getComputationIdsResponse { globalIds += ids } } override suspend fun enqueueComputation( From 02fee232d307642c03f0e620c024ede2c3d69df8 Mon Sep 17 00:00:00 2001 From: Yuhong Wang Date: Mon, 31 Jul 2023 19:03:00 +0000 Subject: [PATCH 04/12] feedbacks --- .../db/computation/ComputationEditToken.kt | 7 +- .../testing/FakeComputationsDatabase.kt | 3 +- .../postgres/PostgresComputationsService.kt | 243 +++++++++--------- .../postgres/readers/ComputationReader.kt | 153 +++++++++-- .../writers/AdvanceComputationStage.kt | 11 +- .../common/postgres/writers/ClaimWork.kt | 106 ++------ .../postgres/writers/CreateComputation.kt | 13 +- .../postgres/writers/FinishComputation.kt | 11 +- .../postgres/writers/RecordOutputBlobPath.kt | 28 +- .../writers/RecordRequisitionBlobPath.kt | 11 +- .../writers/UpdateComputationDetails.kt | 21 +- .../internal/DuchyInternalException.kt | 7 + .../computations/ComputationsService.kt | 6 +- .../testing/ComputationsServiceTest.kt | 159 ++++++++---- .../internal/duchy/error_code.proto | 3 + .../PostgresComputationStatsServiceTest.kt | 4 +- .../PostgresComputationsServiceTest.kt | 4 +- 17 files changed, 479 insertions(+), 311 deletions(-) diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/ComputationEditToken.kt b/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/ComputationEditToken.kt index 24b2a4dd6e2..f87e89a2c69 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/ComputationEditToken.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/ComputationEditToken.kt @@ -34,7 +34,9 @@ data class ComputationEditToken( * increasing number used as a guardrail to protect against concurrent edits to the same * computation. */ - val editVersion: Long + val editVersion: Long, + /** The global identifier for the computation. */ + val globalId: String, ) fun ComputationToken.toDatabaseEditToken(): @@ -48,6 +50,7 @@ fun ComputationToken.toDatabaseEditToken(): protocol = protocol, stage = computationStage, attempt = attempt, - editVersion = version + editVersion = version, + globalId = globalComputationId ) } diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/testing/FakeComputationsDatabase.kt b/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/testing/FakeComputationsDatabase.kt index 1c60675b339..6a835f839ac 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/testing/FakeComputationsDatabase.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/db/computation/testing/FakeComputationsDatabase.kt @@ -355,7 +355,8 @@ private constructor( }, stage = it.computationStage, attempt = it.attempt, - editVersion = it.version + editVersion = it.version, + globalId = it.globalComputationId, ) } .firstOrNull() diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresComputationsService.kt b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresComputationsService.kt index 8b3f0f31e06..aaeacbae5d7 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresComputationsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresComputationsService.kt @@ -53,6 +53,7 @@ import org.wfanet.measurement.duchy.service.internal.ComputationAlreadyExistsExc import org.wfanet.measurement.duchy.service.internal.ComputationDetailsNotFoundException import org.wfanet.measurement.duchy.service.internal.ComputationInitialStageInvalidException import org.wfanet.measurement.duchy.service.internal.ComputationNotFoundException +import org.wfanet.measurement.duchy.service.internal.UnknownDataError import org.wfanet.measurement.duchy.service.internal.computations.toAdvanceComputationStageResponse import org.wfanet.measurement.duchy.service.internal.computations.toClaimWorkResponse import org.wfanet.measurement.duchy.service.internal.computations.toCreateComputationResponse @@ -95,6 +96,7 @@ import org.wfanet.measurement.internal.duchy.UpdateComputationDetailsRequest import org.wfanet.measurement.internal.duchy.UpdateComputationDetailsResponse import org.wfanet.measurement.internal.duchy.getComputationIdsResponse import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggregationV2 +import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggregationV2.Stage import org.wfanet.measurement.internal.duchy.purgeComputationsResponse import org.wfanet.measurement.system.v1alpha.ComputationLogEntriesGrpcKt.ComputationLogEntriesCoroutineStub import org.wfanet.measurement.system.v1alpha.ComputationParticipantKey @@ -115,8 +117,8 @@ class PostgresComputationsService( private val client: DatabaseClient, private val idGenerator: IdGenerator, private val duchyName: String, - private val computationStorageClient: ComputationStore, - private val requisitionStorageClient: RequisitionStore, + private val computationStore: ComputationStore, + private val requisitionStore: RequisitionStore, private val computationLogEntriesClient: ComputationLogEntriesCoroutineStub, private val clock: Clock = Clock.systemUTC(), private val defaultLockDuration: Duration = Duration.ofMinutes(5), @@ -133,29 +135,30 @@ class PostgresComputationsService( "global_computation_id is not specified." } - try { - CreateComputation( - request.globalComputationId, - request.computationType, - protocolStagesEnumHelper.getValidInitialStage(request.computationType).first(), - request.stageDetails, - request.computationDetails, - request.requisitionsList, - clock, - computationTypeEnumHelper, - protocolStagesEnumHelper, - computationProtocolStageDetailsHelper - ) - .execute(client, idGenerator) - } catch (ex: ComputationInitialStageInvalidException) { - throw ex.asStatusRuntimeException(Status.Code.INVALID_ARGUMENT) - } catch (ex: ComputationAlreadyExistsException) { - throw ex.asStatusRuntimeException(Status.Code.ALREADY_EXISTS) - } - val token = - computationReader.readComputationToken(client, request.globalComputationId) - ?: failGrpc(Status.INTERNAL) { "Created computation not found." } + try { + CreateComputation( + request.globalComputationId, + request.computationType, + protocolStagesEnumHelper.getValidInitialStage(request.computationType).first(), + request.stageDetails, + request.computationDetails, + request.requisitionsList, + clock, + computationTypeEnumHelper, + protocolStagesEnumHelper, + computationProtocolStageDetailsHelper, + computationReader + ) + .execute(client, idGenerator) + } catch (ex: ComputationInitialStageInvalidException) { + throw ex.asStatusRuntimeException(Status.Code.INVALID_ARGUMENT) + } catch (ex: ComputationAlreadyExistsException) { + throw ex.asStatusRuntimeException(Status.Code.ALREADY_EXISTS) + } catch (ex: UnknownDataError) { + throw ex.asStatusRuntimeException(Status.Code.INTERNAL) + } + return token.toCreateComputationResponse() } @@ -164,7 +167,7 @@ class PostgresComputationsService( val lockDuration = if (request.hasLockDuration()) request.lockDuration.toDuration() else defaultLockDuration - val claimed = + val claimedToken = try { ClaimWork( request.computationType, @@ -173,6 +176,7 @@ class PostgresComputationsService( clock, computationTypeEnumHelper, protocolStagesEnumHelper, + computationReader, ) .execute(client, idGenerator) } catch (e: ComputationNotFoundException) { @@ -181,21 +185,18 @@ class PostgresComputationsService( throw e.asStatusRuntimeException(Status.Code.INTERNAL) } - return claimed?.let { - val token = - ComputationReader(protocolStagesEnumHelper).readComputationToken(client, it) - ?: failGrpc(Status.INTERNAL) { "Claimed computation $claimed not found." } - + if (claimedToken != null) { sendStatusUpdateToKingdom( newCreateComputationLogEntryRequest( - token.globalComputationId, - token.computationStage, - token.attempt.toLong() + claimedToken.globalComputationId, + claimedToken.computationStage, + claimedToken.attempt.toLong() ) ) - token.toClaimWorkResponse() + return claimedToken.toClaimWorkResponse() } - ?: ClaimWorkResponse.getDefaultInstance() + + return ClaimWorkResponse.getDefaultInstance() } override suspend fun getComputationToken( @@ -223,7 +224,7 @@ class PostgresComputationsService( ) for (blobKey in computationBlobKeys) { try { - computationStorageClient.get(blobKey)?.delete() + computationStore.get(blobKey)?.delete() } catch (e: StatusException) { if (e.status.code != Status.Code.NOT_FOUND) { throw e @@ -235,7 +236,7 @@ class PostgresComputationsService( requisitionReader.readRequisitionBlobKeys(client.singleUse(), request.localComputationId) for (blobKey in requisitionBlobKeys) { try { - requisitionStorageClient.get(blobKey)?.delete() + requisitionStore.get(blobKey)?.delete() } catch (e: StatusException) { if (e.status.code != Status.NOT_FOUND.code) { throw e @@ -250,12 +251,19 @@ class PostgresComputationsService( override suspend fun purgeComputations( request: PurgeComputationsRequest ): PurgeComputationsResponse { + val terminalStages = + request.stagesList.filter { + protocolStagesEnumHelper.validTerminalStage( + protocolStagesEnumHelper.stageToProtocol(it), + it + ) + } var deleted = 0 try { val globalIds: Set = computationReader.readGlobalComputationIds( client.singleUse(), - request.stagesList, + terminalStages, request.updatedBefore.toInstant() ) if (!request.force) { @@ -278,6 +286,7 @@ class PostgresComputationsService( clock = clock, protocolStagesEnumHelper = protocolStagesEnumHelper, protocolStageDetailsHelper = computationProtocolStageDetailsHelper, + computationReader = computationReader, ) .execute(client, idGenerator) sendStatusUpdateToKingdom( @@ -299,22 +308,28 @@ class PostgresComputationsService( override suspend fun finishComputation( request: FinishComputationRequest ): FinishComputationResponse { - FinishComputation( - request.token.toDatabaseEditToken(), - endingStage = request.endingComputationStage, - endComputationReason = - when (request.reason) { - ComputationDetails.CompletedReason.SUCCEEDED -> EndComputationReason.SUCCEEDED - ComputationDetails.CompletedReason.FAILED -> EndComputationReason.FAILED - ComputationDetails.CompletedReason.CANCELED -> EndComputationReason.CANCELED - else -> error("Unknown CompletedReason ${request.reason}") - }, - computationDetails = request.token.computationDetails, - clock = clock, - protocolStagesEnumHelper = protocolStagesEnumHelper, - protocolStageDetailsHelper = computationProtocolStageDetailsHelper, - ) - .execute(client, idGenerator) + val token = + try { + FinishComputation( + request.token.toDatabaseEditToken(), + endingStage = request.endingComputationStage, + endComputationReason = + when (request.reason) { + ComputationDetails.CompletedReason.SUCCEEDED -> EndComputationReason.SUCCEEDED + ComputationDetails.CompletedReason.FAILED -> EndComputationReason.FAILED + ComputationDetails.CompletedReason.CANCELED -> EndComputationReason.CANCELED + else -> error("Unknown CompletedReason ${request.reason}") + }, + computationDetails = request.token.computationDetails, + clock = clock, + protocolStagesEnumHelper = protocolStagesEnumHelper, + protocolStageDetailsHelper = computationProtocolStageDetailsHelper, + computationReader = computationReader, + ) + .execute(client, idGenerator) + } catch (ex: UnknownDataError) { + throw ex.asStatusRuntimeException(Status.Code.INTERNAL) + } sendStatusUpdateToKingdom( newCreateComputationLogEntryRequest( @@ -323,12 +338,6 @@ class PostgresComputationsService( ) ) - val token = - computationReader.readComputationToken(client, request.token.globalComputationId) - ?: failGrpc(Status.INTERNAL) { - "Finished computation ${request.token.globalComputationId} not found." - } - return token.toFinishComputationResponse() } @@ -338,41 +347,41 @@ class PostgresComputationsService( require(request.token.computationDetails.protocolCase == request.details.protocolCase) { "The protocol type cannot change." } - UpdateComputationDetails( - clock = clock, - localId = request.token.localComputationId, - editVersion = request.token.version, - computationDetails = request.details, - requisitionEntries = request.requisitionsList - ) - .execute(client, idGenerator) val token = - computationReader.readComputationToken(client, request.token.globalComputationId) - ?: failGrpc(Status.INTERNAL) { - "Updated computation ${request.token.globalComputationId} not found." - } + try { + UpdateComputationDetails( + token = request.token.toDatabaseEditToken(), + clock = clock, + computationDetails = request.details, + requisitionEntries = request.requisitionsList, + computationReader = computationReader + ) + .execute(client, idGenerator) + } catch (ex: UnknownDataError) { + throw ex.asStatusRuntimeException(Status.Code.INTERNAL) + } + return token.toUpdateComputationDetailsResponse() } override suspend fun recordOutputBlobPath( request: RecordOutputBlobPathRequest ): RecordOutputBlobPathResponse { - RecordOutputBlobPath( - clock = clock, - localId = request.token.localComputationId, - editVersion = request.token.version, - stage = request.token.computationStage, - blobRef = BlobRef(request.outputBlobId, request.blobPath), - protocolStagesEnumHelper = protocolStagesEnumHelper - ) - .execute(client, idGenerator) - val token = - computationReader.readComputationToken(client, request.token.globalComputationId) - ?: failGrpc(Status.INTERNAL) { - "Computation ${request.token.globalComputationId} not found." - } + try { + RecordOutputBlobPath( + token = request.token.toDatabaseEditToken(), + clock = clock, + blobRef = BlobRef(request.outputBlobId, request.blobPath), + protocolStagesEnumHelper = protocolStagesEnumHelper, + computationReader = computationReader, + ) + .execute(client, idGenerator) + } catch (ex: UnknownDataError) { + throw ex.asStatusRuntimeException(Status.Code.INTERNAL) + } + return token.toRecordOutputBlobPathResponse() } @@ -395,19 +404,25 @@ class PostgresComputationsService( ) } - AdvanceComputationStage( - request.token.toDatabaseEditToken(), - nextStage = request.nextComputationStage, - nextStageDetails = request.stageDetails, - inputBlobPaths = request.inputBlobsList, - passThroughBlobPaths = request.passThroughBlobsList, - outputBlobs = request.outputBlobs, - afterTransition = afterTransition, - lockExtension = lockExtension, - clock = clock, - protocolStagesEnumHelper = protocolStagesEnumHelper - ) - .execute(client, idGenerator) + val token = + try { + AdvanceComputationStage( + request.token.toDatabaseEditToken(), + nextStage = request.nextComputationStage, + nextStageDetails = request.stageDetails, + inputBlobPaths = request.inputBlobsList, + passThroughBlobPaths = request.passThroughBlobsList, + outputBlobs = request.outputBlobs, + afterTransition = afterTransition, + lockExtension = lockExtension, + clock = clock, + protocolStagesEnumHelper = protocolStagesEnumHelper, + computationReader = computationReader, + ) + .execute(client, idGenerator) + } catch (ex: UnknownDataError) { + throw ex.asStatusRuntimeException(Status.Code.INTERNAL) + } sendStatusUpdateToKingdom( newCreateComputationLogEntryRequest( @@ -416,11 +431,6 @@ class PostgresComputationsService( ) ) - val token = - computationReader.readComputationToken(client, request.token.globalComputationId) - ?: failGrpc(Status.INTERNAL) { - "Computation ${request.token.globalComputationId} not found." - } return token.toAdvanceComputationStageResponse() } @@ -450,19 +460,20 @@ class PostgresComputationsService( override suspend fun recordRequisitionBlobPath( request: RecordRequisitionBlobPathRequest ): RecordRequisitionBlobPathResponse { - RecordRequisitionBlobPath( - clock = clock, - localId = request.token.localComputationId, - externalRequisitionKey = request.key, - pathToBlob = request.blobPath - ) - .execute(client, idGenerator) - val token = - computationReader.readComputationToken(client, request.token.globalComputationId) - ?: failGrpc(Status.INTERNAL) { - "Computation ${request.token.globalComputationId} not found." - } + try { + RecordRequisitionBlobPath( + clock = clock, + localId = request.token.localComputationId, + externalRequisitionKey = request.key, + pathToBlob = request.blobPath, + computationReader = computationReader, + ) + .execute(client, idGenerator) + } catch (ex: UnknownDataError) { + throw ex.asStatusRuntimeException(Status.Code.INTERNAL) + } + return token.toRecordRequisitionBlobPathResponse() } diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/readers/ComputationReader.kt b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/readers/ComputationReader.kt index 61141d0fd17..cf076bb5835 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/readers/ComputationReader.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/readers/ComputationReader.kt @@ -16,6 +16,7 @@ package org.wfanet.measurement.duchy.deploy.common.postgres.readers import com.google.protobuf.Timestamp import java.time.Instant +import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.firstOrNull import kotlinx.coroutines.flow.toSet import org.wfanet.measurement.common.db.r2dbc.DatabaseClient @@ -78,6 +79,29 @@ class ComputationReader( ) } + data class UnclaimedTaskQueryResult( + val computationId: Long, + val globalId: String, + val computationStage: Long, + val creationTime: Instant, + val updateTime: Instant, + val nextAttempt: Long + ) + + private fun buildUnclaimedTaskQueryResult(row: ResultRow): UnclaimedTaskQueryResult = + UnclaimedTaskQueryResult( + row["ComputationId"], + row["GlobalComputationId"], + row["ComputationStage"], + row["CreationTime"], + row["UpdateTime"], + row["NextAttempt"] + ) + + data class LockOwnerQueryResult(val lockOwner: String?, val updateTime: Instant) { + constructor(row: ResultRow) : this(lockOwner = row["LockOwner"], updateTime = row["UpdateTime"]) + } + private fun buildComputationToken( computation: Computation, blobs: List, @@ -169,6 +193,31 @@ class ComputationReader( return readContext.executeQuery(statement).consume(::Computation).firstOrNull() } + /** + * Reads a [ComputationToken] by globalComputationId. + * + * @param readContext The [ReadContext] for reading from the Postgres database. + * @param globalComputationId A global identifier for a computation. + * @return [ComputationToken] when a Computation with globalComputationId is found, or null. + */ + suspend fun readComputationToken( + readContext: ReadContext, + globalComputationId: String + ): ComputationToken? { + val computation: Computation = readComputation(readContext, globalComputationId) ?: return null + + val blobs = + blobReferenceReader.readBlobMetadata( + readContext, + computation.localComputationId, + computation.computationStage + ) + val requisitions = + requisitionReader.readRequisitionMetadata(readContext, computation.localComputationId) + + return buildComputationToken(computation, blobs, requisitions) + } + /** * Reads a [ComputationToken] by globalComputationId. * @@ -182,24 +231,37 @@ class ComputationReader( ): ComputationToken? { val readContext = client.readTransaction() try { - val computation: Computation = - readComputation(readContext, globalComputationId) ?: return null - - val blobs = - blobReferenceReader.readBlobMetadata( - readContext, - computation.localComputationId, - computation.computationStage - ) - val requisitions = - requisitionReader.readRequisitionMetadata(readContext, computation.localComputationId) - - return buildComputationToken(computation, blobs, requisitions) + return readComputationToken(readContext, globalComputationId) } finally { readContext.close() } } + /** + * Reads a [ComputationToken] by externalRequisitionKey. + * + * @param readContext The [ReadContext] for reading from the Postgres database. + * @param externalRequisitionKey The [ExternalRequisitionKey] for a computation. + * @return [ComputationToken] when a Computation with externalRequisitionKey is found, or null. + */ + suspend fun readComputationToken( + readContext: ReadContext, + externalRequisitionKey: ExternalRequisitionKey + ): ComputationToken? { + val computation = readComputation(readContext, externalRequisitionKey) ?: return null + + val blobs = + blobReferenceReader.readBlobMetadata( + readContext, + computation.localComputationId, + computation.computationStage + ) + val requisitions = + requisitionReader.readRequisitionMetadata(readContext, computation.localComputationId) + + return buildComputationToken(computation, blobs, requisitions) + } + /** * Reads a [ComputationToken] by externalRequisitionKey. * @@ -213,18 +275,7 @@ class ComputationReader( ): ComputationToken? { val readContext = client.readTransaction() try { - val computation = readComputation(readContext, externalRequisitionKey) ?: return null - - val blobs = - blobReferenceReader.readBlobMetadata( - readContext, - computation.localComputationId, - computation.computationStage - ) - val requisitions = - requisitionReader.readRequisitionMetadata(readContext, computation.localComputationId) - - return buildComputationToken(computation, blobs, requisitions) + return readComputationToken(readContext, externalRequisitionKey) } finally { readContext.close() } @@ -285,4 +336,56 @@ class ComputationReader( .consume { row -> row.get("GlobalComputationId") } .toSet() } + + /** + * Reads a list of unclaimed computation tasks + * + * @param readContext The transaction context for reading from the Postgres database. + * @param protocol The enum value of target computation type. + * @param timestamp An [Instant] to filter for the expired computation locks. + * @return a flow of [UnclaimedTaskQueryResult] + */ + suspend fun listUnclaimedTasks( + readContext: ReadContext, + protocol: Long, + timestamp: Instant, + ): Flow { + val listUnclaimedTasksSql = + boundStatement( + """ + SELECT c.ComputationId, c.GlobalComputationId, + c.Protocol, c.ComputationStage, c.UpdateTime, + c.CreationTime, cs.NextAttempt + FROM Computations AS c + JOIN ComputationStages AS cs + ON c.ComputationId = cs.ComputationId + AND c.ComputationStage = cs.ComputationStage + WHERE c.Protocol = $1 + AND c.LockExpirationTime IS NOT NULL + AND c.LockExpirationTime <= $2 + ORDER BY c.CreationTime ASC, c.LockExpirationTime ASC, c.UpdateTime ASC + LIMIT 50; + """ + ) { + bind("$1", protocol) + bind("$2", timestamp) + } + + return readContext.executeQuery(listUnclaimedTasksSql).consume(::buildUnclaimedTaskQueryResult) + } + + suspend fun readLockOwner(readContext: ReadContext, computationId: Long): LockOwnerQueryResult? { + val readLockOwnerSql = + boundStatement( + """ + SELECT LockOwner, UpdateTime + FROM Computations + WHERE + ComputationId = $1; + """ + ) { + bind("$1", computationId) + } + return readContext.executeQuery(readLockOwnerSql).consume(::LockOwnerQueryResult).firstOrNull() + } } diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/AdvanceComputationStage.kt b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/AdvanceComputationStage.kt index 5c7c3874fc7..69b0db28fd7 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/AdvanceComputationStage.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/AdvanceComputationStage.kt @@ -23,9 +23,12 @@ import org.wfanet.measurement.duchy.db.computation.AfterTransition import org.wfanet.measurement.duchy.db.computation.ComputationEditToken import org.wfanet.measurement.duchy.db.computation.ComputationProtocolStagesEnumHelper import org.wfanet.measurement.duchy.deploy.common.postgres.readers.ComputationBlobReferenceReader +import org.wfanet.measurement.duchy.deploy.common.postgres.readers.ComputationReader import org.wfanet.measurement.duchy.deploy.common.postgres.readers.ComputationStageAttemptReader +import org.wfanet.measurement.duchy.service.internal.UnknownDataError import org.wfanet.measurement.internal.duchy.ComputationBlobDependency import org.wfanet.measurement.internal.duchy.ComputationStageAttemptDetails +import org.wfanet.measurement.internal.duchy.ComputationToken import org.wfanet.measurement.internal.duchy.copy /** @@ -53,8 +56,9 @@ class AdvanceComputationStage( private val lockExtension: Duration, private val clock: Clock, private val protocolStagesEnumHelper: ComputationProtocolStagesEnumHelper, -) : PostgresWriter() { - override suspend fun TransactionScope.runTransaction() { + private val computationReader: ComputationReader, +) : PostgresWriter() { + override suspend fun TransactionScope.runTransaction(): ComputationToken { val currentStage = token.stage val localId = token.localId val editVersion = token.editVersion @@ -186,5 +190,8 @@ class AdvanceComputationStage( dependencyType = ComputationBlobDependency.OUTPUT ) } + + return computationReader.readComputationToken(transactionContext, token.globalId) + ?: throw UnknownDataError() } } diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/ClaimWork.kt b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/ClaimWork.kt index 201c6c08ea6..ef1e784a356 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/ClaimWork.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/ClaimWork.kt @@ -16,19 +16,17 @@ package org.wfanet.measurement.duchy.deploy.common.postgres.writers import java.time.Clock import java.time.Duration -import java.time.Instant -import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.filter import kotlinx.coroutines.flow.firstOrNull -import org.wfanet.measurement.common.db.r2dbc.ResultRow -import org.wfanet.measurement.common.db.r2dbc.boundStatement import org.wfanet.measurement.common.db.r2dbc.postgres.PostgresWriter import org.wfanet.measurement.duchy.db.computation.ComputationProtocolStagesEnumHelper -import org.wfanet.measurement.duchy.db.computation.ComputationStageLongValues import org.wfanet.measurement.duchy.db.computation.ComputationTypeEnumHelper +import org.wfanet.measurement.duchy.deploy.common.postgres.readers.ComputationReader import org.wfanet.measurement.duchy.deploy.common.postgres.readers.ComputationStageAttemptReader import org.wfanet.measurement.duchy.service.internal.ComputationNotFoundException +import org.wfanet.measurement.duchy.service.internal.UnknownDataError import org.wfanet.measurement.internal.duchy.ComputationStageAttemptDetails +import org.wfanet.measurement.internal.duchy.ComputationToken import org.wfanet.measurement.internal.duchy.copy /** @@ -54,36 +52,13 @@ class ClaimWork( private val clock: Clock, private val computationTypeEnumHelper: ComputationTypeEnumHelper, private val protocolStagesEnumHelper: ComputationProtocolStagesEnumHelper, -) : PostgresWriter() { + private val computationReader: ComputationReader, +) : PostgresWriter() { - private data class UnclaimedTaskQueryResult( - val computationId: Long, - val globalId: String, - val computationStage: StageT, - val creationTime: Instant, - val updateTime: Instant, - val nextAttempt: Long - ) - - private fun buildUnclaimedTaskQueryResult(row: ResultRow): UnclaimedTaskQueryResult = - UnclaimedTaskQueryResult( - row["ComputationId"], - row["GlobalComputationId"], - protocolStagesEnumHelper.longValuesToComputationStageEnum( - ComputationStageLongValues(row["Protocol"], row["ComputationStage"]) - ), - row["CreationTime"], - row["UpdateTime"], - row["NextAttempt"] - ) - - private data class LockOwnerQueryResult(val lockOwner: String?, val updateTime: Instant) - - private fun buildLockOwnerQueryResult(row: ResultRow): LockOwnerQueryResult = - LockOwnerQueryResult(lockOwner = row["LockOwner"], updateTime = row["UpdateTime"]) - - override suspend fun TransactionScope.runTransaction(): String? { - return listUnclaimedTasks(protocol, clock.instant()) + override suspend fun TransactionScope.runTransaction(): ComputationToken? { + val protocolEnum = computationTypeEnumHelper.protocolEnumToLong(protocol) + return computationReader + .listUnclaimedTasks(transactionContext, protocolEnum, clock.instant()) // First the possible tasks to claim are selected from the computations table, then for each // item in the list we try to claim the lock in a transaction which will only succeed if the // lock is still available. This pattern means only the item which is being updated @@ -91,37 +66,10 @@ class ClaimWork( .filter { claim(it) } // If the value is null, no tasks were claimed. .firstOrNull() - ?.globalId - } - - private suspend fun TransactionScope.listUnclaimedTasks( - protocol: ProtocolT, - timestamp: Instant - ): Flow> { - val listUnclaimedTasksSql = - boundStatement( - """ - SELECT c.ComputationId, c.GlobalComputationId, - c.Protocol, c.ComputationStage, c.UpdateTime, - c.CreationTime, cs.NextAttempt - FROM Computations AS c - JOIN ComputationStages AS cs - ON c.ComputationId = cs.ComputationId - AND c.ComputationStage = cs.ComputationStage - WHERE c.Protocol = $1 - AND c.LockExpirationTime IS NOT NULL - AND c.LockExpirationTime <= $2 - ORDER BY c.CreationTime ASC, c.LockExpirationTime ASC, c.UpdateTime ASC - LIMIT 50; - """ - ) { - bind("$1", computationTypeEnumHelper.protocolEnumToLong(protocol)) - bind("$2", timestamp) + ?.let { + computationReader.readComputationToken(transactionContext, it.globalId) + ?: throw UnknownDataError() } - - return transactionContext - .executeQuery(listUnclaimedTasksSql) - .consume(::buildUnclaimedTaskQueryResult) } /** @@ -129,9 +77,11 @@ class ClaimWork( * lock is acquired a new row is written to the ComputationStageAttempts table. */ private suspend fun TransactionScope.claim( - unclaimedTask: UnclaimedTaskQueryResult + unclaimedTask: ComputationReader.UnclaimedTaskQueryResult ): Boolean { - val currentLockOwner = readLockOwner(unclaimedTask.computationId) + val currentLockOwner = + computationReader.readLockOwner(transactionContext, unclaimedTask.computationId) + ?: throw UnknownDataError() // Verify that the row hasn't been updated since the previous, non-transactional read. // If it has been updated since that time the lock should not be acquired. if (currentLockOwner.updateTime != unclaimedTask.updateTime) return false @@ -143,10 +93,7 @@ class ClaimWork( ownerId, writeTime.plus(lockDuration) ) - val stageLongValue = - protocolStagesEnumHelper - .computationStageEnumToLongValues(unclaimedTask.computationStage) - .stage + val stageLongValue = unclaimedTask.computationStage insertComputationStageAttempt( unclaimedTask.computationId, @@ -188,23 +135,4 @@ class ClaimWork( // The lock was acquired. return true } - - private suspend fun TransactionScope.readLockOwner(computationId: Long): LockOwnerQueryResult { - val readLockOwnerSql = - boundStatement( - """ - SELECT LockOwner, UpdateTime - FROM Computations - WHERE - ComputationId = $1; - """ - ) { - bind("$1", computationId) - } - return transactionContext - .executeQuery(readLockOwnerSql) - .consume(::buildLockOwnerQueryResult) - .firstOrNull() - ?: throw ComputationNotFoundException(computationId) - } } diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/CreateComputation.kt b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/CreateComputation.kt index 6051ba6bc8e..b9550443d45 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/CreateComputation.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/CreateComputation.kt @@ -20,11 +20,14 @@ import org.wfanet.measurement.common.db.r2dbc.postgres.PostgresWriter import org.wfanet.measurement.duchy.db.computation.ComputationProtocolStageDetailsHelper import org.wfanet.measurement.duchy.db.computation.ComputationProtocolStagesEnumHelper import org.wfanet.measurement.duchy.db.computation.ComputationTypeEnumHelper +import org.wfanet.measurement.duchy.deploy.common.postgres.readers.ComputationReader import org.wfanet.measurement.duchy.service.internal.ComputationAlreadyExistsException import org.wfanet.measurement.duchy.service.internal.ComputationInitialStageInvalidException import org.wfanet.measurement.duchy.service.internal.DuchyInternalException +import org.wfanet.measurement.duchy.service.internal.UnknownDataError import org.wfanet.measurement.internal.duchy.ComputationDetails import org.wfanet.measurement.internal.duchy.ComputationStageDetails +import org.wfanet.measurement.internal.duchy.ComputationToken import org.wfanet.measurement.internal.duchy.RequisitionEntry /** @@ -46,7 +49,7 @@ import org.wfanet.measurement.internal.duchy.RequisitionEntry * * [ComputationAlreadyExistsException] when there exists a computation with this * globalComputationId */ -class CreateComputation( +class CreateComputation( private val globalId: String, private val protocol: ProtocolT, private val initialStage: StageT, @@ -59,9 +62,10 @@ class CreateComputation, private val computationProtocolStageDetailsHelper: ComputationProtocolStageDetailsHelper, -) : PostgresWriter() { + private val computationReader: ComputationReader, +) : PostgresWriter() { - override suspend fun TransactionScope.runTransaction() { + override suspend fun TransactionScope.runTransaction(): ComputationToken { if (!computationProtocolStagesEnumHelper.validInitialStage(protocol, initialStage)) { throw ComputationInitialStageInvalidException(protocol.toString(), initialStage.toString()) } @@ -107,5 +111,8 @@ class CreateComputation, private val protocolStageDetailsHelper: ComputationProtocolStageDetailsHelper, -) : PostgresWriter() { - override suspend fun TransactionScope.runTransaction() { + private val computationReader: ComputationReader, +) : PostgresWriter() { + override suspend fun TransactionScope.runTransaction(): ComputationToken { val protocol = token.protocol val localId = token.localId val editVersion = token.editVersion @@ -132,6 +136,9 @@ class FinishComputation( - private val localId: Long, - private val editVersion: Long, - private val stage: StageT, + private val token: ComputationEditToken, private val blobRef: BlobRef, private val clock: Clock, private val protocolStagesEnumHelper: ComputationProtocolStagesEnumHelper, -) : PostgresWriter() { - override suspend fun TransactionScope.runTransaction() { + private val computationReader: ComputationReader, +) : PostgresWriter() { + override suspend fun TransactionScope.runTransaction(): ComputationToken { require(blobRef.key.isNotBlank()) { "Cannot insert blank path to blob. $blobRef" } - checkComputationUnmodified(localId, editVersion) + val localId = token.localId + val stage = token.stage - val stageLongValue = protocolStagesEnumHelper.computationStageEnumToLongValues(stage).stage + checkComputationUnmodified(localId, token.editVersion) + + val stageLongValue = + protocolStagesEnumHelper.computationStageEnumToLongValues(token.stage).stage val type: ComputationBlobDependency = ComputationBlobReferenceReader() .readBlobDependency( @@ -67,5 +72,8 @@ class RecordOutputBlobPath( blobId = blobRef.idInRelationalDatabase, pathToBlob = blobRef.key ) + + return computationReader.readComputationToken(transactionContext, token.globalId) + ?: throw UnknownDataError() } } diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/RecordRequisitionBlobPath.kt b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/RecordRequisitionBlobPath.kt index be8b44b9c47..dac8223d2d1 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/RecordRequisitionBlobPath.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/RecordRequisitionBlobPath.kt @@ -16,7 +16,10 @@ package org.wfanet.measurement.duchy.deploy.common.postgres.writers import java.time.Clock import org.wfanet.measurement.common.db.r2dbc.postgres.PostgresWriter +import org.wfanet.measurement.duchy.deploy.common.postgres.readers.ComputationReader import org.wfanet.measurement.duchy.deploy.common.postgres.readers.RequisitionReader +import org.wfanet.measurement.duchy.service.internal.UnknownDataError +import org.wfanet.measurement.internal.duchy.ComputationToken import org.wfanet.measurement.internal.duchy.ExternalRequisitionKey /** @@ -35,8 +38,9 @@ class RecordRequisitionBlobPath( private val externalRequisitionKey: ExternalRequisitionKey, private val pathToBlob: String, private val clock: Clock, -) : PostgresWriter() { - override suspend fun TransactionScope.runTransaction() { + private val computationReader: ComputationReader, +) : PostgresWriter() { + override suspend fun TransactionScope.runTransaction(): ComputationToken { require(pathToBlob.isNotBlank()) { "Cannot insert blank path to blob. $externalRequisitionKey" } val requisition: RequisitionReader.RequisitionResult = RequisitionReader().readRequisitionByExternalKey(transactionContext, externalRequisitionKey) @@ -54,5 +58,8 @@ class RecordRequisitionBlobPath( pathToBlob = pathToBlob, updateTime = writeTime ) + + return computationReader.readComputationToken(transactionContext, externalRequisitionKey) + ?: throw UnknownDataError() } } diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/UpdateComputationDetails.kt b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/UpdateComputationDetails.kt index b237a969015..6d9bc1be617 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/UpdateComputationDetails.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/UpdateComputationDetails.kt @@ -17,7 +17,11 @@ package org.wfanet.measurement.duchy.deploy.common.postgres.writers import com.google.protobuf.Message import java.time.Clock import org.wfanet.measurement.common.db.r2dbc.postgres.PostgresWriter +import org.wfanet.measurement.duchy.db.computation.ComputationEditToken +import org.wfanet.measurement.duchy.deploy.common.postgres.readers.ComputationReader import org.wfanet.measurement.duchy.deploy.common.postgres.readers.RequisitionReader +import org.wfanet.measurement.duchy.service.internal.UnknownDataError +import org.wfanet.measurement.internal.duchy.ComputationToken import org.wfanet.measurement.internal.duchy.RequisitionEntry /** @@ -32,15 +36,15 @@ import org.wfanet.measurement.internal.duchy.RequisitionEntry * Throws following exceptions on [execute]: * * [IllegalStateException] when arguments does not meet requirement */ -class UpdateComputationDetails( - private val localId: Long, - private val editVersion: Long, +class UpdateComputationDetails( + private val token: ComputationEditToken, private val computationDetails: ComputationDT, private val requisitionEntries: List, private val clock: Clock, -) : PostgresWriter() { - override suspend fun TransactionScope.runTransaction() { - checkComputationUnmodified(localId, editVersion) + private val computationReader: ComputationReader, +) : PostgresWriter() { + override suspend fun TransactionScope.runTransaction(): ComputationToken { + checkComputationUnmodified(token.localId, token.editVersion) val writeTime = clock.instant() requisitionEntries.forEach { @@ -56,6 +60,9 @@ class UpdateComputationDetails( updateTime = writeTime ) } - updateComputation(localId = localId, updateTime = writeTime, details = computationDetails) + updateComputation(localId = token.localId, updateTime = writeTime, details = computationDetails) + + return computationReader.readComputationToken(transactionContext, token.globalId) + ?: throw UnknownDataError() } } diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/DuchyInternalException.kt b/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/DuchyInternalException.kt index 42e43bca799..7a50731b6b1 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/DuchyInternalException.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/DuchyInternalException.kt @@ -101,3 +101,10 @@ class ComputationAlreadyExistsException( override val context get() = mapOf("global_computation_id" to globalComputationId) } + +class UnknownDataError( + message: String = "Data corrupted for unknown reasons", +) : DuchyInternalException(ErrorCode.UNKNOWN_DATA_ERROR, message) { + override val context: Map + get() = emptyMap() +} diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/computations/ComputationsService.kt b/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/computations/ComputationsService.kt index 05f4351c05e..24ccfe8d6dc 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/computations/ComputationsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/computations/ComputationsService.kt @@ -163,11 +163,15 @@ class ComputationsService( override suspend fun purgeComputations( request: PurgeComputationsRequest ): PurgeComputationsResponse { + val terminalStages = + request.stagesList.filter { + computationsDatabase.validTerminalStage(computationsDatabase.stageToProtocol(it), it) + } var deleted = 0 try { val globalIds = computationsDatabase.readGlobalComputationIds( - request.stagesList.toSet(), + terminalStages.toSet(), request.updatedBefore.toInstant() ) if (!request.force) { diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/testing/ComputationsServiceTest.kt b/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/testing/ComputationsServiceTest.kt index 584ef2a547c..be9596328f3 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/testing/ComputationsServiceTest.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/testing/ComputationsServiceTest.kt @@ -186,31 +186,44 @@ abstract class ComputationsServiceTest { @Test fun `purgeComputations returns the matched computation IDs when force is false`() = runBlocking { - val computation1 = service.createComputation(DEFAULT_CREATE_COMPUTATION_REQUEST) - val computation2 = - service.createComputation( - DEFAULT_CREATE_COMPUTATION_REQUEST.copy { - globalComputationId = "5678" - requisitions[0] = - requisitions[0].copy { key = key.copy { externalRequisitionId = "5678" } } - } - ) + val token1 = service.createComputation(DEFAULT_CREATE_COMPUTATION_REQUEST).token + val token2 = + service + .createComputation( + DEFAULT_CREATE_COMPUTATION_REQUEST.copy { + globalComputationId = "5678" + requisitions[0] = + requisitions[0].copy { key = key.copy { externalRequisitionId = "5678" } } + } + ) + .token + service.finishComputation( + finishComputationRequest { + token = token1 + endingComputationStage = Stage.COMPLETE.toProtocolStage() + reason = ComputationDetails.CompletedReason.SUCCEEDED + } + ) + service.finishComputation( + finishComputationRequest { + token = token2 + endingComputationStage = Stage.COMPLETE.toProtocolStage() + reason = ComputationDetails.CompletedReason.SUCCEEDED + } + ) val currentTime = clock.last() val purgeComputationsResp = service.purgeComputations( purgeComputationsRequest { updatedBefore = currentTime.plusSeconds(1000L).toProtoTime() - stages += Stage.INITIALIZATION_PHASE.toProtocolStage() + stages += Stage.COMPLETE.toProtocolStage() force = false } ) assertThat(purgeComputationsResp.purgeSampleList) - .containsExactly( - computation1.token.globalComputationId, - computation2.token.globalComputationId - ) + .containsExactly(token1.globalComputationId, token2.globalComputationId) assertThat(purgeComputationsResp.purgeCount).isEqualTo(2) } @@ -218,11 +231,21 @@ abstract class ComputationsServiceTest { fun `purgeComputations does not delete Computations when force is false`() = runBlocking { val createTime = clock.last() val createdToken = service.createComputation(DEFAULT_CREATE_COMPUTATION_REQUEST).token + val finishedToken = + service + .finishComputation( + finishComputationRequest { + token = createdToken + endingComputationStage = Stage.COMPLETE.toProtocolStage() + reason = ComputationDetails.CompletedReason.SUCCEEDED + } + ) + .token service.purgeComputations( purgeComputationsRequest { updatedBefore = createTime.plusSeconds(1000L).toProtoTime() - stages += Stage.INITIALIZATION_PHASE.toProtocolStage() + stages += Stage.COMPLETE.toProtocolStage() force = false } ) @@ -234,16 +257,34 @@ abstract class ComputationsServiceTest { ) .token ) - .isEqualTo(createdToken) + .isEqualTo(finishedToken) } @Test fun `purgeComputations only returns the deleted count when force is true`() = runBlocking { - service.createComputation(DEFAULT_CREATE_COMPUTATION_REQUEST) - service.createComputation( - DEFAULT_CREATE_COMPUTATION_REQUEST.copy { - globalComputationId = "5678" - requisitions[0] = requisitions[0].copy { key = key.copy { externalRequisitionId = "5678" } } + val token1 = service.createComputation(DEFAULT_CREATE_COMPUTATION_REQUEST).token + val token2 = + service + .createComputation( + DEFAULT_CREATE_COMPUTATION_REQUEST.copy { + globalComputationId = "5678" + requisitions[0] = + requisitions[0].copy { key = key.copy { externalRequisitionId = "5678" } } + } + ) + .token + service.finishComputation( + finishComputationRequest { + token = token1 + endingComputationStage = Stage.COMPLETE.toProtocolStage() + reason = ComputationDetails.CompletedReason.SUCCEEDED + } + ) + service.finishComputation( + finishComputationRequest { + token = token2 + endingComputationStage = Stage.COMPLETE.toProtocolStage() + reason = ComputationDetails.CompletedReason.SUCCEEDED } ) @@ -252,7 +293,7 @@ abstract class ComputationsServiceTest { service.purgeComputations( purgeComputationsRequest { updatedBefore = currentTime.plusSeconds(1000L).toProtoTime() - stages += Stage.INITIALIZATION_PHASE.toProtocolStage() + stages += Stage.COMPLETE.toProtocolStage() force = true } ) @@ -262,7 +303,7 @@ abstract class ComputationsServiceTest { } @Test - fun `purgeComputations only deletes computations of target stages`() = runBlocking { + fun `purgeComputations only deletes computations of terminal target stages`() = runBlocking { // Creates a computation in WAIT_REQUISITIONS_AND_KEY_SET stage service.createComputation(DEFAULT_CREATE_COMPUTATION_REQUEST) val claimWorkResponse = service.claimWork(DEFAULT_CLAIM_WORK_REQUEST) @@ -275,49 +316,73 @@ abstract class ComputationsServiceTest { afterTransition = AfterTransition.RETAIN_AND_EXTEND_LOCK } service.advanceComputationStage(advanceComputationStageRequest) - // Creates two computations in INITIALIZATION_PHASE stage to be purged - val computationInInitPhase1 = - service.createComputation( - DEFAULT_CREATE_COMPUTATION_REQUEST.copy { - globalComputationId = "3456" - requisitions[0] = - requisitions[0].copy { key = key.copy { externalRequisitionId = "3456" } } - } - ) - val computationInInitPhase2 = - service.createComputation( - DEFAULT_CREATE_COMPUTATION_REQUEST.copy { - globalComputationId = "5678" - requisitions[0] = - requisitions[0].copy { key = key.copy { externalRequisitionId = "5678" } } - } - ) + // Creates two computations in COMPLETE stage to be purged + val token1 = + service + .createComputation( + DEFAULT_CREATE_COMPUTATION_REQUEST.copy { + globalComputationId = "3456" + requisitions[0] = + requisitions[0].copy { key = key.copy { externalRequisitionId = "3456" } } + } + ) + .token + val token2 = + service + .createComputation( + DEFAULT_CREATE_COMPUTATION_REQUEST.copy { + globalComputationId = "5678" + requisitions[0] = + requisitions[0].copy { key = key.copy { externalRequisitionId = "5678" } } + } + ) + .token + service.finishComputation( + finishComputationRequest { + token = token1 + endingComputationStage = Stage.COMPLETE.toProtocolStage() + reason = ComputationDetails.CompletedReason.SUCCEEDED + } + ) + service.finishComputation( + finishComputationRequest { + token = token2 + endingComputationStage = Stage.COMPLETE.toProtocolStage() + reason = ComputationDetails.CompletedReason.SUCCEEDED + } + ) val currentTime = clock.last() val purgeComputationsResp = service.purgeComputations( purgeComputationsRequest { updatedBefore = currentTime.plusSeconds(1000L).toProtoTime() - stages += Stage.INITIALIZATION_PHASE.toProtocolStage() + stages += Stage.COMPLETE.toProtocolStage() + stages += Stage.WAIT_REQUISITIONS_AND_KEY_SET.toProtocolStage() force = false } ) assertThat(purgeComputationsResp.purgeSampleList) - .containsExactly( - computationInInitPhase1.token.globalComputationId, - computationInInitPhase2.token.globalComputationId - ) + .containsExactly(token1.globalComputationId, token2.globalComputationId) assertThat(purgeComputationsResp.purgeCount).isEqualTo(2) } @Test fun `getComputationToken throws NOT_FOUND when computation is purged`() = runBlocking { - service.createComputation(DEFAULT_CREATE_COMPUTATION_REQUEST) + val createdToken = service.createComputation(DEFAULT_CREATE_COMPUTATION_REQUEST).token + service.finishComputation( + finishComputationRequest { + token = createdToken + endingComputationStage = Stage.COMPLETE.toProtocolStage() + reason = ComputationDetails.CompletedReason.SUCCEEDED + } + ) + val currentTime = clock.last() val purgeComputationsRequest = purgeComputationsRequest { updatedBefore = currentTime.plusSeconds(1000L).toProtoTime() - stages += Stage.INITIALIZATION_PHASE.toProtocolStage() + stages += Stage.COMPLETE.toProtocolStage() force = true } service.purgeComputations(purgeComputationsRequest) diff --git a/src/main/proto/wfa/measurement/internal/duchy/error_code.proto b/src/main/proto/wfa/measurement/internal/duchy/error_code.proto index 8d3530c27c4..4ed1d7a1269 100644 --- a/src/main/proto/wfa/measurement/internal/duchy/error_code.proto +++ b/src/main/proto/wfa/measurement/internal/duchy/error_code.proto @@ -39,4 +39,7 @@ enum ErrorCode { /** Computation with the same global ID already exists. */ COMPUTATION_ALREADY_EXISTS = 6; + + /** Data corrupted for unknown reasons. */ + UNKNOWN_DATA_ERROR = 7; } diff --git a/src/test/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresComputationStatsServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresComputationStatsServiceTest.kt index a2d3e280064..3c15218f797 100644 --- a/src/test/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresComputationStatsServiceTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresComputationStatsServiceTest.kt @@ -75,8 +75,8 @@ class PostgresComputationStatsServiceTest : client = client, idGenerator = idGenerator, duchyName = ALSACE, - computationStorageClient = ComputationStore(storageClient), - requisitionStorageClient = RequisitionStore(storageClient), + computationStore = ComputationStore(storageClient), + requisitionStore = RequisitionStore(storageClient), computationLogEntriesClient = systemComputationLogEntriesClient, ) } diff --git a/src/test/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresComputationsServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresComputationsServiceTest.kt index 86f27a1f822..7939b4f4807 100644 --- a/src/test/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresComputationsServiceTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresComputationsServiceTest.kt @@ -69,8 +69,8 @@ class PostgresComputationsServiceTest : ComputationsServiceTest Date: Mon, 31 Jul 2023 19:50:47 +0000 Subject: [PATCH 05/12] feedbacks --- .../postgres/PostgresComputationsService.kt | 9 ++-- .../computations/ComputationsService.kt | 9 ++-- .../testing/ComputationsServiceTest.kt | 20 +++++++- ...annerComputationsDatabaseTransactorTest.kt | 50 ++++++++++++------- 4 files changed, 61 insertions(+), 27 deletions(-) diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresComputationsService.kt b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresComputationsService.kt index aaeacbae5d7..737b2342d44 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresComputationsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresComputationsService.kt @@ -251,19 +251,22 @@ class PostgresComputationsService( override suspend fun purgeComputations( request: PurgeComputationsRequest ): PurgeComputationsResponse { - val terminalStages = - request.stagesList.filter { + grpcRequire( + request.stagesList.all { protocolStagesEnumHelper.validTerminalStage( protocolStagesEnumHelper.stageToProtocol(it), it ) } + ) { + "Requested stage list contains non terminal stage." + } var deleted = 0 try { val globalIds: Set = computationReader.readGlobalComputationIds( client.singleUse(), - terminalStages, + request.stagesList, request.updatedBefore.toInstant() ) if (!request.force) { diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/computations/ComputationsService.kt b/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/computations/ComputationsService.kt index 24ccfe8d6dc..f962855e041 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/computations/ComputationsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/computations/ComputationsService.kt @@ -163,15 +163,18 @@ class ComputationsService( override suspend fun purgeComputations( request: PurgeComputationsRequest ): PurgeComputationsResponse { - val terminalStages = - request.stagesList.filter { + grpcRequire( + request.stagesList.all { computationsDatabase.validTerminalStage(computationsDatabase.stageToProtocol(it), it) } + ) { + "Requested stage list contains non terminal stage." + } var deleted = 0 try { val globalIds = computationsDatabase.readGlobalComputationIds( - terminalStages.toSet(), + request.stagesList.toSet(), request.updatedBefore.toInstant() ) if (!request.force) { diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/testing/ComputationsServiceTest.kt b/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/testing/ComputationsServiceTest.kt index be9596328f3..ebe05842c70 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/testing/ComputationsServiceTest.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/testing/ComputationsServiceTest.kt @@ -303,7 +303,24 @@ abstract class ComputationsServiceTest { } @Test - fun `purgeComputations only deletes computations of terminal target stages`() = runBlocking { + fun `purgeComputations throws INVALID_ARGUMENT exception when target stage is non-terminal`() = + runBlocking { + val currentTime = clock.last() + val purgeComputationsRequest = purgeComputationsRequest { + updatedBefore = currentTime.plusSeconds(1000L).toProtoTime() + stages += Stage.COMPLETE.toProtocolStage() + stages += Stage.INITIALIZATION_PHASE.toProtocolStage() + force = true + } + val exception = + assertFailsWith { + service.purgeComputations(purgeComputationsRequest) + } + assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) + } + + @Test + fun `purgeComputations only deletes computations of target stages`() = runBlocking { // Creates a computation in WAIT_REQUISITIONS_AND_KEY_SET stage service.createComputation(DEFAULT_CREATE_COMPUTATION_REQUEST) val claimWorkResponse = service.claimWork(DEFAULT_CLAIM_WORK_REQUEST) @@ -358,7 +375,6 @@ abstract class ComputationsServiceTest { purgeComputationsRequest { updatedBefore = currentTime.plusSeconds(1000L).toProtoTime() stages += Stage.COMPLETE.toProtocolStage() - stages += Stage.WAIT_REQUISITIONS_AND_KEY_SET.toProtocolStage() force = false } ) diff --git a/src/test/kotlin/org/wfanet/measurement/duchy/deploy/gcloud/spanner/computation/GcpSpannerComputationsDatabaseTransactorTest.kt b/src/test/kotlin/org/wfanet/measurement/duchy/deploy/gcloud/spanner/computation/GcpSpannerComputationsDatabaseTransactorTest.kt index 92ac6159cb3..02fe0baff82 100644 --- a/src/test/kotlin/org/wfanet/measurement/duchy/deploy/gcloud/spanner/computation/GcpSpannerComputationsDatabaseTransactorTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/duchy/deploy/gcloud/spanner/computation/GcpSpannerComputationsDatabaseTransactorTest.kt @@ -427,7 +427,8 @@ class GcpSpannerComputationsDatabaseTransactorTest : protocol = FakeProtocol.ZERO, stage = C, attempt = 1, - editVersion = lastUpdated.toEpochMilli() + editVersion = lastUpdated.toEpochMilli(), + globalId = "0" ) val computation = @@ -435,7 +436,7 @@ class GcpSpannerComputationsDatabaseTransactorTest : localId = token.localId, creationTime = lastUpdated.toGcloudTimestamp(), updateTime = lastUpdated.toGcloudTimestamp(), - globalId = "0", + globalId = token.globalId, protocol = token.protocol, stage = token.stage, lockOwner = "PeterSpacemen", @@ -501,7 +502,8 @@ class GcpSpannerComputationsDatabaseTransactorTest : protocol = FakeProtocol.ZERO, stage = C, attempt = 1, - editVersion = 0 + editVersion = 0, + globalId = "0" ) assertFailsWith { database.enqueue(token, 0) } } @@ -517,7 +519,8 @@ class GcpSpannerComputationsDatabaseTransactorTest : protocol = FakeProtocol.ZERO, stage = C, attempt = 1, - editVersion = lastUpdated.minusSeconds(200).toEpochMilli() + editVersion = lastUpdated.minusSeconds(200).toEpochMilli(), + globalId = "1234" ) val computation = @@ -525,7 +528,7 @@ class GcpSpannerComputationsDatabaseTransactorTest : localId = token.localId, creationTime = lastUpdated.toGcloudTimestamp(), updateTime = lastUpdated.toGcloudTimestamp(), - globalId = "1234", + globalId = token.globalId, protocol = FakeProtocol.ONE, stage = token.stage, lockOwner = "AnOwnedLock", @@ -782,7 +785,8 @@ class GcpSpannerComputationsDatabaseTransactorTest : protocol = FakeProtocol.ONE, stage = B, attempt = 2, - editVersion = testClock["last_updated"].toEpochMilli() + editVersion = testClock["last_updated"].toEpochMilli(), + globalId = globalId ) val computation = computationMutations.insertComputation( @@ -1041,7 +1045,8 @@ class GcpSpannerComputationsDatabaseTransactorTest : attempt = 1, protocol = FakeProtocol.ZERO, stage = A, - editVersion = 0 + editVersion = 0, + globalId = "0" ) assertFailsWith { database.updateComputationStage( @@ -1066,7 +1071,8 @@ class GcpSpannerComputationsDatabaseTransactorTest : protocol = FakeProtocol.ZERO, stage = B, attempt = 1, - editVersion = testClock.last().toEpochMilli() + editVersion = testClock.last().toEpochMilli(), + globalId = "2002" ) val computation = computationMutations.insertComputation( @@ -1075,7 +1081,7 @@ class GcpSpannerComputationsDatabaseTransactorTest : stage = B, creationTime = testClock.last().toGcloudTimestamp(), updateTime = testClock.last().toGcloudTimestamp(), - globalId = "2002", + globalId = token.globalId, lockOwner = WRITE_NULL_STRING, lockExpirationTime = WRITE_NULL_TIMESTAMP, details = FAKE_COMPUTATION_DETAILS @@ -1162,7 +1168,8 @@ class GcpSpannerComputationsDatabaseTransactorTest : protocol = FakeProtocol.ZERO, stage = C, attempt = 2, - editVersion = lastUpdated.toEpochMilli() + editVersion = lastUpdated.toEpochMilli(), + globalId = "55" ) val computation = computationMutations.insertComputation( @@ -1171,7 +1178,7 @@ class GcpSpannerComputationsDatabaseTransactorTest : updateTime = lastUpdated.toGcloudTimestamp(), protocol = token.protocol, stage = token.stage, - globalId = "55", + globalId = token.globalId, lockOwner = "PeterSpacemen", lockExpirationTime = lockExpires.toGcloudTimestamp(), details = FAKE_COMPUTATION_DETAILS @@ -1224,7 +1231,8 @@ class GcpSpannerComputationsDatabaseTransactorTest : protocol = FakeProtocol.ZERO, stage = C, attempt = 2, - editVersion = lastUpdated.toEpochMilli() + editVersion = lastUpdated.toEpochMilli(), + globalId = "55" ) val computation = computationMutations.insertComputation( @@ -1233,7 +1241,7 @@ class GcpSpannerComputationsDatabaseTransactorTest : updateTime = lastUpdated.toGcloudTimestamp(), protocol = token.protocol, stage = token.stage, - globalId = "55", + globalId = token.globalId, lockOwner = "PeterSpacemen", lockExpirationTime = lockExpires.toGcloudTimestamp(), details = FAKE_COMPUTATION_DETAILS @@ -1351,7 +1359,8 @@ class GcpSpannerComputationsDatabaseTransactorTest : protocol = FakeProtocol.ZERO, stage = C, attempt = 2, - editVersion = lastUpdated.toEpochMilli() + editVersion = lastUpdated.toEpochMilli(), + globalId = "55" ) val computation = computationMutations.insertComputation( @@ -1360,7 +1369,7 @@ class GcpSpannerComputationsDatabaseTransactorTest : updateTime = lastUpdated.toGcloudTimestamp(), protocol = token.protocol, stage = token.stage, - globalId = "55", + globalId = token.globalId, lockOwner = "PeterSpacemen", lockExpirationTime = lockExpires.toGcloudTimestamp(), details = FAKE_COMPUTATION_DETAILS @@ -1456,7 +1465,8 @@ class GcpSpannerComputationsDatabaseTransactorTest : protocol = FakeProtocol.ZERO, stage = C, attempt = 2, - editVersion = testClock.last().toEpochMilli() + editVersion = testClock.last().toEpochMilli(), + globalId = "55" ) val computation = computationMutations.insertComputation( @@ -1465,7 +1475,7 @@ class GcpSpannerComputationsDatabaseTransactorTest : updateTime = testClock.last().toGcloudTimestamp(), protocol = token.protocol, stage = C, - globalId = "55", + globalId = token.globalId, lockOwner = WRITE_NULL_STRING, lockExpirationTime = testClock.last().toGcloudTimestamp(), details = FAKE_COMPUTATION_DETAILS @@ -1574,7 +1584,8 @@ class GcpSpannerComputationsDatabaseTransactorTest : protocol = FakeProtocol.ZERO, stage = C, attempt = 1, - editVersion = testClock.last().toEpochMilli() + editVersion = testClock.last().toEpochMilli(), + globalId = globalId ) val computation = computationMutations.insertComputation( @@ -1668,7 +1679,8 @@ class GcpSpannerComputationsDatabaseTransactorTest : protocol = FakeProtocol.ZERO, stage = C, attempt = 1, - editVersion = testClock.last().toEpochMilli() + editVersion = testClock.last().toEpochMilli(), + globalId = "55" ) assertFailsWith(IllegalArgumentException::class, "Invalid initial stage") { database.endComputation(token, B, EndComputationReason.CANCELED, FAKE_COMPUTATION_DETAILS) From d19b2cd46a5e6cbc991fd09a407b0a2f41f9fe9a Mon Sep 17 00:00:00 2001 From: Yuhong Wang Date: Tue, 1 Aug 2023 04:45:53 +0000 Subject: [PATCH 06/12] feedbacks --- .../postgres/PostgresComputationsService.kt | 54 ++++--------------- .../postgres/readers/ComputationReader.kt | 11 +++- .../writers/AdvanceComputationStage.kt | 4 +- .../common/postgres/writers/ClaimWork.kt | 7 +-- .../postgres/writers/CreateComputation.kt | 4 +- .../postgres/writers/FinishComputation.kt | 4 +- .../postgres/writers/RecordOutputBlobPath.kt | 4 +- .../writers/RecordRequisitionBlobPath.kt | 4 +- .../writers/UpdateComputationDetails.kt | 4 +- .../common/server/ComputationsServer.kt | 4 +- .../internal/DuchyInternalException.kt | 4 +- .../computations/ComputationsService.kt | 35 ++---------- .../internal/duchy/error_code.proto | 2 +- 13 files changed, 43 insertions(+), 98 deletions(-) diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresComputationsService.kt b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresComputationsService.kt index 737b2342d44..f9adcfad175 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresComputationsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresComputationsService.kt @@ -16,7 +16,6 @@ package org.wfanet.measurement.duchy.deploy.common.postgres import com.google.protobuf.Empty import io.grpc.Status -import io.grpc.StatusException import java.time.Clock import java.time.Duration import java.util.logging.Level @@ -53,7 +52,7 @@ import org.wfanet.measurement.duchy.service.internal.ComputationAlreadyExistsExc import org.wfanet.measurement.duchy.service.internal.ComputationDetailsNotFoundException import org.wfanet.measurement.duchy.service.internal.ComputationInitialStageInvalidException import org.wfanet.measurement.duchy.service.internal.ComputationNotFoundException -import org.wfanet.measurement.duchy.service.internal.UnknownDataError +import org.wfanet.measurement.duchy.service.internal.DataCorruptedException import org.wfanet.measurement.duchy.service.internal.computations.toAdvanceComputationStageResponse import org.wfanet.measurement.duchy.service.internal.computations.toClaimWorkResponse import org.wfanet.measurement.duchy.service.internal.computations.toCreateComputationResponse @@ -96,7 +95,6 @@ import org.wfanet.measurement.internal.duchy.UpdateComputationDetailsRequest import org.wfanet.measurement.internal.duchy.UpdateComputationDetailsResponse import org.wfanet.measurement.internal.duchy.getComputationIdsResponse import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggregationV2 -import org.wfanet.measurement.internal.duchy.protocol.LiquidLegionsSketchAggregationV2.Stage import org.wfanet.measurement.internal.duchy.purgeComputationsResponse import org.wfanet.measurement.system.v1alpha.ComputationLogEntriesGrpcKt.ComputationLogEntriesCoroutineStub import org.wfanet.measurement.system.v1alpha.ComputationParticipantKey @@ -155,7 +153,7 @@ class PostgresComputationsService( throw ex.asStatusRuntimeException(Status.Code.INVALID_ARGUMENT) } catch (ex: ComputationAlreadyExistsException) { throw ex.asStatusRuntimeException(Status.Code.ALREADY_EXISTS) - } catch (ex: UnknownDataError) { + } catch (ex: DataCorruptedException) { throw ex.asStatusRuntimeException(Status.Code.INTERNAL) } @@ -223,25 +221,13 @@ class PostgresComputationsService( request.localComputationId ) for (blobKey in computationBlobKeys) { - try { - computationStore.get(blobKey)?.delete() - } catch (e: StatusException) { - if (e.status.code != Status.Code.NOT_FOUND) { - throw e - } - } + computationStore.get(blobKey)?.delete() } val requisitionBlobKeys = requisitionReader.readRequisitionBlobKeys(client.singleUse(), request.localComputationId) for (blobKey in requisitionBlobKeys) { - try { - requisitionStore.get(blobKey)?.delete() - } catch (e: StatusException) { - if (e.status.code != Status.NOT_FOUND.code) { - throw e - } - } + requisitionStore.get(blobKey)?.delete() } DeleteComputation(request.localComputationId).execute(client, idGenerator) @@ -277,28 +263,6 @@ class PostgresComputationsService( } for (globalId in globalIds) { val token = computationReader.readComputationToken(client, globalId) ?: continue - val computationStageEnum = token.computationStage - val endComputationStage = getEndingComputationStage(computationStageEnum) - - if (!isTerminated(computationStageEnum)) { - FinishComputation( - token.toDatabaseEditToken(), - endingStage = endComputationStage, - endComputationReason = EndComputationReason.FAILED, - computationDetails = token.computationDetails, - clock = clock, - protocolStagesEnumHelper = protocolStagesEnumHelper, - protocolStageDetailsHelper = computationProtocolStageDetailsHelper, - computationReader = computationReader, - ) - .execute(client, idGenerator) - sendStatusUpdateToKingdom( - newCreateComputationLogEntryRequest( - token.globalComputationId, - endComputationStage, - ) - ) - } DeleteComputation(token.localComputationId).execute(client, idGenerator) deleted += 1 } @@ -330,7 +294,7 @@ class PostgresComputationsService( computationReader = computationReader, ) .execute(client, idGenerator) - } catch (ex: UnknownDataError) { + } catch (ex: DataCorruptedException) { throw ex.asStatusRuntimeException(Status.Code.INTERNAL) } @@ -361,7 +325,7 @@ class PostgresComputationsService( computationReader = computationReader ) .execute(client, idGenerator) - } catch (ex: UnknownDataError) { + } catch (ex: DataCorruptedException) { throw ex.asStatusRuntimeException(Status.Code.INTERNAL) } @@ -381,7 +345,7 @@ class PostgresComputationsService( computationReader = computationReader, ) .execute(client, idGenerator) - } catch (ex: UnknownDataError) { + } catch (ex: DataCorruptedException) { throw ex.asStatusRuntimeException(Status.Code.INTERNAL) } @@ -423,7 +387,7 @@ class PostgresComputationsService( computationReader = computationReader, ) .execute(client, idGenerator) - } catch (ex: UnknownDataError) { + } catch (ex: DataCorruptedException) { throw ex.asStatusRuntimeException(Status.Code.INTERNAL) } @@ -473,7 +437,7 @@ class PostgresComputationsService( computationReader = computationReader, ) .execute(client, idGenerator) - } catch (ex: UnknownDataError) { + } catch (ex: DataCorruptedException) { throw ex.asStatusRuntimeException(Status.Code.INTERNAL) } diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/readers/ComputationReader.kt b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/readers/ComputationReader.kt index cf076bb5835..3fa208a801a 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/readers/ComputationReader.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/readers/ComputationReader.kt @@ -219,7 +219,7 @@ class ComputationReader( } /** - * Reads a [ComputationToken] by globalComputationId. + * Reads a [ComputationToken] by globalComputationId in a new transaction. * * @param client The [DatabaseClient] to the Postgres database. * @param globalComputationId A global identifier for a computation. @@ -263,7 +263,7 @@ class ComputationReader( } /** - * Reads a [ComputationToken] by externalRequisitionKey. + * Reads a [ComputationToken] by externalRequisitionKey in a new transaction. * * @param client The [DatabaseClient] to the Postgres database. * @param externalRequisitionKey The [ExternalRequisitionKey] for a computation. @@ -374,6 +374,13 @@ class ComputationReader( return readContext.executeQuery(listUnclaimedTasksSql).consume(::buildUnclaimedTaskQueryResult) } + /** + * Reads the LockOwner and UpdateTime of a computation. + * + * @param readContext The transaction context for reading from the Postgres database. + * @param computationId The local identifier for a computation. + * @return [LockOwnerQueryResult] if computation is found, otherwise null. + */ suspend fun readLockOwner(readContext: ReadContext, computationId: Long): LockOwnerQueryResult? { val readLockOwnerSql = boundStatement( diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/AdvanceComputationStage.kt b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/AdvanceComputationStage.kt index 69b0db28fd7..eb14f11861a 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/AdvanceComputationStage.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/AdvanceComputationStage.kt @@ -25,7 +25,7 @@ import org.wfanet.measurement.duchy.db.computation.ComputationProtocolStagesEnum import org.wfanet.measurement.duchy.deploy.common.postgres.readers.ComputationBlobReferenceReader import org.wfanet.measurement.duchy.deploy.common.postgres.readers.ComputationReader import org.wfanet.measurement.duchy.deploy.common.postgres.readers.ComputationStageAttemptReader -import org.wfanet.measurement.duchy.service.internal.UnknownDataError +import org.wfanet.measurement.duchy.service.internal.DataCorruptedException import org.wfanet.measurement.internal.duchy.ComputationBlobDependency import org.wfanet.measurement.internal.duchy.ComputationStageAttemptDetails import org.wfanet.measurement.internal.duchy.ComputationToken @@ -192,6 +192,6 @@ class AdvanceComputationStage( } return computationReader.readComputationToken(transactionContext, token.globalId) - ?: throw UnknownDataError() + ?: throw DataCorruptedException() } } diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/ClaimWork.kt b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/ClaimWork.kt index ef1e784a356..ef8be8b0336 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/ClaimWork.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/ClaimWork.kt @@ -24,7 +24,7 @@ import org.wfanet.measurement.duchy.db.computation.ComputationTypeEnumHelper import org.wfanet.measurement.duchy.deploy.common.postgres.readers.ComputationReader import org.wfanet.measurement.duchy.deploy.common.postgres.readers.ComputationStageAttemptReader import org.wfanet.measurement.duchy.service.internal.ComputationNotFoundException -import org.wfanet.measurement.duchy.service.internal.UnknownDataError +import org.wfanet.measurement.duchy.service.internal.DataCorruptedException import org.wfanet.measurement.internal.duchy.ComputationStageAttemptDetails import org.wfanet.measurement.internal.duchy.ComputationToken import org.wfanet.measurement.internal.duchy.copy @@ -44,6 +44,7 @@ import org.wfanet.measurement.internal.duchy.copy * Throws following exceptions on [execute]: * * [ComputationNotFoundException] when computation could not be found * * [IllegalStateException] when computation details could not be found + * * [DataCorruptedException] when data is corrupted */ class ClaimWork( private val protocol: ProtocolT, @@ -68,7 +69,7 @@ class ClaimWork( .firstOrNull() ?.let { computationReader.readComputationToken(transactionContext, it.globalId) - ?: throw UnknownDataError() + ?: throw DataCorruptedException() } } @@ -81,7 +82,7 @@ class ClaimWork( ): Boolean { val currentLockOwner = computationReader.readLockOwner(transactionContext, unclaimedTask.computationId) - ?: throw UnknownDataError() + ?: throw DataCorruptedException() // Verify that the row hasn't been updated since the previous, non-transactional read. // If it has been updated since that time the lock should not be acquired. if (currentLockOwner.updateTime != unclaimedTask.updateTime) return false diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/CreateComputation.kt b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/CreateComputation.kt index b9550443d45..a3e89ea4b1f 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/CreateComputation.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/CreateComputation.kt @@ -23,8 +23,8 @@ import org.wfanet.measurement.duchy.db.computation.ComputationTypeEnumHelper import org.wfanet.measurement.duchy.deploy.common.postgres.readers.ComputationReader import org.wfanet.measurement.duchy.service.internal.ComputationAlreadyExistsException import org.wfanet.measurement.duchy.service.internal.ComputationInitialStageInvalidException +import org.wfanet.measurement.duchy.service.internal.DataCorruptedException import org.wfanet.measurement.duchy.service.internal.DuchyInternalException -import org.wfanet.measurement.duchy.service.internal.UnknownDataError import org.wfanet.measurement.internal.duchy.ComputationDetails import org.wfanet.measurement.internal.duchy.ComputationStageDetails import org.wfanet.measurement.internal.duchy.ComputationToken @@ -113,6 +113,6 @@ class CreateComputation( ) return computationReader.readComputationToken(transactionContext, token.globalId) - ?: throw UnknownDataError() + ?: throw DataCorruptedException() } } diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/RecordRequisitionBlobPath.kt b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/RecordRequisitionBlobPath.kt index dac8223d2d1..02e9b660d03 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/RecordRequisitionBlobPath.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/RecordRequisitionBlobPath.kt @@ -18,7 +18,7 @@ import java.time.Clock import org.wfanet.measurement.common.db.r2dbc.postgres.PostgresWriter import org.wfanet.measurement.duchy.deploy.common.postgres.readers.ComputationReader import org.wfanet.measurement.duchy.deploy.common.postgres.readers.RequisitionReader -import org.wfanet.measurement.duchy.service.internal.UnknownDataError +import org.wfanet.measurement.duchy.service.internal.DataCorruptedException import org.wfanet.measurement.internal.duchy.ComputationToken import org.wfanet.measurement.internal.duchy.ExternalRequisitionKey @@ -60,6 +60,6 @@ class RecordRequisitionBlobPath( ) return computationReader.readComputationToken(transactionContext, externalRequisitionKey) - ?: throw UnknownDataError() + ?: throw DataCorruptedException() } } diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/UpdateComputationDetails.kt b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/UpdateComputationDetails.kt index 6d9bc1be617..b8da9f9bd57 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/UpdateComputationDetails.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/UpdateComputationDetails.kt @@ -20,7 +20,7 @@ import org.wfanet.measurement.common.db.r2dbc.postgres.PostgresWriter import org.wfanet.measurement.duchy.db.computation.ComputationEditToken import org.wfanet.measurement.duchy.deploy.common.postgres.readers.ComputationReader import org.wfanet.measurement.duchy.deploy.common.postgres.readers.RequisitionReader -import org.wfanet.measurement.duchy.service.internal.UnknownDataError +import org.wfanet.measurement.duchy.service.internal.DataCorruptedException import org.wfanet.measurement.internal.duchy.ComputationToken import org.wfanet.measurement.internal.duchy.RequisitionEntry @@ -63,6 +63,6 @@ class UpdateComputationDetails( updateComputation(localId = token.localId, updateTime = writeTime, details = computationDetails) return computationReader.readComputationToken(transactionContext, token.globalId) - ?: throw UnknownDataError() + ?: throw DataCorruptedException() } } diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/server/ComputationsServer.kt b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/server/ComputationsServer.kt index f496273c77c..f3f8a2b3498 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/server/ComputationsServer.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/server/ComputationsServer.kt @@ -84,8 +84,8 @@ abstract class ComputationsServer : Runnable { ComputationsService( computationsDatabase = computationsDatabase, computationLogEntriesClient = computationLogEntriesClient, - computationStorageClient = ComputationStore(storageClient), - requisitionStorageClient = RequisitionStore(storageClient), + computationStore = ComputationStore(storageClient), + requisitionStore = RequisitionStore(storageClient), duchyName = flags.duchy.duchyName ) diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/DuchyInternalException.kt b/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/DuchyInternalException.kt index 7a50731b6b1..f12217508b4 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/DuchyInternalException.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/DuchyInternalException.kt @@ -102,9 +102,9 @@ class ComputationAlreadyExistsException( get() = mapOf("global_computation_id" to globalComputationId) } -class UnknownDataError( +class DataCorruptedException( message: String = "Data corrupted for unknown reasons", -) : DuchyInternalException(ErrorCode.UNKNOWN_DATA_ERROR, message) { +) : DuchyInternalException(ErrorCode.DATA_CORRUPTED, message) { override val context: Map get() = emptyMap() } diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/computations/ComputationsService.kt b/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/computations/ComputationsService.kt index f962855e041..69197043ffd 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/computations/ComputationsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/computations/ComputationsService.kt @@ -16,7 +16,6 @@ package org.wfanet.measurement.duchy.service.internal.computations import com.google.protobuf.Empty import io.grpc.Status -import io.grpc.StatusException import java.time.Clock import java.time.Duration import java.util.logging.Level @@ -73,8 +72,8 @@ import org.wfanet.measurement.system.v1alpha.CreateComputationLogEntryRequest class ComputationsService( private val computationsDatabase: ComputationsDatabase, private val computationLogEntriesClient: ComputationLogEntriesCoroutineStub, - private val computationStorageClient: ComputationStore, - private val requisitionStorageClient: RequisitionStore, + private val computationStore: ComputationStore, + private val requisitionStore: RequisitionStore, private val duchyName: String, private val clock: Clock = Clock.systemUTC(), private val defaultLockDuration: Duration = Duration.ofMinutes(5), @@ -134,23 +133,11 @@ class ComputationsService( private suspend fun deleteComputation(localId: Long) { val computationBlobKeys = computationsDatabase.readComputationBlobKeys(localId) for (blobKey in computationBlobKeys) { - try { - computationStorageClient.get(blobKey)?.delete() - } catch (e: StatusException) { - if (e.status.code != Status.Code.NOT_FOUND) { - throw e - } - } + computationStore.get(blobKey)?.delete() } val requisitionBlobKeys = computationsDatabase.readRequisitionBlobKeys(localId) for (blobKey in requisitionBlobKeys) { - try { - requisitionStorageClient.get(blobKey)?.delete() - } catch (e: StatusException) { - if (e.status.code != Status.NOT_FOUND.code) { - throw e - } - } + requisitionStore.get(blobKey)?.delete() } computationsDatabase.deleteComputation(localId) } @@ -185,20 +172,6 @@ class ComputationsService( } for (globalId in globalIds) { val token = computationsDatabase.readComputationToken(globalId) ?: continue - if (!isTerminated(token)) { - computationsDatabase.endComputation( - token.toDatabaseEditToken(), - getEndingComputationStage(token), - EndComputationReason.FAILED, - token.computationDetails - ) - sendStatusUpdateToKingdom( - newCreateComputationLogEntryRequest( - token.globalComputationId, - getEndingComputationStage(token), - ) - ) - } deleteComputation(token.localComputationId) deleted += 1 } diff --git a/src/main/proto/wfa/measurement/internal/duchy/error_code.proto b/src/main/proto/wfa/measurement/internal/duchy/error_code.proto index 4ed1d7a1269..f65970465e4 100644 --- a/src/main/proto/wfa/measurement/internal/duchy/error_code.proto +++ b/src/main/proto/wfa/measurement/internal/duchy/error_code.proto @@ -41,5 +41,5 @@ enum ErrorCode { COMPUTATION_ALREADY_EXISTS = 6; /** Data corrupted for unknown reasons. */ - UNKNOWN_DATA_ERROR = 7; + DATA_CORRUPTED = 7; } From 205be8580753e5ad62881f9c949edc24fc51050d Mon Sep 17 00:00:00 2001 From: Yuhong Wang Date: Wed, 2 Aug 2023 15:18:40 +0000 Subject: [PATCH 07/12] feedbacks --- .../postgres/PostgresComputationsService.kt | 189 ++++++++---------- .../writers/AdvanceComputationStage.kt | 4 +- .../common/postgres/writers/ClaimWork.kt | 7 +- .../postgres/writers/ComputationMutations.kt | 44 ++++ .../postgres/writers/CreateComputation.kt | 4 +- .../postgres/writers/DeleteComputation.kt | 42 ++-- .../postgres/writers/FinishComputation.kt | 4 +- .../postgres/writers/RecordOutputBlobPath.kt | 10 +- .../writers/RecordRequisitionBlobPath.kt | 6 +- .../writers/UpdateComputationDetails.kt | 4 +- .../internal/DuchyInternalException.kt | 7 - 11 files changed, 163 insertions(+), 158 deletions(-) diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresComputationsService.kt b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresComputationsService.kt index f9adcfad175..7f6a6caf105 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresComputationsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresComputationsService.kt @@ -52,7 +52,6 @@ import org.wfanet.measurement.duchy.service.internal.ComputationAlreadyExistsExc import org.wfanet.measurement.duchy.service.internal.ComputationDetailsNotFoundException import org.wfanet.measurement.duchy.service.internal.ComputationInitialStageInvalidException import org.wfanet.measurement.duchy.service.internal.ComputationNotFoundException -import org.wfanet.measurement.duchy.service.internal.DataCorruptedException import org.wfanet.measurement.duchy.service.internal.computations.toAdvanceComputationStageResponse import org.wfanet.measurement.duchy.service.internal.computations.toClaimWorkResponse import org.wfanet.measurement.duchy.service.internal.computations.toCreateComputationResponse @@ -110,7 +109,10 @@ class PostgresComputationsService( ComputationProtocolStagesEnumHelper, private val computationProtocolStageDetailsHelper: ComputationProtocolStageDetailsHelper< - ComputationType, ComputationStage, ComputationStageDetails, ComputationDetails + ComputationType, + ComputationStage, + ComputationStageDetails, + ComputationDetails, >, private val client: DatabaseClient, private val idGenerator: IdGenerator, @@ -127,7 +129,7 @@ class PostgresComputationsService( private val requisitionReader = RequisitionReader() override suspend fun createComputation( - request: CreateComputationRequest + request: CreateComputationRequest, ): CreateComputationResponse { grpcRequire(request.globalComputationId.isNotEmpty()) { "global_computation_id is not specified." @@ -146,15 +148,13 @@ class PostgresComputationsService( computationTypeEnumHelper, protocolStagesEnumHelper, computationProtocolStageDetailsHelper, - computationReader + computationReader, ) .execute(client, idGenerator) } catch (ex: ComputationInitialStageInvalidException) { throw ex.asStatusRuntimeException(Status.Code.INVALID_ARGUMENT) } catch (ex: ComputationAlreadyExistsException) { throw ex.asStatusRuntimeException(Status.Code.ALREADY_EXISTS) - } catch (ex: DataCorruptedException) { - throw ex.asStatusRuntimeException(Status.Code.INTERNAL) } return token.toCreateComputationResponse() @@ -188,8 +188,8 @@ class PostgresComputationsService( newCreateComputationLogEntryRequest( claimedToken.globalComputationId, claimedToken.computationStage, - claimedToken.attempt.toLong() - ) + claimedToken.attempt.toLong(), + ), ) return claimedToken.toClaimWorkResponse() } @@ -198,7 +198,7 @@ class PostgresComputationsService( } override suspend fun getComputationToken( - request: GetComputationTokenRequest + request: GetComputationTokenRequest, ): GetComputationTokenResponse { val token = @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. @@ -218,7 +218,7 @@ class PostgresComputationsService( val computationBlobKeys = computationBlobReferenceReader.readComputationBlobKeys( client.singleUse(), - request.localComputationId + request.localComputationId, ) for (blobKey in computationBlobKeys) { computationStore.get(blobKey)?.delete() @@ -235,15 +235,15 @@ class PostgresComputationsService( } override suspend fun purgeComputations( - request: PurgeComputationsRequest + request: PurgeComputationsRequest, ): PurgeComputationsResponse { grpcRequire( request.stagesList.all { protocolStagesEnumHelper.validTerminalStage( protocolStagesEnumHelper.stageToProtocol(it), - it + it, ) - } + }, ) { "Requested stage list contains non terminal stage." } @@ -253,7 +253,7 @@ class PostgresComputationsService( computationReader.readGlobalComputationIds( client.singleUse(), request.stagesList, - request.updatedBefore.toInstant() + request.updatedBefore.toInstant(), ) if (!request.force) { return purgeComputationsResponse { @@ -262,9 +262,8 @@ class PostgresComputationsService( } } for (globalId in globalIds) { - val token = computationReader.readComputationToken(client, globalId) ?: continue - DeleteComputation(token.localComputationId).execute(client, idGenerator) - deleted += 1 + val numOfRowsDeleted = DeleteComputation(globalId).execute(client, idGenerator) + deleted += numOfRowsDeleted.toInt() } } catch (e: Exception) { logger.log(Level.WARNING, "Exception during Computations cleaning. $e") @@ -273,87 +272,75 @@ class PostgresComputationsService( } override suspend fun finishComputation( - request: FinishComputationRequest + request: FinishComputationRequest, ): FinishComputationResponse { val token = - try { - FinishComputation( - request.token.toDatabaseEditToken(), - endingStage = request.endingComputationStage, - endComputationReason = - when (request.reason) { - ComputationDetails.CompletedReason.SUCCEEDED -> EndComputationReason.SUCCEEDED - ComputationDetails.CompletedReason.FAILED -> EndComputationReason.FAILED - ComputationDetails.CompletedReason.CANCELED -> EndComputationReason.CANCELED - else -> error("Unknown CompletedReason ${request.reason}") - }, - computationDetails = request.token.computationDetails, - clock = clock, - protocolStagesEnumHelper = protocolStagesEnumHelper, - protocolStageDetailsHelper = computationProtocolStageDetailsHelper, - computationReader = computationReader, - ) - .execute(client, idGenerator) - } catch (ex: DataCorruptedException) { - throw ex.asStatusRuntimeException(Status.Code.INTERNAL) - } + FinishComputation( + request.token.toDatabaseEditToken(), + endingStage = request.endingComputationStage, + endComputationReason = + when (request.reason) { + ComputationDetails.CompletedReason.SUCCEEDED -> EndComputationReason.SUCCEEDED + ComputationDetails.CompletedReason.FAILED -> EndComputationReason.FAILED + ComputationDetails.CompletedReason.CANCELED -> EndComputationReason.CANCELED + else -> error("Unknown CompletedReason ${request.reason}") + }, + computationDetails = request.token.computationDetails, + clock = clock, + protocolStagesEnumHelper = protocolStagesEnumHelper, + protocolStageDetailsHelper = computationProtocolStageDetailsHelper, + computationReader = computationReader, + ) + .execute(client, idGenerator) sendStatusUpdateToKingdom( newCreateComputationLogEntryRequest( request.token.globalComputationId, - request.endingComputationStage - ) + request.endingComputationStage, + ), ) return token.toFinishComputationResponse() } override suspend fun updateComputationDetails( - request: UpdateComputationDetailsRequest + request: UpdateComputationDetailsRequest, ): UpdateComputationDetailsResponse { require(request.token.computationDetails.protocolCase == request.details.protocolCase) { "The protocol type cannot change." } val token = - try { - UpdateComputationDetails( - token = request.token.toDatabaseEditToken(), - clock = clock, - computationDetails = request.details, - requisitionEntries = request.requisitionsList, - computationReader = computationReader - ) - .execute(client, idGenerator) - } catch (ex: DataCorruptedException) { - throw ex.asStatusRuntimeException(Status.Code.INTERNAL) - } + UpdateComputationDetails( + token = request.token.toDatabaseEditToken(), + clock = clock, + computationDetails = request.details, + requisitionEntries = request.requisitionsList, + computationReader = computationReader, + ) + .execute(client, idGenerator) return token.toUpdateComputationDetailsResponse() } override suspend fun recordOutputBlobPath( - request: RecordOutputBlobPathRequest + request: RecordOutputBlobPathRequest, ): RecordOutputBlobPathResponse { val token = - try { - RecordOutputBlobPath( - token = request.token.toDatabaseEditToken(), - clock = clock, - blobRef = BlobRef(request.outputBlobId, request.blobPath), - protocolStagesEnumHelper = protocolStagesEnumHelper, - computationReader = computationReader, - ) - .execute(client, idGenerator) - } catch (ex: DataCorruptedException) { - throw ex.asStatusRuntimeException(Status.Code.INTERNAL) - } + RecordOutputBlobPath( + token = request.token.toDatabaseEditToken(), + clock = clock, + blobRef = BlobRef(request.outputBlobId, request.blobPath), + protocolStagesEnumHelper = protocolStagesEnumHelper, + computationReader = computationReader, + ) + .execute(client, idGenerator) return token.toRecordOutputBlobPathResponse() } override suspend fun advanceComputationStage( - request: AdvanceComputationStageRequest + request: AdvanceComputationStageRequest, ): AdvanceComputationStageResponse { val lockExtension: Duration = if (request.hasLockExtension()) request.lockExtension.toDuration() else defaultLockDuration @@ -367,49 +354,45 @@ class PostgresComputationsService( AfterTransition.CONTINUE_WORKING else -> error( - "Unsupported AdvanceComputationStageRequest.AfterTransition '${request.afterTransition}'. " + "Unsupported AdvanceComputationStageRequest.AfterTransition '${request.afterTransition}'. ", ) } val token = - try { - AdvanceComputationStage( - request.token.toDatabaseEditToken(), - nextStage = request.nextComputationStage, - nextStageDetails = request.stageDetails, - inputBlobPaths = request.inputBlobsList, - passThroughBlobPaths = request.passThroughBlobsList, - outputBlobs = request.outputBlobs, - afterTransition = afterTransition, - lockExtension = lockExtension, - clock = clock, - protocolStagesEnumHelper = protocolStagesEnumHelper, - computationReader = computationReader, - ) - .execute(client, idGenerator) - } catch (ex: DataCorruptedException) { - throw ex.asStatusRuntimeException(Status.Code.INTERNAL) - } + AdvanceComputationStage( + request.token.toDatabaseEditToken(), + nextStage = request.nextComputationStage, + nextStageDetails = request.stageDetails, + inputBlobPaths = request.inputBlobsList, + passThroughBlobPaths = request.passThroughBlobsList, + outputBlobs = request.outputBlobs, + afterTransition = afterTransition, + lockExtension = lockExtension, + clock = clock, + protocolStagesEnumHelper = protocolStagesEnumHelper, + computationReader = computationReader, + ) + .execute(client, idGenerator) sendStatusUpdateToKingdom( newCreateComputationLogEntryRequest( request.token.globalComputationId, - request.nextComputationStage - ) + request.nextComputationStage, + ), ) return token.toAdvanceComputationStageResponse() } override suspend fun getComputationIds( - request: GetComputationIdsRequest + request: GetComputationIdsRequest, ): GetComputationIdsResponse { val ids = computationReader.readGlobalComputationIds(client.singleUse(), request.stagesList) return getComputationIdsResponse { globalIds += ids } } override suspend fun enqueueComputation( - request: EnqueueComputationRequest + request: EnqueueComputationRequest, ): EnqueueComputationResponse { grpcRequire(request.delaySecond >= 0) { "DelaySecond ${request.delaySecond} should be non-negative." @@ -425,21 +408,17 @@ class PostgresComputationsService( } override suspend fun recordRequisitionBlobPath( - request: RecordRequisitionBlobPathRequest + request: RecordRequisitionBlobPathRequest, ): RecordRequisitionBlobPathResponse { val token = - try { - RecordRequisitionBlobPath( - clock = clock, - localId = request.token.localComputationId, - externalRequisitionKey = request.key, - pathToBlob = request.blobPath, - computationReader = computationReader, - ) - .execute(client, idGenerator) - } catch (ex: DataCorruptedException) { - throw ex.asStatusRuntimeException(Status.Code.INTERNAL) - } + RecordRequisitionBlobPath( + clock = clock, + localId = request.token.localComputationId, + externalRequisitionKey = request.key, + pathToBlob = request.blobPath, + computationReader = computationReader, + ) + .execute(client, idGenerator) return token.toRecordRequisitionBlobPathResponse() } @@ -447,7 +426,7 @@ class PostgresComputationsService( private fun newCreateComputationLogEntryRequest( globalId: String, computationStage: ComputationStage, - attempt: Long = 0L + attempt: Long = 0L, ): CreateComputationLogEntryRequest { return createComputationLogEntryRequest { parent = ComputationParticipantKey(globalId, duchyName).toName() diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/AdvanceComputationStage.kt b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/AdvanceComputationStage.kt index eb14f11861a..34e787da1e1 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/AdvanceComputationStage.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/AdvanceComputationStage.kt @@ -25,7 +25,6 @@ import org.wfanet.measurement.duchy.db.computation.ComputationProtocolStagesEnum import org.wfanet.measurement.duchy.deploy.common.postgres.readers.ComputationBlobReferenceReader import org.wfanet.measurement.duchy.deploy.common.postgres.readers.ComputationReader import org.wfanet.measurement.duchy.deploy.common.postgres.readers.ComputationStageAttemptReader -import org.wfanet.measurement.duchy.service.internal.DataCorruptedException import org.wfanet.measurement.internal.duchy.ComputationBlobDependency import org.wfanet.measurement.internal.duchy.ComputationStageAttemptDetails import org.wfanet.measurement.internal.duchy.ComputationToken @@ -191,7 +190,6 @@ class AdvanceComputationStage( ) } - return computationReader.readComputationToken(transactionContext, token.globalId) - ?: throw DataCorruptedException() + return checkNotNull(computationReader.readComputationToken(transactionContext, token.globalId)) } } diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/ClaimWork.kt b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/ClaimWork.kt index ef8be8b0336..e0b781b5718 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/ClaimWork.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/ClaimWork.kt @@ -24,7 +24,6 @@ import org.wfanet.measurement.duchy.db.computation.ComputationTypeEnumHelper import org.wfanet.measurement.duchy.deploy.common.postgres.readers.ComputationReader import org.wfanet.measurement.duchy.deploy.common.postgres.readers.ComputationStageAttemptReader import org.wfanet.measurement.duchy.service.internal.ComputationNotFoundException -import org.wfanet.measurement.duchy.service.internal.DataCorruptedException import org.wfanet.measurement.internal.duchy.ComputationStageAttemptDetails import org.wfanet.measurement.internal.duchy.ComputationToken import org.wfanet.measurement.internal.duchy.copy @@ -68,8 +67,7 @@ class ClaimWork( // If the value is null, no tasks were claimed. .firstOrNull() ?.let { - computationReader.readComputationToken(transactionContext, it.globalId) - ?: throw DataCorruptedException() + checkNotNull(computationReader.readComputationToken(transactionContext, it.globalId)) } } @@ -81,8 +79,7 @@ class ClaimWork( unclaimedTask: ComputationReader.UnclaimedTaskQueryResult ): Boolean { val currentLockOwner = - computationReader.readLockOwner(transactionContext, unclaimedTask.computationId) - ?: throw DataCorruptedException() + checkNotNull(computationReader.readLockOwner(transactionContext, unclaimedTask.computationId)) // Verify that the row hasn't been updated since the previous, non-transactional read. // If it has been updated since that time the lock should not be acquired. if (currentLockOwner.updateTime != unclaimedTask.updateTime) return false diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/ComputationMutations.kt b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/ComputationMutations.kt index a4f304a472a..01bfb167813 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/ComputationMutations.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/ComputationMutations.kt @@ -240,3 +240,47 @@ suspend fun PostgresWriter.TransactionScope.checkComputationUnmodified( ) } } + +/** + * Deletes a computation by local identifier + * + * @param localId local identifier of a computation + * @return number of rows deleted + */ +suspend fun PostgresWriter.TransactionScope.deleteComputationByLocalId( + localId: Long, +): Long { + val sql = + boundStatement( + """ + DELETE FROM Computations + WHERE ComputationId = $1 + """ + .trimIndent() + ) { + bind("$1", localId) + } + return transactionContext.executeStatement(sql).numRowsUpdated +} + +/** + * Deletes a computation by local identifier + * + * @param globalId global identifier of a computation + * @return number of rows deleted + */ +suspend fun PostgresWriter.TransactionScope.deleteComputationByGlobalId( + globalId: String, +): Long { + val sql = + boundStatement( + """ + DELETE FROM Computations + WHERE GlobalComputationId = $1 + """ + .trimIndent() + ) { + bind("$1", globalId) + } + return transactionContext.executeStatement(sql).numRowsUpdated +} diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/CreateComputation.kt b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/CreateComputation.kt index a3e89ea4b1f..4f1a236a784 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/CreateComputation.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/CreateComputation.kt @@ -23,7 +23,6 @@ import org.wfanet.measurement.duchy.db.computation.ComputationTypeEnumHelper import org.wfanet.measurement.duchy.deploy.common.postgres.readers.ComputationReader import org.wfanet.measurement.duchy.service.internal.ComputationAlreadyExistsException import org.wfanet.measurement.duchy.service.internal.ComputationInitialStageInvalidException -import org.wfanet.measurement.duchy.service.internal.DataCorruptedException import org.wfanet.measurement.duchy.service.internal.DuchyInternalException import org.wfanet.measurement.internal.duchy.ComputationDetails import org.wfanet.measurement.internal.duchy.ComputationStageDetails @@ -112,7 +111,6 @@ class CreateComputation() { - - override suspend fun TransactionScope.runTransaction() { - val statement = - boundStatement( - """ - DELETE FROM Computations - WHERE ComputationId = $1 - """ - .trimIndent() - ) { - bind("$1", localId) - } - transactionContext.executeStatement(statement) +/** [PostgresWriter] to delete a computation by its localComputationId or globalComputationId. */ +class DeleteComputation : PostgresWriter { + + private var globalId: String? = null + private var localId: Long? = null + + constructor(globalId: String) : super() { + this.globalId = globalId + } + + constructor(localId: Long) : super() { + this.localId = localId + } + + override suspend fun TransactionScope.runTransaction(): Long { + if (localId != null) { + return deleteComputationByLocalId(localId!!) + } + + if (globalId != null) { + return deleteComputationByGlobalId(globalId!!) + } } } diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/FinishComputation.kt b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/FinishComputation.kt index 6359770471c..cfe49a20d5c 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/FinishComputation.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/FinishComputation.kt @@ -24,7 +24,6 @@ import org.wfanet.measurement.duchy.db.computation.ComputationProtocolStagesEnum import org.wfanet.measurement.duchy.db.computation.EndComputationReason import org.wfanet.measurement.duchy.deploy.common.postgres.readers.ComputationReader import org.wfanet.measurement.duchy.deploy.common.postgres.readers.ComputationStageAttemptReader -import org.wfanet.measurement.duchy.service.internal.DataCorruptedException import org.wfanet.measurement.internal.duchy.ComputationStageAttemptDetails import org.wfanet.measurement.internal.duchy.ComputationToken import org.wfanet.measurement.internal.duchy.copy @@ -137,8 +136,7 @@ class FinishComputation( transactionContext, localId, stageLongValue, - blobRef.idInRelationalDatabase + blobRef.idInRelationalDatabase, ) ?: error( "No ComputationBlobReferences row for " + - "($localId, $stage, ${blobRef.idInRelationalDatabase})" + "($localId, $stage, ${blobRef.idInRelationalDatabase})", ) require(type == ComputationBlobDependency.OUTPUT) { "Cannot write to $type blob" } @@ -70,10 +69,9 @@ class RecordOutputBlobPath( localId = localId, stage = stageLongValue, blobId = blobRef.idInRelationalDatabase, - pathToBlob = blobRef.key + pathToBlob = blobRef.key, ) - return computationReader.readComputationToken(transactionContext, token.globalId) - ?: throw DataCorruptedException() + return checkNotNull(computationReader.readComputationToken(transactionContext, token.globalId)) } } diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/RecordRequisitionBlobPath.kt b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/RecordRequisitionBlobPath.kt index 02e9b660d03..2421427b732 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/RecordRequisitionBlobPath.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/RecordRequisitionBlobPath.kt @@ -18,7 +18,6 @@ import java.time.Clock import org.wfanet.measurement.common.db.r2dbc.postgres.PostgresWriter import org.wfanet.measurement.duchy.deploy.common.postgres.readers.ComputationReader import org.wfanet.measurement.duchy.deploy.common.postgres.readers.RequisitionReader -import org.wfanet.measurement.duchy.service.internal.DataCorruptedException import org.wfanet.measurement.internal.duchy.ComputationToken import org.wfanet.measurement.internal.duchy.ExternalRequisitionKey @@ -59,7 +58,8 @@ class RecordRequisitionBlobPath( updateTime = writeTime ) - return computationReader.readComputationToken(transactionContext, externalRequisitionKey) - ?: throw DataCorruptedException() + return checkNotNull( + computationReader.readComputationToken(transactionContext, externalRequisitionKey) + ) } } diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/UpdateComputationDetails.kt b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/UpdateComputationDetails.kt index b8da9f9bd57..7af43e26410 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/UpdateComputationDetails.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/UpdateComputationDetails.kt @@ -20,7 +20,6 @@ import org.wfanet.measurement.common.db.r2dbc.postgres.PostgresWriter import org.wfanet.measurement.duchy.db.computation.ComputationEditToken import org.wfanet.measurement.duchy.deploy.common.postgres.readers.ComputationReader import org.wfanet.measurement.duchy.deploy.common.postgres.readers.RequisitionReader -import org.wfanet.measurement.duchy.service.internal.DataCorruptedException import org.wfanet.measurement.internal.duchy.ComputationToken import org.wfanet.measurement.internal.duchy.RequisitionEntry @@ -62,7 +61,6 @@ class UpdateComputationDetails( } updateComputation(localId = token.localId, updateTime = writeTime, details = computationDetails) - return computationReader.readComputationToken(transactionContext, token.globalId) - ?: throw DataCorruptedException() + return checkNotNull(computationReader.readComputationToken(transactionContext, token.globalId)) } } diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/DuchyInternalException.kt b/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/DuchyInternalException.kt index f12217508b4..42e43bca799 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/DuchyInternalException.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/DuchyInternalException.kt @@ -101,10 +101,3 @@ class ComputationAlreadyExistsException( override val context get() = mapOf("global_computation_id" to globalComputationId) } - -class DataCorruptedException( - message: String = "Data corrupted for unknown reasons", -) : DuchyInternalException(ErrorCode.DATA_CORRUPTED, message) { - override val context: Map - get() = emptyMap() -} From 3b74c02322307381c5159a6057a4bd86b0a57da3 Mon Sep 17 00:00:00 2001 From: Yuhong Wang Date: Wed, 2 Aug 2023 16:42:57 +0000 Subject: [PATCH 08/12] lint --- .../duchy/service/internal/computations/ComputationsService.kt | 1 - 1 file changed, 1 deletion(-) diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/computations/ComputationsService.kt b/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/computations/ComputationsService.kt index 0d21c8edb16..b1e3fca2ab2 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/computations/ComputationsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/computations/ComputationsService.kt @@ -33,7 +33,6 @@ import org.wfanet.measurement.duchy.name import org.wfanet.measurement.duchy.number import org.wfanet.measurement.duchy.storage.ComputationStore import org.wfanet.measurement.duchy.storage.RequisitionStore -import org.wfanet.measurement.duchy.toProtocolStage import org.wfanet.measurement.internal.duchy.AdvanceComputationStageRequest import org.wfanet.measurement.internal.duchy.AdvanceComputationStageResponse import org.wfanet.measurement.internal.duchy.ClaimWorkRequest From 94a9c4bcf8f84b26d98a8d49a52c6d5fdad1dc5e Mon Sep 17 00:00:00 2001 From: Yuhong Wang Date: Wed, 2 Aug 2023 23:19:47 +0000 Subject: [PATCH 09/12] update purge into a transaction --- .../postgres/PostgresComputationsService.kt | 190 +++++++++--------- .../postgres/writers/DeleteComputation.kt | 25 +-- .../postgres/writers/PurgeComputations.kt | 50 +++++ 3 files changed, 146 insertions(+), 119 deletions(-) create mode 100644 src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/PurgeComputations.kt diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresComputationsService.kt b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresComputationsService.kt index c80ec74f7f1..2397c7e5012 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresComputationsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresComputationsService.kt @@ -43,6 +43,7 @@ import org.wfanet.measurement.duchy.deploy.common.postgres.writers.CreateComputa import org.wfanet.measurement.duchy.deploy.common.postgres.writers.DeleteComputation import org.wfanet.measurement.duchy.deploy.common.postgres.writers.EnqueueComputation import org.wfanet.measurement.duchy.deploy.common.postgres.writers.FinishComputation +import org.wfanet.measurement.duchy.deploy.common.postgres.writers.PurgeComputations import org.wfanet.measurement.duchy.deploy.common.postgres.writers.RecordOutputBlobPath import org.wfanet.measurement.duchy.deploy.common.postgres.writers.RecordRequisitionBlobPath import org.wfanet.measurement.duchy.deploy.common.postgres.writers.UpdateComputationDetails @@ -104,13 +105,13 @@ import org.wfanet.measurement.system.v1alpha.stageAttempt class PostgresComputationsService( private val computationTypeEnumHelper: ComputationTypeEnumHelper, private val protocolStagesEnumHelper: - ComputationProtocolStagesEnumHelper, + ComputationProtocolStagesEnumHelper, private val computationProtocolStageDetailsHelper: - ComputationProtocolStageDetailsHelper< - ComputationType, - ComputationStage, - ComputationStageDetails, - ComputationDetails, + ComputationProtocolStageDetailsHelper< + ComputationType, + ComputationStage, + ComputationStageDetails, + ComputationDetails, >, private val client: DatabaseClient, private val idGenerator: IdGenerator, @@ -136,18 +137,18 @@ class PostgresComputationsService( val token = try { CreateComputation( - request.globalComputationId, - request.computationType, - protocolStagesEnumHelper.getValidInitialStage(request.computationType).first(), - request.stageDetails, - request.computationDetails, - request.requisitionsList, - clock, - computationTypeEnumHelper, - protocolStagesEnumHelper, - computationProtocolStageDetailsHelper, - computationReader, - ) + request.globalComputationId, + request.computationType, + protocolStagesEnumHelper.getValidInitialStage(request.computationType).first(), + request.stageDetails, + request.computationDetails, + request.requisitionsList, + clock, + computationTypeEnumHelper, + protocolStagesEnumHelper, + computationProtocolStageDetailsHelper, + computationReader, + ) .execute(client, idGenerator) } catch (ex: ComputationInitialStageInvalidException) { throw ex.asStatusRuntimeException(Status.Code.INVALID_ARGUMENT) @@ -166,14 +167,14 @@ class PostgresComputationsService( val claimedToken = try { ClaimWork( - request.computationType, - request.owner, - lockDuration, - clock, - computationTypeEnumHelper, - protocolStagesEnumHelper, - computationReader, - ) + request.computationType, + request.owner, + lockDuration, + clock, + computationTypeEnumHelper, + protocolStagesEnumHelper, + computationReader, + ) .execute(client, idGenerator) } catch (e: ComputationNotFoundException) { throw e.asStatusRuntimeException(Status.Code.NOT_FOUND) @@ -203,8 +204,10 @@ class PostgresComputationsService( when (request.keyCase) { KeyCase.GLOBAL_COMPUTATION_ID -> computationReader.readComputationToken(client, request.globalComputationId) + KeyCase.REQUISITION_KEY -> computationReader.readComputationToken(client, request.requisitionKey) + KeyCase.KEY_NOT_SET -> failGrpc(Status.INVALID_ARGUMENT) { "key not set" } } ?: failGrpc(Status.NOT_FOUND) { "Computation not found" } @@ -245,28 +248,18 @@ class PostgresComputationsService( ) { "Requested stage list contains non terminal stage." } - var deleted = 0 - try { - val globalIds: Set = - computationReader.readGlobalComputationIds( - client.singleUse(), - request.stagesList, - request.updatedBefore.toInstant(), - ) - if (!request.force) { - return purgeComputationsResponse { - purgeCount = globalIds.size - purgeSample += globalIds - } - } - for (globalId in globalIds) { - val numOfRowsDeleted = DeleteComputation(globalId).execute(client, idGenerator) - deleted += numOfRowsDeleted.toInt() - } - } catch (e: Exception) { - logger.log(Level.WARNING, "Exception during Computations cleaning. $e") + val purgeResult = + PurgeComputations( + request.stagesList, + request.updatedBefore.toInstant(), + request.force, + computationReader + ).execute(client, idGenerator) + + return purgeComputationsResponse { + purgeCount = purgeResult.purgeCount + purgeResult.purgeSamples?.forEach { purgeSample += it } } - return purgeComputationsResponse { this.purgeCount = deleted } } override suspend fun finishComputation( @@ -274,21 +267,21 @@ class PostgresComputationsService( ): FinishComputationResponse { val token = FinishComputation( - request.token.toDatabaseEditToken(), - endingStage = request.endingComputationStage, - endComputationReason = - when (request.reason) { - ComputationDetails.CompletedReason.SUCCEEDED -> EndComputationReason.SUCCEEDED - ComputationDetails.CompletedReason.FAILED -> EndComputationReason.FAILED - ComputationDetails.CompletedReason.CANCELED -> EndComputationReason.CANCELED - else -> error("Unknown CompletedReason ${request.reason}") - }, - computationDetails = request.token.computationDetails, - clock = clock, - protocolStagesEnumHelper = protocolStagesEnumHelper, - protocolStageDetailsHelper = computationProtocolStageDetailsHelper, - computationReader = computationReader, - ) + request.token.toDatabaseEditToken(), + endingStage = request.endingComputationStage, + endComputationReason = + when (request.reason) { + ComputationDetails.CompletedReason.SUCCEEDED -> EndComputationReason.SUCCEEDED + ComputationDetails.CompletedReason.FAILED -> EndComputationReason.FAILED + ComputationDetails.CompletedReason.CANCELED -> EndComputationReason.CANCELED + else -> error("Unknown CompletedReason ${request.reason}") + }, + computationDetails = request.token.computationDetails, + clock = clock, + protocolStagesEnumHelper = protocolStagesEnumHelper, + protocolStageDetailsHelper = computationProtocolStageDetailsHelper, + computationReader = computationReader, + ) .execute(client, idGenerator) sendStatusUpdateToKingdom( @@ -310,12 +303,12 @@ class PostgresComputationsService( val token = UpdateComputationDetails( - token = request.token.toDatabaseEditToken(), - clock = clock, - computationDetails = request.details, - requisitionEntries = request.requisitionsList, - computationReader = computationReader, - ) + token = request.token.toDatabaseEditToken(), + clock = clock, + computationDetails = request.details, + requisitionEntries = request.requisitionsList, + computationReader = computationReader, + ) .execute(client, idGenerator) return token.toUpdateComputationDetailsResponse() @@ -326,12 +319,12 @@ class PostgresComputationsService( ): RecordOutputBlobPathResponse { val token = RecordOutputBlobPath( - token = request.token.toDatabaseEditToken(), - clock = clock, - blobRef = BlobRef(request.outputBlobId, request.blobPath), - protocolStagesEnumHelper = protocolStagesEnumHelper, - computationReader = computationReader, - ) + token = request.token.toDatabaseEditToken(), + clock = clock, + blobRef = BlobRef(request.outputBlobId, request.blobPath), + protocolStagesEnumHelper = protocolStagesEnumHelper, + computationReader = computationReader, + ) .execute(client, idGenerator) return token.toRecordOutputBlobPathResponse() @@ -346,10 +339,13 @@ class PostgresComputationsService( when (request.afterTransition) { AdvanceComputationStageRequest.AfterTransition.ADD_UNCLAIMED_TO_QUEUE -> AfterTransition.ADD_UNCLAIMED_TO_QUEUE + AdvanceComputationStageRequest.AfterTransition.DO_NOT_ADD_TO_QUEUE -> AfterTransition.DO_NOT_ADD_TO_QUEUE + AdvanceComputationStageRequest.AfterTransition.RETAIN_AND_EXTEND_LOCK -> AfterTransition.CONTINUE_WORKING + else -> error( "Unsupported AdvanceComputationStageRequest.AfterTransition '${request.afterTransition}'. ", @@ -358,18 +354,18 @@ class PostgresComputationsService( val token = AdvanceComputationStage( - request.token.toDatabaseEditToken(), - nextStage = request.nextComputationStage, - nextStageDetails = request.stageDetails, - inputBlobPaths = request.inputBlobsList, - passThroughBlobPaths = request.passThroughBlobsList, - outputBlobs = request.outputBlobs, - afterTransition = afterTransition, - lockExtension = lockExtension, - clock = clock, - protocolStagesEnumHelper = protocolStagesEnumHelper, - computationReader = computationReader, - ) + request.token.toDatabaseEditToken(), + nextStage = request.nextComputationStage, + nextStageDetails = request.stageDetails, + inputBlobPaths = request.inputBlobsList, + passThroughBlobPaths = request.passThroughBlobsList, + outputBlobs = request.outputBlobs, + afterTransition = afterTransition, + lockExtension = lockExtension, + clock = clock, + protocolStagesEnumHelper = protocolStagesEnumHelper, + computationReader = computationReader, + ) .execute(client, idGenerator) sendStatusUpdateToKingdom( @@ -396,11 +392,11 @@ class PostgresComputationsService( "DelaySecond ${request.delaySecond} should be non-negative." } EnqueueComputation( - request.token.localComputationId, - request.token.version, - request.delaySecond.toLong(), - clock, - ) + request.token.localComputationId, + request.token.version, + request.delaySecond.toLong(), + clock, + ) .execute(client, idGenerator) return EnqueueComputationResponse.getDefaultInstance() } @@ -410,12 +406,12 @@ class PostgresComputationsService( ): RecordRequisitionBlobPathResponse { val token = RecordRequisitionBlobPath( - clock = clock, - localId = request.token.localComputationId, - externalRequisitionKey = request.key, - pathToBlob = request.blobPath, - computationReader = computationReader, - ) + clock = clock, + localId = request.token.localComputationId, + externalRequisitionKey = request.key, + pathToBlob = request.blobPath, + computationReader = computationReader, + ) .execute(client, idGenerator) return token.toRecordRequisitionBlobPathResponse() diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/DeleteComputation.kt b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/DeleteComputation.kt index c719857b7a5..797aca39f90 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/DeleteComputation.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/DeleteComputation.kt @@ -16,29 +16,10 @@ package org.wfanet.measurement.duchy.deploy.common.postgres.writers import org.wfanet.measurement.common.db.r2dbc.postgres.PostgresWriter -/** [PostgresWriter] to delete a computation by its localComputationId or globalComputationId. */ -class DeleteComputation : PostgresWriter { - - private var globalId: String? = null - private var localId: Long? = null - - constructor(globalId: String) : super() { - this.globalId = globalId - } - - constructor(localId: Long) : super() { - this.localId = localId - } +/** [PostgresWriter] to delete a computation by its localComputationId. */ +class DeleteComputation(private val localId: Long) : PostgresWriter() { override suspend fun TransactionScope.runTransaction(): Long { - if (localId != null) { - return deleteComputationByLocalId(localId!!) - } - - if (globalId != null) { - return deleteComputationByGlobalId(globalId!!) - } - - return 0 + return deleteComputationByLocalId(localId) } } diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/PurgeComputations.kt b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/PurgeComputations.kt new file mode 100644 index 00000000000..c84ed15ad13 --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/PurgeComputations.kt @@ -0,0 +1,50 @@ +// Copyright 2023 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package org.wfanet.measurement.duchy.deploy.common.postgres.writers + +import java.time.Instant +import org.wfanet.measurement.common.db.r2dbc.postgres.PostgresWriter +import org.wfanet.measurement.duchy.deploy.common.postgres.readers.ComputationReader +import org.wfanet.measurement.internal.duchy.ComputationStage + +class PurgeComputations( + private val stages: List, + private val updatedBefore: Instant, + private val force: Boolean, + private val computationReader: ComputationReader +) : PostgresWriter() { + data class PurgeResult( + val purgeCount: Int, + val purgeSamples: Set? = null, + ) + + override suspend fun TransactionScope.runTransaction(): PurgeResult { + var deleted = 0 + val globalIds: Set = + computationReader.readGlobalComputationIds( + transactionContext, + stages, + updatedBefore, + ) + if (!force) { + return PurgeResult(globalIds.size, globalIds) + } + for (globalId in globalIds) { + val numOfRowsDeleted = deleteComputationByGlobalId(globalId) + deleted += numOfRowsDeleted.toInt() + } + return PurgeResult(deleted) + } +} From 1fbe58155920cd68881ed09834c6c0ea2b30d82d Mon Sep 17 00:00:00 2001 From: Yuhong Wang Date: Wed, 2 Aug 2023 23:22:40 +0000 Subject: [PATCH 10/12] lint --- .../postgres/PostgresComputationsService.kt | 168 +++++++++--------- 1 file changed, 82 insertions(+), 86 deletions(-) diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresComputationsService.kt b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresComputationsService.kt index 2397c7e5012..7c0854674ae 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresComputationsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresComputationsService.kt @@ -105,13 +105,13 @@ import org.wfanet.measurement.system.v1alpha.stageAttempt class PostgresComputationsService( private val computationTypeEnumHelper: ComputationTypeEnumHelper, private val protocolStagesEnumHelper: - ComputationProtocolStagesEnumHelper, + ComputationProtocolStagesEnumHelper, private val computationProtocolStageDetailsHelper: - ComputationProtocolStageDetailsHelper< - ComputationType, - ComputationStage, - ComputationStageDetails, - ComputationDetails, + ComputationProtocolStageDetailsHelper< + ComputationType, + ComputationStage, + ComputationStageDetails, + ComputationDetails, >, private val client: DatabaseClient, private val idGenerator: IdGenerator, @@ -137,18 +137,18 @@ class PostgresComputationsService( val token = try { CreateComputation( - request.globalComputationId, - request.computationType, - protocolStagesEnumHelper.getValidInitialStage(request.computationType).first(), - request.stageDetails, - request.computationDetails, - request.requisitionsList, - clock, - computationTypeEnumHelper, - protocolStagesEnumHelper, - computationProtocolStageDetailsHelper, - computationReader, - ) + request.globalComputationId, + request.computationType, + protocolStagesEnumHelper.getValidInitialStage(request.computationType).first(), + request.stageDetails, + request.computationDetails, + request.requisitionsList, + clock, + computationTypeEnumHelper, + protocolStagesEnumHelper, + computationProtocolStageDetailsHelper, + computationReader, + ) .execute(client, idGenerator) } catch (ex: ComputationInitialStageInvalidException) { throw ex.asStatusRuntimeException(Status.Code.INVALID_ARGUMENT) @@ -167,14 +167,14 @@ class PostgresComputationsService( val claimedToken = try { ClaimWork( - request.computationType, - request.owner, - lockDuration, - clock, - computationTypeEnumHelper, - protocolStagesEnumHelper, - computationReader, - ) + request.computationType, + request.owner, + lockDuration, + clock, + computationTypeEnumHelper, + protocolStagesEnumHelper, + computationReader, + ) .execute(client, idGenerator) } catch (e: ComputationNotFoundException) { throw e.asStatusRuntimeException(Status.Code.NOT_FOUND) @@ -204,10 +204,8 @@ class PostgresComputationsService( when (request.keyCase) { KeyCase.GLOBAL_COMPUTATION_ID -> computationReader.readComputationToken(client, request.globalComputationId) - KeyCase.REQUISITION_KEY -> computationReader.readComputationToken(client, request.requisitionKey) - KeyCase.KEY_NOT_SET -> failGrpc(Status.INVALID_ARGUMENT) { "key not set" } } ?: failGrpc(Status.NOT_FOUND) { "Computation not found" } @@ -250,11 +248,12 @@ class PostgresComputationsService( } val purgeResult = PurgeComputations( - request.stagesList, - request.updatedBefore.toInstant(), - request.force, - computationReader - ).execute(client, idGenerator) + request.stagesList, + request.updatedBefore.toInstant(), + request.force, + computationReader + ) + .execute(client, idGenerator) return purgeComputationsResponse { purgeCount = purgeResult.purgeCount @@ -267,21 +266,21 @@ class PostgresComputationsService( ): FinishComputationResponse { val token = FinishComputation( - request.token.toDatabaseEditToken(), - endingStage = request.endingComputationStage, - endComputationReason = - when (request.reason) { - ComputationDetails.CompletedReason.SUCCEEDED -> EndComputationReason.SUCCEEDED - ComputationDetails.CompletedReason.FAILED -> EndComputationReason.FAILED - ComputationDetails.CompletedReason.CANCELED -> EndComputationReason.CANCELED - else -> error("Unknown CompletedReason ${request.reason}") - }, - computationDetails = request.token.computationDetails, - clock = clock, - protocolStagesEnumHelper = protocolStagesEnumHelper, - protocolStageDetailsHelper = computationProtocolStageDetailsHelper, - computationReader = computationReader, - ) + request.token.toDatabaseEditToken(), + endingStage = request.endingComputationStage, + endComputationReason = + when (request.reason) { + ComputationDetails.CompletedReason.SUCCEEDED -> EndComputationReason.SUCCEEDED + ComputationDetails.CompletedReason.FAILED -> EndComputationReason.FAILED + ComputationDetails.CompletedReason.CANCELED -> EndComputationReason.CANCELED + else -> error("Unknown CompletedReason ${request.reason}") + }, + computationDetails = request.token.computationDetails, + clock = clock, + protocolStagesEnumHelper = protocolStagesEnumHelper, + protocolStageDetailsHelper = computationProtocolStageDetailsHelper, + computationReader = computationReader, + ) .execute(client, idGenerator) sendStatusUpdateToKingdom( @@ -303,12 +302,12 @@ class PostgresComputationsService( val token = UpdateComputationDetails( - token = request.token.toDatabaseEditToken(), - clock = clock, - computationDetails = request.details, - requisitionEntries = request.requisitionsList, - computationReader = computationReader, - ) + token = request.token.toDatabaseEditToken(), + clock = clock, + computationDetails = request.details, + requisitionEntries = request.requisitionsList, + computationReader = computationReader, + ) .execute(client, idGenerator) return token.toUpdateComputationDetailsResponse() @@ -319,12 +318,12 @@ class PostgresComputationsService( ): RecordOutputBlobPathResponse { val token = RecordOutputBlobPath( - token = request.token.toDatabaseEditToken(), - clock = clock, - blobRef = BlobRef(request.outputBlobId, request.blobPath), - protocolStagesEnumHelper = protocolStagesEnumHelper, - computationReader = computationReader, - ) + token = request.token.toDatabaseEditToken(), + clock = clock, + blobRef = BlobRef(request.outputBlobId, request.blobPath), + protocolStagesEnumHelper = protocolStagesEnumHelper, + computationReader = computationReader, + ) .execute(client, idGenerator) return token.toRecordOutputBlobPathResponse() @@ -339,13 +338,10 @@ class PostgresComputationsService( when (request.afterTransition) { AdvanceComputationStageRequest.AfterTransition.ADD_UNCLAIMED_TO_QUEUE -> AfterTransition.ADD_UNCLAIMED_TO_QUEUE - AdvanceComputationStageRequest.AfterTransition.DO_NOT_ADD_TO_QUEUE -> AfterTransition.DO_NOT_ADD_TO_QUEUE - AdvanceComputationStageRequest.AfterTransition.RETAIN_AND_EXTEND_LOCK -> AfterTransition.CONTINUE_WORKING - else -> error( "Unsupported AdvanceComputationStageRequest.AfterTransition '${request.afterTransition}'. ", @@ -354,18 +350,18 @@ class PostgresComputationsService( val token = AdvanceComputationStage( - request.token.toDatabaseEditToken(), - nextStage = request.nextComputationStage, - nextStageDetails = request.stageDetails, - inputBlobPaths = request.inputBlobsList, - passThroughBlobPaths = request.passThroughBlobsList, - outputBlobs = request.outputBlobs, - afterTransition = afterTransition, - lockExtension = lockExtension, - clock = clock, - protocolStagesEnumHelper = protocolStagesEnumHelper, - computationReader = computationReader, - ) + request.token.toDatabaseEditToken(), + nextStage = request.nextComputationStage, + nextStageDetails = request.stageDetails, + inputBlobPaths = request.inputBlobsList, + passThroughBlobPaths = request.passThroughBlobsList, + outputBlobs = request.outputBlobs, + afterTransition = afterTransition, + lockExtension = lockExtension, + clock = clock, + protocolStagesEnumHelper = protocolStagesEnumHelper, + computationReader = computationReader, + ) .execute(client, idGenerator) sendStatusUpdateToKingdom( @@ -392,11 +388,11 @@ class PostgresComputationsService( "DelaySecond ${request.delaySecond} should be non-negative." } EnqueueComputation( - request.token.localComputationId, - request.token.version, - request.delaySecond.toLong(), - clock, - ) + request.token.localComputationId, + request.token.version, + request.delaySecond.toLong(), + clock, + ) .execute(client, idGenerator) return EnqueueComputationResponse.getDefaultInstance() } @@ -406,12 +402,12 @@ class PostgresComputationsService( ): RecordRequisitionBlobPathResponse { val token = RecordRequisitionBlobPath( - clock = clock, - localId = request.token.localComputationId, - externalRequisitionKey = request.key, - pathToBlob = request.blobPath, - computationReader = computationReader, - ) + clock = clock, + localId = request.token.localComputationId, + externalRequisitionKey = request.key, + pathToBlob = request.blobPath, + computationReader = computationReader, + ) .execute(client, idGenerator) return token.toRecordRequisitionBlobPathResponse() From ccf5922a52a3865d9b146a21bb9d4c341ae44d38 Mon Sep 17 00:00:00 2001 From: Yuhong Wang Date: Wed, 2 Aug 2023 23:23:41 +0000 Subject: [PATCH 11/12] revert --- src/main/proto/wfa/measurement/internal/duchy/error_code.proto | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/main/proto/wfa/measurement/internal/duchy/error_code.proto b/src/main/proto/wfa/measurement/internal/duchy/error_code.proto index f65970465e4..8d3530c27c4 100644 --- a/src/main/proto/wfa/measurement/internal/duchy/error_code.proto +++ b/src/main/proto/wfa/measurement/internal/duchy/error_code.proto @@ -39,7 +39,4 @@ enum ErrorCode { /** Computation with the same global ID already exists. */ COMPUTATION_ALREADY_EXISTS = 6; - - /** Data corrupted for unknown reasons. */ - DATA_CORRUPTED = 7; } From 388ae1c1986efa40a7afa72c28a96f37d0d42494 Mon Sep 17 00:00:00 2001 From: Yuhong Wang Date: Thu, 3 Aug 2023 18:28:50 +0000 Subject: [PATCH 12/12] update --- .../duchy/deploy/common/postgres/PostgresComputationsService.kt | 2 +- .../duchy/deploy/common/postgres/writers/PurgeComputations.kt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresComputationsService.kt b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresComputationsService.kt index 7c0854674ae..ec9d7d5ecee 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresComputationsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/PostgresComputationsService.kt @@ -257,7 +257,7 @@ class PostgresComputationsService( return purgeComputationsResponse { purgeCount = purgeResult.purgeCount - purgeResult.purgeSamples?.forEach { purgeSample += it } + purgeSample += purgeResult.purgeSamples } } diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/PurgeComputations.kt b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/PurgeComputations.kt index c84ed15ad13..e07c9ed5868 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/PurgeComputations.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/common/postgres/writers/PurgeComputations.kt @@ -27,7 +27,7 @@ class PurgeComputations( ) : PostgresWriter() { data class PurgeResult( val purgeCount: Int, - val purgeSamples: Set? = null, + val purgeSamples: Set = emptySet(), ) override suspend fun TransactionScope.runTransaction(): PurgeResult {