Skip to content

Commit

Permalink
Add DuchyInfo for duchy-related information (#106)
Browse files Browse the repository at this point in the history
Add DuchyInfo for duchy-related information
  • Loading branch information
oliver-amzn authored Jun 22, 2021
1 parent 8399fe4 commit f82b5b7
Show file tree
Hide file tree
Showing 16 changed files with 183 additions and 109 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@ kt_jvm_library(
"//src/main/kotlin/org/wfanet/measurement/common",
"//src/main/kotlin/org/wfanet/measurement/common/crypto:signing_certs",
"//src/main/kotlin/org/wfanet/measurement/common/throttler",
"//src/main/proto/wfa/measurement/config:duchy_rpc_config_java_proto",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package org.wfanet.measurement.common.identity
import io.grpc.BindableService
import io.grpc.Context
import io.grpc.Contexts
import io.grpc.Grpc
import io.grpc.Metadata
import io.grpc.ServerCall
import io.grpc.ServerCallHandler
Expand All @@ -26,6 +27,8 @@ import io.grpc.ServerServiceDefinition
import io.grpc.Status
import io.grpc.stub.AbstractStub
import io.grpc.stub.MetadataUtils
import java.security.cert.X509Certificate
import javax.net.ssl.SSLSession

/**
* Details about an authenticated Duchy.
Expand All @@ -34,7 +37,9 @@ import io.grpc.stub.MetadataUtils
*/
data class DuchyIdentity(val id: String) {
init {
require(id in DuchyIds.ALL) { "Duchy $id is unknown; known Duchies are ${DuchyIds.ALL}" }
requireNotNull(DuchyInfo.getByDuchyId(id)) {
"Duchy $id is unknown; known Duchies are ${DuchyInfo.ALL_DUCHY_IDS}"
}
}
}

Expand All @@ -60,34 +65,44 @@ private val DUCHY_ID_METADATA_KEY = Metadata.Key.of(KEY_NAME, Metadata.ASCII_STR
* ```
* On the client side, use [withDuchyId].
*/
class DuchyServerIdentityInterceptor : ServerInterceptor {
class DuchyTlsIdentityInterceptor() : ServerInterceptor {
override fun <ReqT, RespT> interceptCall(
call: ServerCall<ReqT, RespT>,
headers: Metadata,
next: ServerCallHandler<ReqT, RespT>
): ServerCall.Listener<ReqT> {
val duchyId: String? = headers.get(DUCHY_ID_METADATA_KEY)

if (duchyId == null) {
val sslSession: SSLSession? = call.attributes[Grpc.TRANSPORT_ATTR_SSL_SESSION]
if (sslSession == null) {
call.close(
Status.UNAUTHENTICATED.withDescription("gRPC metadata missing 'duchy_id' key"),
Status.UNAUTHENTICATED.withDescription("gRPC metadata missing sslSession"),
Metadata()
)
return object : ServerCall.Listener<ReqT>() {}
}

val context = Context.current().withValue(DUCHY_IDENTITY_CONTEXT_KEY, DuchyIdentity(duchyId))
return Contexts.interceptCall(context, call, headers, next)
for (cert in sslSession.peerCertificates) {
if (cert !is X509Certificate) {
continue
}

val duchyInfo =
DuchyInfo.getByRootCertificateSkid(
String(cert.getExtensionValue("X509v3 Authority Key Identifier"))
)
?: continue

val context =
Context.current().withValue(DUCHY_IDENTITY_CONTEXT_KEY, DuchyIdentity(duchyInfo.duchyId))
return Contexts.interceptCall(context, call, headers, next)
}

return Contexts.interceptCall(Context.current(), call, headers, next)
}
}

/** Convenience helper for [DuchyServerIdentityInterceptor]. */
/** Convenience helper for [DuchyTlsIdentityInterceptor]. */
fun BindableService.withDuchyIdentities(): ServerServiceDefinition =
ServerInterceptors.interceptForward(this, DuchyServerIdentityInterceptor())

/** Convenience helper for [DuchyServerIdentityInterceptor]. */
fun ServerServiceDefinition.withDuchyIdentities(): ServerServiceDefinition =
ServerInterceptors.interceptForward(this, DuchyServerIdentityInterceptor())
ServerInterceptors.interceptForward(this, DuchyTlsIdentityInterceptor())

/**
* Sets metadata key "duchy_id" on all outgoing requests.
Expand Down
58 changes: 0 additions & 58 deletions src/main/kotlin/org/wfanet/measurement/common/identity/DuchyIds.kt

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Copyright 2021 The Cross-Media Measurement Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package org.wfanet.measurement.common.identity

import org.wfanet.measurement.common.parseTextProto
import org.wfanet.measurement.config.DuchyRpcConfig
import picocli.CommandLine

object DuchyInfo {
lateinit var entries: Array<Entry>
val count: Int
get() = DuchyInfo.entries.size
val ALL_DUCHY_IDS: Set<String>
get() = DuchyInfo.entries.map { it.duchyId }.toSet()

fun initializeFromFlags(flags: DuchyInfoFlags) {
require(!DuchyInfo::entries.isInitialized)
val configMessage =
flags.config.reader().use { parseTextProto(it, DuchyRpcConfig.getDefaultInstance()) }
require(configMessage.duchiesCount > 0) { "Duchy info config has no entries" }
entries = configMessage.duchiesList.map { it.toDuchyInfoEntry() }.toTypedArray()
}

/** Returns the [Entry] for the specified root cert key ID. */
fun getByRootCertificateSkid(rootCertificateSkid: String): Entry? {
return entries.firstOrNull { it.rootCertificateSkid == rootCertificateSkid }
}

/** Returns the [Entry] for the specified Duchy ID. */
fun getByDuchyId(duchyId: String): Entry? {
return entries.firstOrNull { it.duchyId == duchyId }
}

fun setForTest(duchyIds: Set<String>) {
entries = duchyIds.map { DuchyInfo.Entry(it, "hostname-$it", "cert-id-$it") }.toTypedArray()
}

data class Entry(
val duchyId: String,
val computationControlServiceTarget: String,
val rootCertificateSkid: String
)
}

class DuchyInfoFlags {
@CommandLine.Option(
names = ["--duchy-info-config"],
description = ["DuchyRpcConfig proto message in text format."],
required = true
)
lateinit var config: String
private set
}

private fun DuchyRpcConfig.Duchy.toDuchyInfoEntry(): DuchyInfo.Entry {
return DuchyInfo.Entry(duchyId, computationControlServiceTarget, rootCertificateSkid)
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ package org.wfanet.measurement.common.identity.testing
import org.junit.rules.TestRule
import org.junit.runner.Description
import org.junit.runners.model.Statement
import org.wfanet.measurement.common.identity.DuchyIds
import org.wfanet.measurement.common.identity.DuchyInfo

/** JUnit rule that sets the global list of all valid Duchy ids to [duchyIds]. */
class DuchyIdSetter(val duchyIds: Set<String>) : TestRule {
Expand All @@ -27,7 +27,7 @@ class DuchyIdSetter(val duchyIds: Set<String>) : TestRule {
override fun apply(base: Statement, description: Description): Statement {
return object : Statement() {
override fun evaluate() {
DuchyIds.setDuchyIdsForTest(duchyIds)
DuchyInfo.setForTest(duchyIds)
base.evaluate()
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ import io.grpc.ManagedChannel
import org.wfanet.measurement.common.commandLineMain
import org.wfanet.measurement.common.grpc.CommonServer
import org.wfanet.measurement.common.grpc.buildChannel
import org.wfanet.measurement.common.identity.DuchyIdFlags
import org.wfanet.measurement.common.identity.DuchyIds
import org.wfanet.measurement.common.identity.DuchyInfo
import org.wfanet.measurement.common.identity.DuchyInfoFlags
import org.wfanet.measurement.duchy.DuchyPublicKeys
import org.wfanet.measurement.duchy.db.computation.ComputationProtocolStageDetails
import org.wfanet.measurement.duchy.deploy.common.CommonDuchyFlags
Expand All @@ -44,7 +44,7 @@ class AsyncComputationControlServiceFlags {
private set

@CommandLine.Mixin
lateinit var duchyIdFlags: DuchyIdFlags
lateinit var duchyInfo: DuchyInfoFlags
private set

@CommandLine.Option(
Expand All @@ -68,8 +68,8 @@ private fun run(@CommandLine.Mixin flags: AsyncComputationControlServiceFlags) {
require(latestDuchyPublicKeys.containsKey(duchyName)) {
"Public key not specified for Duchy $duchyName"
}
DuchyIds.setDuchyIdsFromFlags(flags.duchyIdFlags)
require(latestDuchyPublicKeys.keys.toSet() == DuchyIds.ALL)
DuchyInfo.initializeFromFlags(flags.duchyInfo)
require(latestDuchyPublicKeys.keys.toSet() == DuchyInfo.ALL_DUCHY_IDS)

val otherDuchyNames = latestDuchyPublicKeys.keys.filter { it != duchyName }
val channel: ManagedChannel = buildChannel(flags.computationsServiceTarget)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ package org.wfanet.measurement.duchy.deploy.common.server
import io.grpc.ManagedChannel
import org.wfanet.measurement.common.grpc.CommonServer
import org.wfanet.measurement.common.grpc.buildChannel
import org.wfanet.measurement.common.identity.DuchyIdFlags
import org.wfanet.measurement.common.identity.DuchyIds
import org.wfanet.measurement.common.identity.DuchyInfo
import org.wfanet.measurement.common.identity.DuchyInfoFlags
import org.wfanet.measurement.common.identity.withDuchyIdentities
import org.wfanet.measurement.duchy.DuchyPublicKeys
import org.wfanet.measurement.duchy.deploy.common.CommonDuchyFlags
Expand All @@ -33,7 +33,7 @@ abstract class ComputationControlServer : Runnable {
private set

@CommandLine.Mixin
protected lateinit var duchyIdFlags: DuchyIdFlags
protected lateinit var duchyInfoFlags: DuchyInfoFlags
private set

protected fun run(storageClient: StorageClient) {
Expand All @@ -42,8 +42,8 @@ abstract class ComputationControlServer : Runnable {
require(latestDuchyPublicKeys.containsKey(duchyName)) {
"Public key not specified for Duchy $duchyName"
}
DuchyIds.setDuchyIdsFromFlags(duchyIdFlags)
require(latestDuchyPublicKeys.keys.toSet() == DuchyIds.ALL)
DuchyInfo.initializeFromFlags(duchyInfoFlags)
require(latestDuchyPublicKeys.keys.toSet() == DuchyInfo.ALL_DUCHY_IDS)

val channel: ManagedChannel = buildChannel(flags.asyncComputationControlServiceTarget)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ package org.wfanet.measurement.kingdom.deploy.common.server

import org.wfanet.measurement.common.commandLineMain
import org.wfanet.measurement.common.grpc.CommonServer
import org.wfanet.measurement.common.identity.DuchyIdFlags
import org.wfanet.measurement.common.identity.DuchyInfoFlags
import org.wfanet.measurement.internal.kingdom.ReportLogEntriesGrpcKt.ReportLogEntriesCoroutineStub
import org.wfanet.measurement.internal.kingdom.ReportsGrpcKt.ReportsCoroutineStub
import org.wfanet.measurement.kingdom.service.system.v1alpha.GlobalComputationService
Expand All @@ -30,10 +30,10 @@ import picocli.CommandLine
)
private fun run(
@CommandLine.Mixin kingdomApiServerFlags: KingdomApiServerFlags,
@CommandLine.Mixin duchyIdFlags: DuchyIdFlags,
@CommandLine.Mixin duchyInfoFlags: DuchyInfoFlags,
@CommandLine.Mixin commonServerFlags: CommonServer.Flags
) {
runKingdomApiServer(kingdomApiServerFlags, duchyIdFlags, commonServerFlags) { channel ->
runKingdomApiServer(kingdomApiServerFlags, duchyInfoFlags, commonServerFlags) { channel ->
GlobalComputationService(ReportsCoroutineStub(channel), ReportLogEntriesCoroutineStub(channel))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ import kotlin.properties.Delegates
import org.wfanet.measurement.common.grpc.CommonServer
import org.wfanet.measurement.common.grpc.buildChannel
import org.wfanet.measurement.common.grpc.withVerboseLogging
import org.wfanet.measurement.common.identity.DuchyIdFlags
import org.wfanet.measurement.common.identity.DuchyIds
import org.wfanet.measurement.common.identity.DuchyInfo
import org.wfanet.measurement.common.identity.DuchyInfoFlags
import org.wfanet.measurement.common.identity.withDuchyIdentities
import picocli.CommandLine

Expand All @@ -44,11 +44,11 @@ class KingdomApiServerFlags {

fun runKingdomApiServer(
kingdomApiServerFlags: KingdomApiServerFlags,
duchyIdFlags: DuchyIdFlags,
duchyInfoFlags: DuchyInfoFlags,
commonServerFlags: CommonServer.Flags,
serviceFactory: (Channel) -> BindableService
) {
DuchyIds.setDuchyIdsFromFlags(duchyIdFlags)
DuchyInfo.initializeFromFlags(duchyInfoFlags)

val channel: Channel =
buildChannel(kingdomApiServerFlags.internalApiTarget)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,18 @@ package org.wfanet.measurement.kingdom.deploy.common.server

import kotlinx.coroutines.runInterruptible
import org.wfanet.measurement.common.grpc.CommonServer
import org.wfanet.measurement.common.identity.DuchyIdFlags
import org.wfanet.measurement.common.identity.DuchyIds
import org.wfanet.measurement.common.identity.DuchyInfo
import org.wfanet.measurement.common.identity.DuchyInfoFlags
import org.wfanet.measurement.kingdom.deploy.common.service.DataServices
import picocli.CommandLine

abstract class KingdomDataServer : Runnable {
@CommandLine.Mixin private lateinit var serverFlags: CommonServer.Flags

@CommandLine.Mixin private lateinit var duchyIdFlags: DuchyIdFlags
@CommandLine.Mixin private lateinit var duchyInfoFlags: DuchyInfoFlags

protected suspend fun run(dataServices: DataServices) {
DuchyIds.setDuchyIdsFromFlags(duchyIdFlags)
DuchyInfo.initializeFromFlags(duchyInfoFlags)

val services = dataServices.buildDataServices()
val server = CommonServer.fromFlags(serverFlags, this::class.simpleName!!, services)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ package org.wfanet.measurement.kingdom.deploy.common.server

import kotlinx.coroutines.runInterruptible
import org.wfanet.measurement.common.grpc.CommonServer
import org.wfanet.measurement.common.identity.DuchyIdFlags
import org.wfanet.measurement.common.identity.DuchyIds
import org.wfanet.measurement.common.identity.DuchyInfo
import org.wfanet.measurement.common.identity.DuchyInfoFlags
import org.wfanet.measurement.kingdom.db.ReportDatabase
import org.wfanet.measurement.kingdom.db.RequisitionDatabase
import org.wfanet.measurement.kingdom.service.internal.buildLegacyDataServices
Expand All @@ -26,13 +26,13 @@ import picocli.CommandLine
abstract class LegacyKingdomDataServer : Runnable {
@CommandLine.Mixin private lateinit var serverFlags: CommonServer.Flags

@CommandLine.Mixin private lateinit var duchyIdFlags: DuchyIdFlags
@CommandLine.Mixin private lateinit var duchyInfoFlags: DuchyInfoFlags

protected suspend fun run(
reportDatabase: ReportDatabase,
requisitionDatabase: RequisitionDatabase
) {
DuchyIds.setDuchyIdsFromFlags(duchyIdFlags)
DuchyInfo.initializeFromFlags(duchyInfoFlags)

val services = buildLegacyDataServices(reportDatabase, requisitionDatabase)
val server = CommonServer.fromFlags(serverFlags, this::class.simpleName!!, services)
Expand Down
Loading

0 comments on commit f82b5b7

Please sign in to comment.