diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/postgres/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/postgres/BUILD.bazel new file mode 100644 index 00000000000..e8b450de507 --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/postgres/BUILD.bazel @@ -0,0 +1,20 @@ +load("@io_bazel_rules_kotlin//kotlin:jvm.bzl", "kt_jvm_library") + +kt_jvm_library( + name = "services", + srcs = glob(["*Service.kt"]), + visibility = [ + "//src/main/kotlin/org/wfanet/measurement/duchy/deploy/postgres:__pkg__", + "//src/test/kotlin/org/wfanet/measurement/duchy/deploy/postgres:__pkg__", + ], + deps = [ + "//src/main/kotlin/org/wfanet/measurement/duchy/deploy/postgres/readers", + "//src/main/kotlin/org/wfanet/measurement/duchy/deploy/postgres/writers", + "//src/main/proto/wfa/measurement/internal/duchy:continuation_tokens_service_kt_jvm_grpc_proto", + "//src/main/proto/wfa/measurement/system/v1alpha:computations_service_kt_jvm_grpc_proto", + "@wfa_common_jvm//imports/java/io/grpc:api", + "@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/db/r2dbc", + ], +) diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/postgres/PostgresContinuationTokensService.kt b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/postgres/PostgresContinuationTokensService.kt new file mode 100644 index 00000000000..9f705329774 --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/postgres/PostgresContinuationTokensService.kt @@ -0,0 +1,57 @@ +// Copyright 2023 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package org.wfanet.measurement.duchy.deploy.postgres + +import io.grpc.Status +import org.wfanet.measurement.common.db.r2dbc.DatabaseClient +import org.wfanet.measurement.common.identity.IdGenerator +import org.wfanet.measurement.duchy.deploy.postgres.readers.ContinuationTokenReader +import org.wfanet.measurement.duchy.deploy.postgres.writers.SetContinuationToken +import org.wfanet.measurement.duchy.service.internal.ContinuationTokenInvalidException +import org.wfanet.measurement.duchy.service.internal.ContinuationTokenMalformedException +import org.wfanet.measurement.internal.duchy.ContinuationTokensGrpcKt.ContinuationTokensCoroutineImplBase +import org.wfanet.measurement.internal.duchy.GetContinuationTokenRequest +import org.wfanet.measurement.internal.duchy.GetContinuationTokenResponse +import org.wfanet.measurement.internal.duchy.SetContinuationTokenRequest +import org.wfanet.measurement.internal.duchy.SetContinuationTokenResponse +import org.wfanet.measurement.internal.duchy.getContinuationTokenResponse + +class PostgresContinuationTokensService( + private val client: DatabaseClient, + private val idGenerator: IdGenerator, +) : ContinuationTokensCoroutineImplBase() { + + override suspend fun getContinuationToken( + request: GetContinuationTokenRequest + ): GetContinuationTokenResponse { + val result: ContinuationTokenReader.Result = + ContinuationTokenReader().getContinuationToken(client.singleUse()) + ?: return GetContinuationTokenResponse.getDefaultInstance() + return getContinuationTokenResponse { token = result.continuationToken } + } + + override suspend fun setContinuationToken( + request: SetContinuationTokenRequest + ): SetContinuationTokenResponse { + try { + SetContinuationToken(request.token).execute(client, idGenerator) + } catch (e: ContinuationTokenInvalidException) { + throw e.asStatusRuntimeException(Status.Code.FAILED_PRECONDITION) + } catch (e: ContinuationTokenMalformedException) { + throw e.asStatusRuntimeException(Status.Code.INVALID_ARGUMENT) + } + return SetContinuationTokenResponse.getDefaultInstance() + } +} diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/postgres/readers/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/postgres/readers/BUILD.bazel new file mode 100644 index 00000000000..fb3a5571afb --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/postgres/readers/BUILD.bazel @@ -0,0 +1,18 @@ +load("@io_bazel_rules_kotlin//kotlin:jvm.bzl", "kt_jvm_library") + +package(default_visibility = [ + "//src/main/kotlin/org/wfanet/measurement/duchy/deploy/postgres:__subpackages__", +]) + +kt_jvm_library( + name = "readers", + srcs = glob(["*.kt"]), + deps = [ + "//src/main/proto/wfa/measurement/internal/duchy:continuation_tokens_service_kt_jvm_grpc_proto", + "//src/main/proto/wfa/measurement/system/v1alpha:computations_service_kt_jvm_grpc_proto", + "@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/db/r2dbc", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/identity", + ], +) diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/postgres/readers/ContinuationTokenReader.kt b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/postgres/readers/ContinuationTokenReader.kt new file mode 100644 index 00000000000..27471dcafe9 --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/postgres/readers/ContinuationTokenReader.kt @@ -0,0 +1,46 @@ +// Copyright 2023 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package org.wfanet.measurement.duchy.deploy.postgres.readers + +import kotlinx.coroutines.flow.firstOrNull +import org.wfanet.measurement.common.db.r2dbc.ReadContext +import org.wfanet.measurement.common.db.r2dbc.ResultRow +import org.wfanet.measurement.common.db.r2dbc.boundStatement + +class ContinuationTokenReader { + companion object { + private const val parameterizedQueryString = + """ + SELECT ContinuationToken + FROM HeraldContinuationTokens + Limit 1 + """ + } + + data class Result(val continuationToken: String) + + fun translate(row: ResultRow): Result = Result(row["ContinuationToken"]) + + /** + * Reads a ContinuationToken from the HeraldContinuationTokens table. + * + * @return [Result] when a ContinuationToken is found. + * @return null when there is no ContinuationToken. + */ + suspend fun getContinuationToken(readContext: ReadContext): Result? { + val statement = boundStatement(parameterizedQueryString) + return readContext.executeQuery(statement).consume(::translate).firstOrNull() + } +} diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/postgres/writers/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/postgres/writers/BUILD.bazel new file mode 100644 index 00000000000..c0473c47d6c --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/postgres/writers/BUILD.bazel @@ -0,0 +1,22 @@ +load("@io_bazel_rules_kotlin//kotlin:jvm.bzl", "kt_jvm_library") + +package(default_visibility = [ + "//src/main/kotlin/org/wfanet/measurement/duchy/deploy/postgres:__subpackages__", +]) + +kt_jvm_library( + name = "writers", + srcs = glob(["*.kt"]), + deps = [ + "//src/main/kotlin/org/wfanet/measurement/duchy/deploy/postgres/readers", + "//src/main/kotlin/org/wfanet/measurement/duchy/service/internal:internal_exception", + "//src/main/proto/wfa/measurement/internal/duchy:continuation_tokens_service_kt_jvm_grpc_proto", + "//src/main/proto/wfa/measurement/system/v1alpha:computations_service_kt_jvm_grpc_proto", + "@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/db/r2dbc", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/db/r2dbc/postgres", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/identity", + ], +) diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/deploy/postgres/writers/SetContinutationToken.kt b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/postgres/writers/SetContinutationToken.kt new file mode 100644 index 00000000000..63a23e631d3 --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/duchy/deploy/postgres/writers/SetContinutationToken.kt @@ -0,0 +1,85 @@ +// Copyright 2023 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package org.wfanet.measurement.duchy.deploy.postgres.writers + +import com.google.protobuf.InvalidProtocolBufferException +import com.google.protobuf.util.Timestamps +import java.time.Instant +import org.wfanet.measurement.common.base64UrlDecode +import org.wfanet.measurement.common.db.r2dbc.boundStatement +import org.wfanet.measurement.common.db.r2dbc.postgres.PostgresWriter +import org.wfanet.measurement.duchy.deploy.postgres.readers.ContinuationTokenReader +import org.wfanet.measurement.duchy.service.internal.ContinuationTokenInvalidException +import org.wfanet.measurement.duchy.service.internal.ContinuationTokenMalformedException +import org.wfanet.measurement.duchy.service.internal.DuchyInternalException +import org.wfanet.measurement.system.v1alpha.StreamActiveComputationsContinuationToken + +/** + * [PostgresWriter] for setting continuation tokens. + * + * Throws a subclass of [DuchyInternalException] on [execute]: + * * [ContinuationTokenMalformedException] when the new token is malformed + * * [ContinuationTokenInvalidException] when the new token is invalid + */ +class SetContinuationToken(private val continuationToken: String) : PostgresWriter() { + private fun decodeContinuationToken(token: String): StreamActiveComputationsContinuationToken = + try { + StreamActiveComputationsContinuationToken.parseFrom(token.base64UrlDecode()) + } catch (e: InvalidProtocolBufferException) { + throw ContinuationTokenMalformedException( + continuationToken, + "ContinuationToken is malformed." + ) + } + + override suspend fun TransactionScope.runTransaction() { + val statement = + boundStatement( + """ + INSERT INTO HeraldContinuationTokens (Presence, ContinuationToken, UpdateTime) + VALUES ($1, $2, $3) + ON CONFLICT (Presence) + DO + UPDATE SET ContinuationToken = EXCLUDED.ContinuationToken, UpdateTime = EXCLUDED.UpdateTime; + """ + ) { + bind("$1", true) + bind("$2", continuationToken) + bind("$3", Instant.now()) + } + + transactionContext.run { + val newContinuationToken = decodeContinuationToken(continuationToken) + val oldContinuationToken = + ContinuationTokenReader().getContinuationToken(transactionContext)?.continuationToken?.let { + decodeContinuationToken(it) + } + + if ( + oldContinuationToken != null && + Timestamps.compare( + newContinuationToken.updateTimeSince, + oldContinuationToken.updateTimeSince + ) < 0 + ) { + throw ContinuationTokenInvalidException( + continuationToken, + "ContinuationToken to set cannot have older timestamp." + ) + } + transactionContext.executeStatement(statement) + } + } +} diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/BUILD.bazel new file mode 100644 index 00000000000..f91312c9fae --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/BUILD.bazel @@ -0,0 +1,18 @@ +load("@io_bazel_rules_kotlin//kotlin:jvm.bzl", "kt_jvm_library") + +package(default_visibility = [ + "//src/main/kotlin/org/wfanet/measurement/duchy/deploy:__subpackages__", +]) + +kt_jvm_library( + name = "internal_exception", + srcs = ["DuchyInternalException.kt"], + deps = [ + "//src/main/proto/google/rpc:error_details_kt_jvm_proto", + "//src/main/proto/google/rpc:status_kt_jvm_proto", + "//src/main/proto/wfa/measurement/internal/duchy:error_code_kt_jvm_proto", + "@wfa_common_jvm//imports/java/com/google/protobuf", + "@wfa_common_jvm//imports/java/io/grpc/protobuf", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common", + ], +) diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/DuchyInternalException.kt b/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/DuchyInternalException.kt new file mode 100644 index 00000000000..3d1f1a44c1f --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/duchy/service/internal/DuchyInternalException.kt @@ -0,0 +1,63 @@ +// Copyright 2023 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package org.wfanet.measurement.duchy.service.internal + +import com.google.protobuf.Any +import com.google.rpc.errorInfo +import com.google.rpc.status +import io.grpc.Status +import io.grpc.StatusRuntimeException +import io.grpc.protobuf.StatusProto +import org.wfanet.measurement.internal.duchy.ErrorCode + +sealed class DuchyInternalException(val code: ErrorCode, override val message: String) : + Exception() { + protected abstract val context: Map + + fun asStatusRuntimeException( + statusCode: Status.Code, + message: String = this.message + ): StatusRuntimeException { + val statusProto = status { + code = statusCode.value() + this.message = message + details += + Any.pack( + errorInfo { + reason = this@DuchyInternalException.code.toString() + domain = ErrorCode.getDescriptor().fullName + metadata.putAll(context) + } + ) + } + return StatusProto.toStatusRuntimeException(statusProto) + } +} + +class ContinuationTokenInvalidException( + val continuationToken: String, + message: String, +) : DuchyInternalException(ErrorCode.CONTINUATION_TOKEN_INVALID, message) { + override val context + get() = mapOf("continuation_token" to continuationToken) +} + +class ContinuationTokenMalformedException( + val continuationToken: String, + message: String, +) : DuchyInternalException(ErrorCode.CONTINUATION_TOKEN_MALFORMED, message) { + override val context + get() = mapOf("continuation_token" to continuationToken) +} diff --git a/src/main/proto/wfa/measurement/internal/duchy/BUILD.bazel b/src/main/proto/wfa/measurement/internal/duchy/BUILD.bazel index 62a547cf37d..cbb0c4a6f47 100644 --- a/src/main/proto/wfa/measurement/internal/duchy/BUILD.bazel +++ b/src/main/proto/wfa/measurement/internal/duchy/BUILD.bazel @@ -208,3 +208,19 @@ cc_proto_library( name = "differential_privacy_cc_proto", deps = [":differential_privacy_proto"], ) + +proto_library( + name = "error_code_proto", + srcs = ["error_code.proto"], +) + +java_proto_library( + name = "error_code_java_proto", + deps = [":error_code_proto"], +) + +kt_jvm_proto_library( + name = "error_code_kt_jvm_proto", + srcs = [":error_code_proto"], + deps = [":error_code_java_proto"], +) diff --git a/src/main/proto/wfa/measurement/internal/duchy/error_code.proto b/src/main/proto/wfa/measurement/internal/duchy/error_code.proto new file mode 100644 index 00000000000..602f976f335 --- /dev/null +++ b/src/main/proto/wfa/measurement/internal/duchy/error_code.proto @@ -0,0 +1,30 @@ +// Copyright 2023 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package wfa.measurement.internal.duchy; + +option java_package = "org.wfanet.measurement.internal.duchy"; +option java_multiple_files = true; + +enum ErrorCode { + ERROR_CODE_UNSPECIFIED = 0; + + /** ContinuationToken is invalid. */ + CONTINUATION_TOKEN_INVALID = 1; + + /** ContinuationToken is malformed. */ + CONTINUATION_TOKEN_MALFORMED = 2; +} diff --git a/src/test/kotlin/org/wfanet/measurement/duchy/deploy/postgres/BUILD.bazel b/src/test/kotlin/org/wfanet/measurement/duchy/deploy/postgres/BUILD.bazel index e7aeea5e21f..7de51530aca 100644 --- a/src/test/kotlin/org/wfanet/measurement/duchy/deploy/postgres/BUILD.bazel +++ b/src/test/kotlin/org/wfanet/measurement/duchy/deploy/postgres/BUILD.bazel @@ -12,3 +12,18 @@ kt_jvm_test( "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/db/r2dbc/postgres/testing:embedded_postgres", ], ) + +kt_jvm_test( + name = "PostgresContinuationTokensServiceTest", + srcs = ["PostgresContinuationTokensServiceTest.kt"], + test_class = "org.wfanet.measurement.duchy.deploy.postgres.PostgresContinuationTokensServiceTest", + deps = [ + "//src/main/kotlin/org/wfanet/measurement/duchy/deploy/postgres:services", + "//src/main/kotlin/org/wfanet/measurement/duchy/deploy/postgres/testing", + "//src/main/kotlin/org/wfanet/measurement/duchy/service/internal/testing", + "@wfa_common_jvm//imports/java/com/google/common/truth", + "@wfa_common_jvm//imports/java/com/opentable/db/postgres:pg_embedded", + "@wfa_common_jvm//imports/java/org/junit", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/db/r2dbc/postgres/testing:embedded_postgres", + ], +) diff --git a/src/test/kotlin/org/wfanet/measurement/duchy/deploy/postgres/PostgresContinuationTokensServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/duchy/deploy/postgres/PostgresContinuationTokensServiceTest.kt new file mode 100644 index 00000000000..4c2c6505d39 --- /dev/null +++ b/src/test/kotlin/org/wfanet/measurement/duchy/deploy/postgres/PostgresContinuationTokensServiceTest.kt @@ -0,0 +1,34 @@ +// Copyright 2023 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package org.wfanet.measurement.duchy.deploy.postgres + +import java.time.Clock +import kotlin.random.Random +import org.junit.runner.RunWith +import org.junit.runners.JUnit4 +import org.wfanet.measurement.common.db.r2dbc.postgres.testing.EmbeddedPostgresDatabaseProvider +import org.wfanet.measurement.common.identity.RandomIdGenerator +import org.wfanet.measurement.duchy.deploy.postgres.testing.Schemata.DUCHY_CHANGELOG_PATH +import org.wfanet.measurement.duchy.service.internal.testing.ContinuationTokensServiceTest + +@RunWith(JUnit4::class) +class PostgresContinuationTokensServiceTest : + ContinuationTokensServiceTest() { + override fun newService(): PostgresContinuationTokensService { + val client = EmbeddedPostgresDatabaseProvider(DUCHY_CHANGELOG_PATH).createNewDatabase() + val idGenerator = RandomIdGenerator(Clock.systemUTC(), Random(1)) + return PostgresContinuationTokensService(client, idGenerator) + } +}