Skip to content

Commit

Permalink
Implement postgres continuation token service (#1059)
Browse files Browse the repository at this point in the history
This implements the Postgres continuation token service and its dependent Postgres readers/writers.
  • Loading branch information
YuhongWang-Amazon authored and ple13 committed Aug 16, 2024
1 parent ede9e77 commit 14e17f4
Show file tree
Hide file tree
Showing 12 changed files with 424 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
load("@io_bazel_rules_kotlin//kotlin:jvm.bzl", "kt_jvm_library")

kt_jvm_library(
name = "services",
srcs = glob(["*Service.kt"]),
visibility = [
"//src/main/kotlin/org/wfanet/measurement/duchy/deploy/postgres:__pkg__",
"//src/test/kotlin/org/wfanet/measurement/duchy/deploy/postgres:__pkg__",
],
deps = [
"//src/main/kotlin/org/wfanet/measurement/duchy/deploy/postgres/readers",
"//src/main/kotlin/org/wfanet/measurement/duchy/deploy/postgres/writers",
"//src/main/proto/wfa/measurement/internal/duchy:continuation_tokens_service_kt_jvm_grpc_proto",
"//src/main/proto/wfa/measurement/system/v1alpha:computations_service_kt_jvm_grpc_proto",
"@wfa_common_jvm//imports/java/io/grpc:api",
"@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/db/r2dbc",
],
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// Copyright 2023 The Cross-Media Measurement Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package org.wfanet.measurement.duchy.deploy.postgres

import io.grpc.Status
import org.wfanet.measurement.common.db.r2dbc.DatabaseClient
import org.wfanet.measurement.common.identity.IdGenerator
import org.wfanet.measurement.duchy.deploy.postgres.readers.ContinuationTokenReader
import org.wfanet.measurement.duchy.deploy.postgres.writers.SetContinuationToken
import org.wfanet.measurement.duchy.service.internal.ContinuationTokenInvalidException
import org.wfanet.measurement.duchy.service.internal.ContinuationTokenMalformedException
import org.wfanet.measurement.internal.duchy.ContinuationTokensGrpcKt.ContinuationTokensCoroutineImplBase
import org.wfanet.measurement.internal.duchy.GetContinuationTokenRequest
import org.wfanet.measurement.internal.duchy.GetContinuationTokenResponse
import org.wfanet.measurement.internal.duchy.SetContinuationTokenRequest
import org.wfanet.measurement.internal.duchy.SetContinuationTokenResponse
import org.wfanet.measurement.internal.duchy.getContinuationTokenResponse

class PostgresContinuationTokensService(
private val client: DatabaseClient,
private val idGenerator: IdGenerator,
) : ContinuationTokensCoroutineImplBase() {

override suspend fun getContinuationToken(
request: GetContinuationTokenRequest
): GetContinuationTokenResponse {
val result: ContinuationTokenReader.Result =
ContinuationTokenReader().getContinuationToken(client.singleUse())
?: return GetContinuationTokenResponse.getDefaultInstance()
return getContinuationTokenResponse { token = result.continuationToken }
}

override suspend fun setContinuationToken(
request: SetContinuationTokenRequest
): SetContinuationTokenResponse {
try {
SetContinuationToken(request.token).execute(client, idGenerator)
} catch (e: ContinuationTokenInvalidException) {
throw e.asStatusRuntimeException(Status.Code.FAILED_PRECONDITION)
} catch (e: ContinuationTokenMalformedException) {
throw e.asStatusRuntimeException(Status.Code.INVALID_ARGUMENT)
}
return SetContinuationTokenResponse.getDefaultInstance()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
load("@io_bazel_rules_kotlin//kotlin:jvm.bzl", "kt_jvm_library")

package(default_visibility = [
"//src/main/kotlin/org/wfanet/measurement/duchy/deploy/postgres:__subpackages__",
])

kt_jvm_library(
name = "readers",
srcs = glob(["*.kt"]),
deps = [
"//src/main/proto/wfa/measurement/internal/duchy:continuation_tokens_service_kt_jvm_grpc_proto",
"//src/main/proto/wfa/measurement/system/v1alpha:computations_service_kt_jvm_grpc_proto",
"@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/db/r2dbc",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/identity",
],
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright 2023 The Cross-Media Measurement Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package org.wfanet.measurement.duchy.deploy.postgres.readers

import kotlinx.coroutines.flow.firstOrNull
import org.wfanet.measurement.common.db.r2dbc.ReadContext
import org.wfanet.measurement.common.db.r2dbc.ResultRow
import org.wfanet.measurement.common.db.r2dbc.boundStatement

class ContinuationTokenReader {
companion object {
private const val parameterizedQueryString =
"""
SELECT ContinuationToken
FROM HeraldContinuationTokens
Limit 1
"""
}

data class Result(val continuationToken: String)

fun translate(row: ResultRow): Result = Result(row["ContinuationToken"])

/**
* Reads a ContinuationToken from the HeraldContinuationTokens table.
*
* @return [Result] when a ContinuationToken is found.
* @return null when there is no ContinuationToken.
*/
suspend fun getContinuationToken(readContext: ReadContext): Result? {
val statement = boundStatement(parameterizedQueryString)
return readContext.executeQuery(statement).consume(::translate).firstOrNull()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
load("@io_bazel_rules_kotlin//kotlin:jvm.bzl", "kt_jvm_library")

package(default_visibility = [
"//src/main/kotlin/org/wfanet/measurement/duchy/deploy/postgres:__subpackages__",
])

kt_jvm_library(
name = "writers",
srcs = glob(["*.kt"]),
deps = [
"//src/main/kotlin/org/wfanet/measurement/duchy/deploy/postgres/readers",
"//src/main/kotlin/org/wfanet/measurement/duchy/service/internal:internal_exception",
"//src/main/proto/wfa/measurement/internal/duchy:continuation_tokens_service_kt_jvm_grpc_proto",
"//src/main/proto/wfa/measurement/system/v1alpha:computations_service_kt_jvm_grpc_proto",
"@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/db/r2dbc",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/db/r2dbc/postgres",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/identity",
],
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// Copyright 2023 The Cross-Media Measurement Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package org.wfanet.measurement.duchy.deploy.postgres.writers

import com.google.protobuf.InvalidProtocolBufferException
import com.google.protobuf.util.Timestamps
import java.time.Instant
import org.wfanet.measurement.common.base64UrlDecode
import org.wfanet.measurement.common.db.r2dbc.boundStatement
import org.wfanet.measurement.common.db.r2dbc.postgres.PostgresWriter
import org.wfanet.measurement.duchy.deploy.postgres.readers.ContinuationTokenReader
import org.wfanet.measurement.duchy.service.internal.ContinuationTokenInvalidException
import org.wfanet.measurement.duchy.service.internal.ContinuationTokenMalformedException
import org.wfanet.measurement.duchy.service.internal.DuchyInternalException
import org.wfanet.measurement.system.v1alpha.StreamActiveComputationsContinuationToken

/**
* [PostgresWriter] for setting continuation tokens.
*
* Throws a subclass of [DuchyInternalException] on [execute]:
* * [ContinuationTokenMalformedException] when the new token is malformed
* * [ContinuationTokenInvalidException] when the new token is invalid
*/
class SetContinuationToken(private val continuationToken: String) : PostgresWriter<Unit>() {
private fun decodeContinuationToken(token: String): StreamActiveComputationsContinuationToken =
try {
StreamActiveComputationsContinuationToken.parseFrom(token.base64UrlDecode())
} catch (e: InvalidProtocolBufferException) {
throw ContinuationTokenMalformedException(
continuationToken,
"ContinuationToken is malformed."
)
}

override suspend fun TransactionScope.runTransaction() {
val statement =
boundStatement(
"""
INSERT INTO HeraldContinuationTokens (Presence, ContinuationToken, UpdateTime)
VALUES ($1, $2, $3)
ON CONFLICT (Presence)
DO
UPDATE SET ContinuationToken = EXCLUDED.ContinuationToken, UpdateTime = EXCLUDED.UpdateTime;
"""
) {
bind("$1", true)
bind("$2", continuationToken)
bind("$3", Instant.now())
}

transactionContext.run {
val newContinuationToken = decodeContinuationToken(continuationToken)
val oldContinuationToken =
ContinuationTokenReader().getContinuationToken(transactionContext)?.continuationToken?.let {
decodeContinuationToken(it)
}

if (
oldContinuationToken != null &&
Timestamps.compare(
newContinuationToken.updateTimeSince,
oldContinuationToken.updateTimeSince
) < 0
) {
throw ContinuationTokenInvalidException(
continuationToken,
"ContinuationToken to set cannot have older timestamp."
)
}
transactionContext.executeStatement(statement)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
load("@io_bazel_rules_kotlin//kotlin:jvm.bzl", "kt_jvm_library")

package(default_visibility = [
"//src/main/kotlin/org/wfanet/measurement/duchy/deploy:__subpackages__",
])

kt_jvm_library(
name = "internal_exception",
srcs = ["DuchyInternalException.kt"],
deps = [
"//src/main/proto/google/rpc:error_details_kt_jvm_proto",
"//src/main/proto/google/rpc:status_kt_jvm_proto",
"//src/main/proto/wfa/measurement/internal/duchy:error_code_kt_jvm_proto",
"@wfa_common_jvm//imports/java/com/google/protobuf",
"@wfa_common_jvm//imports/java/io/grpc/protobuf",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common",
],
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// Copyright 2023 The Cross-Media Measurement Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package org.wfanet.measurement.duchy.service.internal

import com.google.protobuf.Any
import com.google.rpc.errorInfo
import com.google.rpc.status
import io.grpc.Status
import io.grpc.StatusRuntimeException
import io.grpc.protobuf.StatusProto
import org.wfanet.measurement.internal.duchy.ErrorCode

sealed class DuchyInternalException(val code: ErrorCode, override val message: String) :
Exception() {
protected abstract val context: Map<String, String>

fun asStatusRuntimeException(
statusCode: Status.Code,
message: String = this.message
): StatusRuntimeException {
val statusProto = status {
code = statusCode.value()
this.message = message
details +=
Any.pack(
errorInfo {
reason = this@DuchyInternalException.code.toString()
domain = ErrorCode.getDescriptor().fullName
metadata.putAll(context)
}
)
}
return StatusProto.toStatusRuntimeException(statusProto)
}
}

class ContinuationTokenInvalidException(
val continuationToken: String,
message: String,
) : DuchyInternalException(ErrorCode.CONTINUATION_TOKEN_INVALID, message) {
override val context
get() = mapOf("continuation_token" to continuationToken)
}

class ContinuationTokenMalformedException(
val continuationToken: String,
message: String,
) : DuchyInternalException(ErrorCode.CONTINUATION_TOKEN_MALFORMED, message) {
override val context
get() = mapOf("continuation_token" to continuationToken)
}
16 changes: 16 additions & 0 deletions src/main/proto/wfa/measurement/internal/duchy/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -208,3 +208,19 @@ cc_proto_library(
name = "differential_privacy_cc_proto",
deps = [":differential_privacy_proto"],
)

proto_library(
name = "error_code_proto",
srcs = ["error_code.proto"],
)

java_proto_library(
name = "error_code_java_proto",
deps = [":error_code_proto"],
)

kt_jvm_proto_library(
name = "error_code_kt_jvm_proto",
srcs = [":error_code_proto"],
deps = [":error_code_java_proto"],
)
30 changes: 30 additions & 0 deletions src/main/proto/wfa/measurement/internal/duchy/error_code.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright 2023 The Cross-Media Measurement Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

syntax = "proto3";

package wfa.measurement.internal.duchy;

option java_package = "org.wfanet.measurement.internal.duchy";
option java_multiple_files = true;

enum ErrorCode {
ERROR_CODE_UNSPECIFIED = 0;

/** ContinuationToken is invalid. */
CONTINUATION_TOKEN_INVALID = 1;

/** ContinuationToken is malformed. */
CONTINUATION_TOKEN_MALFORMED = 2;
}
Loading

0 comments on commit 14e17f4

Please sign in to comment.