Skip to content

Commit

Permalink
Implement Event Groups Service Reporting V2Alpha. (#1088)
Browse files Browse the repository at this point in the history
  • Loading branch information
tristanvuong2021 authored and ple13 committed Aug 16, 2024
1 parent c86eace commit cfb24af
Show file tree
Hide file tree
Showing 5 changed files with 980 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import io.grpc.StatusException
import java.io.File
import java.security.SecureRandom
import java.security.cert.X509Certificate
import java.time.Duration
import java.util.logging.Logger
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.runBlocking
Expand All @@ -31,9 +32,12 @@ import org.junit.runner.Description
import org.junit.runners.model.Statement
import org.wfanet.measurement.api.v2alpha.CertificatesGrpcKt.CertificatesCoroutineStub as PublicKingdomCertificatesCoroutineStub
import org.wfanet.measurement.api.v2alpha.DataProvidersGrpcKt.DataProvidersCoroutineStub as PublicKingdomDataProvidersCoroutineStub
import org.wfanet.measurement.api.v2alpha.EventGroupMetadataDescriptorsGrpcKt.EventGroupMetadataDescriptorsCoroutineStub as PublicKingdomEventGroupMetadataDescriptorsCoroutineStub
import org.wfanet.measurement.api.v2alpha.EventGroupsGrpcKt.EventGroupsCoroutineStub as PublicKingdomEventGroupsCoroutineStub
import org.wfanet.measurement.api.v2alpha.MeasurementConsumerCertificateKey
import org.wfanet.measurement.api.v2alpha.MeasurementConsumersGrpcKt.MeasurementConsumersCoroutineStub as PublicKingdomMeasurementConsumersCoroutineStub
import org.wfanet.measurement.api.v2alpha.MeasurementsGrpcKt.MeasurementsCoroutineStub as PublicKingdomMeasurementsCoroutineStub
import org.wfanet.measurement.api.withAuthenticationKey
import org.wfanet.measurement.common.crypto.tink.loadPrivateKey
import org.wfanet.measurement.common.grpc.testing.GrpcTestServerRule
import org.wfanet.measurement.common.grpc.withVerboseLogging
Expand All @@ -52,6 +56,7 @@ import org.wfanet.measurement.internal.reporting.v2.ReportsGrpcKt.ReportsCorouti
import org.wfanet.measurement.internal.reporting.v2.measurementConsumer
import org.wfanet.measurement.reporting.deploy.v2.common.server.InternalReportingServer
import org.wfanet.measurement.reporting.deploy.v2.common.server.InternalReportingServer.Companion.toList
import org.wfanet.measurement.reporting.service.api.CelEnvCacheProvider
import org.wfanet.measurement.reporting.service.api.InMemoryEncryptionKeyPairStore
import org.wfanet.measurement.reporting.service.api.v2alpha.EventGroupsService
import org.wfanet.measurement.reporting.service.api.v2alpha.MetricsService
Expand Down Expand Up @@ -81,6 +86,12 @@ class InProcessReportingServer(
private val publicKingdomDataProvidersClient by lazy {
PublicKingdomDataProvidersCoroutineStub(publicKingdomChannelGenerator())
}
private val publicKingdomEventGroupMetadataDescriptorsClient by lazy {
PublicKingdomEventGroupMetadataDescriptorsCoroutineStub(publicKingdomChannelGenerator())
}
private val publicKingdomEventGroupsClient by lazy {
PublicKingdomEventGroupsCoroutineStub(publicKingdomChannelGenerator())
}

private val internalApiChannel by lazy { internalReportingServer.channel }
private val internalMeasurementConsumersClient by lazy {
Expand Down Expand Up @@ -148,8 +159,22 @@ class InProcessReportingServer(
}
}

val celEnvCacheProvider =
CelEnvCacheProvider(
publicKingdomEventGroupMetadataDescriptorsClient.withAuthenticationKey(
measurementConsumerConfig.apiKey
),
Duration.ofSeconds(5),
Dispatchers.Default,
)

listOf(
EventGroupsService().withMetadataPrincipalIdentities(measurementConsumerConfig),
EventGroupsService(
publicKingdomEventGroupsClient,
encryptionKeyPairStore,
celEnvCacheProvider
)
.withMetadataPrincipalIdentities(measurementConsumerConfig),
MetricsService(
METRIC_SPEC_CONFIG,
internalReportingSetsClient,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,21 @@ kt_jvm_library(
name = "event_groups_service",
srcs = ["EventGroupsService.kt"],
deps = [
":resource_key",
"//imports/java/org/projectnessie/cel",
"//src/main/kotlin/org/wfanet/measurement/api:api_key_constants",
"//src/main/kotlin/org/wfanet/measurement/reporting/service/api:cel_env_provider",
"//src/main/kotlin/org/wfanet/measurement/reporting/service/api:encryption_key_pair_store",
"//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha:principal_server_interceptor",
"//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha:reporting_principal",
"//src/main/proto/wfa/measurement/api/v2alpha:event_group_kt_jvm_proto",
"//src/main/proto/wfa/measurement/api/v2alpha:event_groups_service_kt_jvm_grpc_proto",
"//src/main/proto/wfa/measurement/api/v2alpha:page_token_kt_jvm_proto",
"//src/main/proto/wfa/measurement/reporting/v2alpha:event_group_kt_jvm_proto",
"//src/main/proto/wfa/measurement/reporting/v2alpha:event_groups_service_kt_jvm_grpc_proto",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/crypto:key_storage",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc",
"@wfa_consent_signaling_client//src/main/kotlin/org/wfanet/measurement/consent/client/measurementconsumer",
],
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,237 @@

package org.wfanet.measurement.reporting.service.api.v2alpha

import com.google.protobuf.DynamicMessage
import io.grpc.Status
import io.grpc.StatusException
import java.security.GeneralSecurityException
import org.projectnessie.cel.common.types.Err
import org.projectnessie.cel.common.types.ref.Val
import org.wfanet.measurement.api.v2alpha.DataProviderKey
import org.wfanet.measurement.api.v2alpha.EncryptionPublicKey
import org.wfanet.measurement.api.v2alpha.EventGroup as CmmsEventGroup
import org.wfanet.measurement.api.v2alpha.EventGroupKey as CmmsEventGroupKey
import org.wfanet.measurement.api.v2alpha.EventGroupsGrpcKt.EventGroupsCoroutineStub as CmmsEventGroupsCoroutineStub
import org.wfanet.measurement.api.v2alpha.MeasurementConsumerKey
import org.wfanet.measurement.api.v2alpha.listEventGroupsRequest
import org.wfanet.measurement.api.withAuthenticationKey
import org.wfanet.measurement.common.crypto.PrivateKeyHandle
import org.wfanet.measurement.common.grpc.grpcRequire
import org.wfanet.measurement.common.grpc.grpcRequireNotNull
import org.wfanet.measurement.consent.client.measurementconsumer.decryptMetadata
import org.wfanet.measurement.reporting.service.api.CelEnvProvider
import org.wfanet.measurement.reporting.service.api.EncryptionKeyPairStore
import org.wfanet.measurement.reporting.v2alpha.EventGroup
import org.wfanet.measurement.reporting.v2alpha.EventGroupKt
import org.wfanet.measurement.reporting.v2alpha.EventGroupsGrpcKt.EventGroupsCoroutineImplBase
import org.wfanet.measurement.reporting.v2alpha.ListEventGroupsRequest
import org.wfanet.measurement.reporting.v2alpha.ListEventGroupsResponse
import org.wfanet.measurement.reporting.v2alpha.eventGroup
import org.wfanet.measurement.reporting.v2alpha.listEventGroupsResponse

/** TODO(@tristanvuong2021): Implement methods. */
class EventGroupsService : EventGroupsCoroutineImplBase()
class EventGroupsService(
private val cmmsEventGroupsStub: CmmsEventGroupsCoroutineStub,
private val encryptionKeyPairStore: EncryptionKeyPairStore,
private val celEnvProvider: CelEnvProvider,
) : EventGroupsCoroutineImplBase() {
override suspend fun listEventGroups(request: ListEventGroupsRequest): ListEventGroupsResponse {
val parentKey =
grpcRequireNotNull(MeasurementConsumerKey.fromName(request.parent)) {
"Parent is either unspecified or invalid."
}

val principal: ReportingPrincipal = principalFromCurrentContext
when (principal) {
is MeasurementConsumerPrincipal -> {
if (parentKey != principal.resourceKey) {
throw Status.PERMISSION_DENIED.withDescription(
"Cannot list event groups for another MeasurementConsumer"
)
.asRuntimeException()
}
}
}

val apiAuthenticationKey: String = principal.config.apiKey

grpcRequire(request.pageSize >= 0) { "page_size cannot be negative" }

val pageSize =
when {
request.pageSize < MIN_PAGE_SIZE -> DEFAULT_PAGE_SIZE
request.pageSize > MAX_PAGE_SIZE -> MAX_PAGE_SIZE
else -> request.pageSize
}

val cmmsListEventGroupResponse =
try {
cmmsEventGroupsStub
.withAuthenticationKey(apiAuthenticationKey)
.listEventGroups(
listEventGroupsRequest {
parent = parentKey.toName()
this.pageSize = pageSize
pageToken = request.pageToken
}
)
} catch (e: StatusException) {
throw when (e.status.code) {
Status.Code.DEADLINE_EXCEEDED -> Status.DEADLINE_EXCEEDED
Status.Code.CANCELLED -> Status.CANCELLED
else -> Status.UNKNOWN
}
.withCause(e)
.asRuntimeException()
}
val cmmsEventGroups = cmmsListEventGroupResponse.eventGroupsList

val eventGroups =
cmmsEventGroups.map {
val cmmsMetadata: CmmsEventGroup.Metadata? =
if (it.encryptedMetadata.isEmpty) {
null
} else {
decryptMetadata(it, principal.resourceKey.toName())
}

it.toEventGroup(cmmsMetadata)
}

return listEventGroupsResponse {
this.eventGroups += filterEventGroups(eventGroups, request.filter)
nextPageToken = cmmsListEventGroupResponse.nextPageToken
}
}

private suspend fun filterEventGroups(
eventGroups: List<EventGroup>,
filter: String,
): List<EventGroup> {
if (filter.isEmpty()) {
return eventGroups
}

val typeRegistryAndEnv = celEnvProvider.getTypeRegistryAndEnv()
val env = typeRegistryAndEnv.env
val typeRegistry = typeRegistryAndEnv.typeRegistry

val astAndIssues =
try {
env.compile(filter)
} catch (_: NullPointerException) {
// NullPointerException is thrown when an operator in the filter is not a CEL operator.
throw Status.INVALID_ARGUMENT.withDescription("filter is not a valid CEL expression")
.asRuntimeException()
}
if (astAndIssues.hasIssues()) {
throw Status.INVALID_ARGUMENT.withDescription(
"filter is not a valid CEL expression: ${astAndIssues.issues}"
)
.asRuntimeException()
}
val program = env.program(astAndIssues.ast)

eventGroups
.filter { it.hasMetadata() }
.distinctBy { it.metadata.metadata.typeUrl }
.forEach {
val typeUrl = it.metadata.metadata.typeUrl
typeRegistry.getDescriptorForTypeUrl(typeUrl)
?: throw Status.FAILED_PRECONDITION.withDescription(
"${it.metadata.eventGroupMetadataDescriptor} does not contain descriptor for $typeUrl"
)
.asRuntimeException()
}

return eventGroups.filter { eventGroup ->
val variables: Map<String, Any> =
mutableMapOf<String, Any>().apply {
for (fieldDescriptor in eventGroup.descriptorForType.fields) {
put(fieldDescriptor.name, eventGroup.getField(fieldDescriptor))
}
// TODO(projectnessie/cel-java#295): Remove when fixed.
if (eventGroup.hasMetadata()) {
val metadata: com.google.protobuf.Any = eventGroup.metadata.metadata
put(
METADATA_FIELD,
DynamicMessage.parseFrom(
typeRegistry.getDescriptorForTypeUrl(metadata.typeUrl),
metadata.value
)
)
}
}
val result: Val = program.eval(variables).`val`
if (result is Err) {
throw result.toRuntimeException()
}

if (result.value() !is Boolean) {
throw Status.INVALID_ARGUMENT.withDescription("filter does not evaluate to boolean")
.asRuntimeException()
}

result.booleanValue()
}
}

private suspend fun decryptMetadata(
cmmsEventGroup: CmmsEventGroup,
principalName: String,
): CmmsEventGroup.Metadata {
if (!cmmsEventGroup.hasMeasurementConsumerPublicKey()) {
throw Status.FAILED_PRECONDITION.withDescription(
"EventGroup ${cmmsEventGroup.name} has encrypted metadata but no encryption public key"
)
.asRuntimeException()
}
val encryptionKey =
EncryptionPublicKey.parseFrom(cmmsEventGroup.measurementConsumerPublicKey.data)
val decryptionKeyHandle: PrivateKeyHandle =
encryptionKeyPairStore.getPrivateKeyHandle(principalName, encryptionKey.data)
?: throw Status.FAILED_PRECONDITION.withDescription(
"Public key does not have corresponding private key"
)
.asRuntimeException()

return try {
decryptMetadata(cmmsEventGroup.encryptedMetadata, decryptionKeyHandle)
} catch (e: GeneralSecurityException) {
throw Status.FAILED_PRECONDITION.withCause(e)
.withDescription("Metadata cannot be decrypted")
.asRuntimeException()
}
}

private fun CmmsEventGroup.toEventGroup(cmmsMetadata: CmmsEventGroup.Metadata?): EventGroup {
val source = this
val cmmsEventGroupKey = requireNotNull(CmmsEventGroupKey.fromName(name))
val measurementConsumerKey =
requireNotNull(MeasurementConsumerKey.fromName(measurementConsumer))
return eventGroup {
name =
EventGroupKey(measurementConsumerKey.measurementConsumerId, cmmsEventGroupKey.eventGroupId)
.toName()
cmmsEventGroup = source.name
cmmsDataProvider = DataProviderKey(cmmsEventGroupKey.dataProviderId).toName()
eventGroupReferenceId = source.eventGroupReferenceId
eventTemplates +=
source.eventTemplatesList.map { EventGroupKt.eventTemplate { type = it.type } }
if (cmmsMetadata != null) {
metadata =
EventGroupKt.metadata {
eventGroupMetadataDescriptor = cmmsMetadata.eventGroupMetadataDescriptor
metadata = cmmsMetadata.metadata
}
}
}
}

companion object {
private const val METADATA_FIELD = "metadata.metadata"

private const val MIN_PAGE_SIZE = 1
private const val DEFAULT_PAGE_SIZE = 50
private const val MAX_PAGE_SIZE = 1000
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,44 @@
load("@io_bazel_rules_kotlin//kotlin:jvm.bzl", "kt_jvm_test")

kt_jvm_test(
name = "EventGroupsServiceTest",
srcs = ["EventGroupsServiceTest.kt"],
associates = [
"//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha:event_groups_service",
],
data = [
"//src/main/k8s/testing/secretfiles:all_der_files",
"//src/main/k8s/testing/secretfiles:all_tink_keysets",
],
test_class = "org.wfanet.measurement.reporting.service.api.v2alpha.EventGroupsServiceTest",
deps = [
"//src/main/kotlin/org/wfanet/measurement/api/v2alpha:principal_server_interceptor",
"//src/main/kotlin/org/wfanet/measurement/reporting/service/api:cel_env_provider",
"//src/main/kotlin/org/wfanet/measurement/reporting/service/api:encryption_key_pair_store",
"//src/main/kotlin/org/wfanet/measurement/reporting/service/api/v2alpha:principal_server_interceptor",
"//src/main/proto/wfa/measurement/api/v2alpha:event_group_kt_jvm_proto",
"//src/main/proto/wfa/measurement/api/v2alpha:event_group_metadata_descriptors_service_kt_jvm_grpc_proto",
"//src/main/proto/wfa/measurement/api/v2alpha:event_groups_service_kt_jvm_grpc_proto",
"//src/main/proto/wfa/measurement/api/v2alpha/event_group_metadata/testing:test_metadata_messages_kt_jvm_proto",
"//src/main/proto/wfa/measurement/api/v2alpha/event_templates/testing:test_event_kt_jvm_proto",
"//src/main/proto/wfa/measurement/config/reporting:measurement_consumer_config_kt_jvm_proto",
"//src/main/proto/wfa/measurement/reporting/v2alpha:event_groups_service_kt_jvm_grpc_proto",
"@wfa_common_jvm//imports/java/com/google/common/truth",
"@wfa_common_jvm//imports/java/com/google/common/truth/extensions/proto",
"@wfa_common_jvm//imports/java/com/google/protobuf",
"@wfa_common_jvm//imports/kotlin/kotlin/test",
"@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core",
"@wfa_common_jvm//imports/kotlin/org/mockito/kotlin",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/crypto/testing",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/crypto/tink",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc/testing",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/testing",
"@wfa_consent_signaling_client//src/main/kotlin/org/wfanet/measurement/consent/client/dataprovider",
],
)

kt_jvm_test(
name = "MetricSpecDefaultsTest",
srcs = ["MetricSpecDefaultsTest.kt"],
Expand Down
Loading

0 comments on commit cfb24af

Please sign in to comment.