Skip to content

Commit

Permalink
Pass Reporting EventGroup message descriptor to CelEnvProvider impls. (
Browse files Browse the repository at this point in the history
…#1417)

This allows CelEnvProvider to be used for multiple Reporting API versions.

Fixes #1406
  • Loading branch information
SanjayVas authored Jan 12, 2024
1 parent e2c025e commit 2306a16
Show file tree
Hide file tree
Showing 10 changed files with 27 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ import org.wfanet.measurement.reporting.service.api.InMemoryEncryptionKeyPairSto
import org.wfanet.measurement.reporting.service.api.v1alpha.EventGroupsService
import org.wfanet.measurement.reporting.service.api.v1alpha.ReportingSetsService
import org.wfanet.measurement.reporting.service.api.v1alpha.ReportsService
import org.wfanet.measurement.reporting.v1alpha.EventGroup

/** TestRule that starts and stops all Reporting Server gRPC services. */
class InProcessReportingServer(
Expand Down Expand Up @@ -123,6 +124,7 @@ class InProcessReportingServer(
publicKingdomEventGroupMetadataDescriptorsClient.withAuthenticationKey(
measurementConsumerConfig.apiKey
),
EventGroup.getDescriptor(),
Duration.ofSeconds(5),
Dispatchers.Default,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ import org.wfanet.measurement.reporting.service.api.v2alpha.MetricCalculationSpe
import org.wfanet.measurement.reporting.service.api.v2alpha.MetricsService
import org.wfanet.measurement.reporting.service.api.v2alpha.ReportingSetsService
import org.wfanet.measurement.reporting.service.api.v2alpha.ReportsService
import org.wfanet.measurement.reporting.v2alpha.EventGroup
import org.wfanet.measurement.reporting.v2alpha.MetricsGrpcKt.MetricsCoroutineStub as PublicMetricsCoroutineStub

/** TestRule that starts and stops all Reporting Server gRPC services. */
Expand Down Expand Up @@ -184,6 +185,7 @@ class InProcessReportingServer(
publicKingdomEventGroupMetadataDescriptorsClient.withAuthenticationKey(
measurementConsumerConfig.apiKey
),
EventGroup.getDescriptor(),
Duration.ofSeconds(5),
Dispatchers.Default,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ import org.wfanet.measurement.reporting.service.api.v1alpha.ReportingPrincipal
import org.wfanet.measurement.reporting.service.api.v1alpha.ReportingSetsService
import org.wfanet.measurement.reporting.service.api.v1alpha.ReportsService
import org.wfanet.measurement.reporting.service.api.v1alpha.withPrincipalsFromX509AuthorityKeyIdentifiers
import org.wfanet.measurement.reporting.v1alpha.EventGroup
import picocli.CommandLine

private const val SERVER_NAME = "V1AlphaPublicApiServer"
Expand Down Expand Up @@ -108,6 +109,7 @@ private fun run(
CelEnvCacheProvider(
KingdomEventGroupMetadataDescriptorsCoroutineStub(kingdomChannel)
.withAuthenticationKey(apiKey),
EventGroup.getDescriptor(),
reportingApiServerFlags.eventGroupMetadataDescriptorCacheDuration,
Dispatchers.Default,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ import org.wfanet.measurement.reporting.service.api.v2alpha.ReportingPrincipal
import org.wfanet.measurement.reporting.service.api.v2alpha.ReportingSetsService
import org.wfanet.measurement.reporting.service.api.v2alpha.ReportsService
import org.wfanet.measurement.reporting.service.api.v2alpha.withPrincipalsFromX509AuthorityKeyIdentifiers
import org.wfanet.measurement.reporting.v2alpha.EventGroup
import org.wfanet.measurement.reporting.v2alpha.MetricsGrpcKt.MetricsCoroutineStub
import picocli.CommandLine

Expand Down Expand Up @@ -156,6 +157,7 @@ private fun run(
CelEnvCacheProvider(
KingdomEventGroupMetadataDescriptorsCoroutineStub(kingdomChannel)
.withAuthenticationKey(apiKey),
EventGroup.getDescriptor(),
reportingApiServerFlags.eventGroupMetadataDescriptorCacheDuration,
Dispatchers.Default,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ kt_jvm_library(
"//imports/java/org/projectnessie/cel",
"//src/main/proto/wfa/measurement/api/v2alpha:event_group_metadata_descriptor_kt_jvm_proto",
"//src/main/proto/wfa/measurement/api/v2alpha:event_group_metadata_descriptors_service_kt_jvm_grpc_proto",
"//src/main/proto/wfa/measurement/reporting/v1alpha:event_group_kt_jvm_proto",
"@wfa_common_jvm//imports/java/com/google/protobuf",
"@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core",
"@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:jdk8",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ import org.wfanet.measurement.api.v2alpha.EventGroupMetadataDescriptor
import org.wfanet.measurement.api.v2alpha.EventGroupMetadataDescriptorsGrpcKt
import org.wfanet.measurement.api.v2alpha.listEventGroupMetadataDescriptorsRequest
import org.wfanet.measurement.common.ProtoReflection
import org.wfanet.measurement.reporting.v1alpha.EventGroup

private const val METADATA_FIELD = "metadata.metadata"
private const val MAX_PAGE_SIZE = 1000
Expand All @@ -66,6 +65,8 @@ interface CelEnvProvider {
class CelEnvCacheProvider(
private val eventGroupsMetadataDescriptorsStub:
EventGroupMetadataDescriptorsGrpcKt.EventGroupMetadataDescriptorsCoroutineStub,
/** Protobuf descriptor of Reporting EventGroup message type. */
private val reportingEventGroupDescriptor: Descriptors.Descriptor,
private val cacheRefreshInterval: Duration,
coroutineContext: CoroutineContext,
private val numRetriesInitialSync: Int = 3,
Expand Down Expand Up @@ -142,22 +143,20 @@ class CelEnvCacheProvider(
// Build CEL ProtoTypeRegistry.
val celTypeRegistry = ProtoTypeRegistry.newRegistry()
descriptors.forEach { celTypeRegistry.registerDescriptor(it.file) }

celTypeRegistry.registerMessage(EventGroup.getDefaultInstance())
celTypeRegistry.registerDescriptor(reportingEventGroupDescriptor.file)

// Build CEL Env.
val eventGroupDescriptor = EventGroup.getDescriptor()
val env =
Env.newEnv(
EnvOption.container(eventGroupDescriptor.fullName),
EnvOption.container(reportingEventGroupDescriptor.fullName),
EnvOption.customTypeProvider(celTypeRegistry),
EnvOption.customTypeAdapter(celTypeRegistry),
EnvOption.declarations(
eventGroupDescriptor.fields
reportingEventGroupDescriptor.fields
.map {
Decls.newVar(
it.name,
celTypeRegistry.findFieldType(eventGroupDescriptor.fullName, it.name).type
celTypeRegistry.findFieldType(reportingEventGroupDescriptor.fullName, it.name).type
)
}
// TODO(projectnessie/cel-java#295): Remove when fixed.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ kt_jvm_test(
deps = [
"//src/main/proto/wfa/measurement/api/v2alpha:event_group_metadata_descriptors_service_kt_jvm_grpc_proto",
"//src/main/proto/wfa/measurement/api/v2alpha/event_group_metadata/testing:test_metadata_messages_kt_jvm_proto",
"//src/main/proto/wfa/measurement/reporting/v1alpha:event_group_kt_jvm_proto",
"@wfa_common_jvm//imports/java/com/google/common/truth",
"@wfa_common_jvm//imports/kotlin/kotlin/test",
"@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class CelEnvProviderTest {
EventGroupMetadataDescriptorsGrpcKt.EventGroupMetadataDescriptorsCoroutineStub(
grpcTestServerRule.channel
),
REPORTING_EVENT_GROUP_DESCRIPTOR,
Duration.ofMinutes(5),
coroutineContext
)
Expand Down Expand Up @@ -132,6 +133,7 @@ class CelEnvProviderTest {
EventGroupMetadataDescriptorsGrpcKt.EventGroupMetadataDescriptorsCoroutineStub(
grpcTestServerRule.channel
),
REPORTING_EVENT_GROUP_DESCRIPTOR,
Duration.ofMinutes(5),
coroutineContext,
1
Expand Down Expand Up @@ -171,6 +173,7 @@ class CelEnvProviderTest {
EventGroupMetadataDescriptorsGrpcKt.EventGroupMetadataDescriptorsCoroutineStub(
grpcTestServerRule.channel
),
REPORTING_EVENT_GROUP_DESCRIPTOR,
Duration.ofMinutes(5),
coroutineContext,
numRetries
Expand Down Expand Up @@ -201,6 +204,7 @@ class CelEnvProviderTest {
EventGroupMetadataDescriptorsGrpcKt.EventGroupMetadataDescriptorsCoroutineStub(
grpcTestServerRule.channel
),
REPORTING_EVENT_GROUP_DESCRIPTOR,
Duration.ofMinutes(5),
coroutineContext,
numRetries
Expand Down Expand Up @@ -245,6 +249,7 @@ class CelEnvProviderTest {
EventGroupMetadataDescriptorsGrpcKt.EventGroupMetadataDescriptorsCoroutineStub(
grpcTestServerRule.channel
),
REPORTING_EVENT_GROUP_DESCRIPTOR,
cacheRefreshInterval,
coroutineContext
)
Expand All @@ -270,6 +275,8 @@ class CelEnvProviderTest {
}

companion object {
private val REPORTING_EVENT_GROUP_DESCRIPTOR = EventGroup.getDescriptor()

private fun verifyTypeRegistryAndEnv(typeRegistryAndEnv: CelEnvProvider.TypeRegistryAndEnv) {
val eventGroup = eventGroup {
metadata =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ import org.wfanet.measurement.consent.client.common.toEncryptionPublicKey
import org.wfanet.measurement.consent.client.dataprovider.encryptMetadata
import org.wfanet.measurement.reporting.service.api.CelEnvCacheProvider
import org.wfanet.measurement.reporting.service.api.InMemoryEncryptionKeyPairStore
import org.wfanet.measurement.reporting.v1alpha.EventGroup
import org.wfanet.measurement.reporting.v1alpha.EventGroupKt.metadata
import org.wfanet.measurement.reporting.v1alpha.eventGroup
import org.wfanet.measurement.reporting.v1alpha.listEventGroupsRequest
Expand Down Expand Up @@ -187,6 +188,7 @@ class EventGroupsServiceTest {
val celEnvCacheProvider =
CelEnvCacheProvider(
EventGroupMetadataDescriptorsCoroutineStub(grpcTestServerRule.channel),
EventGroup.getDescriptor(),
Duration.ofSeconds(5),
Dispatchers.Default,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ import org.wfanet.measurement.consent.client.common.toEncryptionPublicKey
import org.wfanet.measurement.consent.client.dataprovider.encryptMetadata
import org.wfanet.measurement.reporting.service.api.CelEnvCacheProvider
import org.wfanet.measurement.reporting.service.api.InMemoryEncryptionKeyPairStore
import org.wfanet.measurement.reporting.v2alpha.EventGroup
import org.wfanet.measurement.reporting.v2alpha.EventGroupKt
import org.wfanet.measurement.reporting.v2alpha.eventGroup
import org.wfanet.measurement.reporting.v2alpha.listEventGroupsRequest
Expand Down Expand Up @@ -108,6 +109,7 @@ class EventGroupsServiceTest {
val celEnvCacheProvider =
CelEnvCacheProvider(
EventGroupMetadataDescriptorsCoroutineStub(grpcTestServerRule.channel),
EventGroup.getDescriptor(),
Duration.ofSeconds(5),
Dispatchers.Default,
)
Expand Down Expand Up @@ -508,6 +510,7 @@ class EventGroupsServiceTest {
val celEnvCacheProvider =
CelEnvCacheProvider(
EventGroupMetadataDescriptorsCoroutineStub(grpcTestServerRule.channel),
EventGroup.getDescriptor(),
Duration.ofSeconds(5),
Dispatchers.Default,
)
Expand Down

0 comments on commit 2306a16

Please sign in to comment.