Skip to content

Commit

Permalink
fix: Include known types when building EventGroup filter registry (#1925
Browse files Browse the repository at this point in the history
)

Closes #1924
  • Loading branch information
SanjayVas authored Nov 18, 2024
1 parent f705789 commit cb75127
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,15 @@ class CelEnvCacheProvider(
val eventGroupMetadataDescriptors: List<EventGroupMetadataDescriptor> =
getEventGroupMetadataDescriptors()

val fileDescriptors: List<Descriptors.FileDescriptor> =
ProtoReflection.buildFileDescriptors(
eventGroupMetadataDescriptors.map { it.descriptorSet },
allKnownMetadataTypes,
)
val fileDescriptors: Set<Descriptors.FileDescriptor> =
allKnownMetadataTypes.toMutableSet().apply {
addAll(
ProtoReflection.buildFileDescriptors(
eventGroupMetadataDescriptors.map { it.descriptorSet },
allKnownMetadataTypes,
)
)
}

val env = buildCelEnvironment(fileDescriptors)
val typeRegistry: TypeRegistry = buildTypeRegistry(fileDescriptors)
Expand Down Expand Up @@ -205,10 +209,6 @@ class CelEnvCacheProvider(
}
}

private fun buildTypeRegistry(fileDescriptors: List<Descriptors.FileDescriptor>): TypeRegistry {
return TypeRegistry.newBuilder().add(fileDescriptors.flatMap { it.messageTypes }).build()
}

/** Suspends until any in-flight sync operations are complete. */
suspend fun waitForSync() {
initialSyncJob.join()
Expand All @@ -232,5 +232,11 @@ class CelEnvCacheProvider(
@VisibleForTesting internal val RETRY_DELAY: Duration = Duration.ofMillis(100)

private val logger: Logger = Logger.getLogger(this::class.java.name)

private fun buildTypeRegistry(
fileDescriptors: Iterable<Descriptors.FileDescriptor>
): TypeRegistry {
return TypeRegistry.newBuilder().add(fileDescriptors.flatMap { it.messageTypes }).build()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@ import java.nio.file.Paths
import java.time.Duration
import kotlin.test.assertFailsWith
import kotlinx.coroutines.runBlocking
import org.junit.After
import org.junit.Before
import org.junit.Rule
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.JUnit4
import org.mockito.kotlin.any
import org.mockito.kotlin.argumentCaptor
import org.mockito.kotlin.stub
import org.mockito.kotlin.times
import org.mockito.kotlin.verify
import org.mockito.kotlin.whenever
Expand All @@ -50,6 +52,7 @@ import org.wfanet.measurement.api.v2alpha.MeasurementConsumerKey
import org.wfanet.measurement.api.v2alpha.copy
import org.wfanet.measurement.api.v2alpha.eventGroup as cmmsEventGroup
import org.wfanet.measurement.api.v2alpha.eventGroupMetadataDescriptor
import org.wfanet.measurement.api.v2alpha.event_group_metadata.testing.TestMetadataMessage
import org.wfanet.measurement.api.v2alpha.event_group_metadata.testing.testMetadataMessage
import org.wfanet.measurement.api.v2alpha.event_templates.testing.TestEvent
import org.wfanet.measurement.api.v2alpha.listEventGroupMetadataDescriptorsResponse
Expand Down Expand Up @@ -105,11 +108,12 @@ class EventGroupsServiceTest {
addService(publicKingdomEventGroupMetadataDescriptorsMock)
}

private lateinit var celEnvCacheProvider: CelEnvCacheProvider
private lateinit var service: EventGroupsService

@Before
fun initService() {
val celEnvCacheProvider =
celEnvCacheProvider =
CelEnvCacheProvider(
EventGroupMetadataDescriptorsCoroutineStub(grpcTestServerRule.channel),
EventGroup.getDescriptor(),
Expand All @@ -125,6 +129,11 @@ class EventGroupsServiceTest {
)
}

@After
fun closeCelEnvCacheProvider() {
celEnvCacheProvider.close()
}

@Test
fun `listEventGroups returns events groups after multiple calls to kingdom`() = runBlocking {
val testMessage = testMetadataMessage { publisherId = 5 }
Expand Down Expand Up @@ -276,6 +285,76 @@ class EventGroupsServiceTest {
)
}

@Test
fun `listEventGroups returns only event groups that match filter when filter has metadata using a known type`() {
celEnvCacheProvider.close()
celEnvCacheProvider =
CelEnvCacheProvider(
EventGroupMetadataDescriptorsCoroutineStub(grpcTestServerRule.channel),
EventGroup.getDescriptor(),
Duration.ofSeconds(5),
listOf(TestMetadataMessage.getDescriptor().file),
)
service =
EventGroupsService(
EventGroupsCoroutineStub(grpcTestServerRule.channel),
ENCRYPTION_KEY_PAIR_STORE,
celEnvCacheProvider,
)
publicKingdomEventGroupMetadataDescriptorsMock.stub {
onBlocking { listEventGroupMetadataDescriptors(any()) }
.thenReturn(
listEventGroupMetadataDescriptorsResponse {
eventGroupMetadataDescriptors +=
EVENT_GROUP_METADATA_DESCRIPTOR.copy { clearDescriptorSet() }
}
)
}
val testMessage = testMetadataMessage { publisherId = 5 }

val cmmsEventGroup2 =
CMMS_EVENT_GROUP.copy {
encryptedMetadata =
encryptMetadata(
CmmsEventGroupKt.metadata {
eventGroupMetadataDescriptor = EVENT_GROUP_METADATA_DESCRIPTOR_NAME
metadata = Any.pack(testMessage)
},
ENCRYPTION_PUBLIC_KEY.toEncryptionPublicKey(),
)
}

runBlocking {
whenever(publicKingdomEventGroupsMock.listEventGroups(any()))
.thenReturn(
listEventGroupsResponse { eventGroups += listOf(CMMS_EVENT_GROUP, cmmsEventGroup2) }
)
}

val response =
withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME, CONFIG) {
runBlocking {
service.listEventGroups(
listEventGroupsRequest {
parent = MEASUREMENT_CONSUMER_NAME
filter = "metadata.metadata.publisher_id > 5"
}
)
}
}

assertThat(response.eventGroupsList).containsExactly(EVENT_GROUP)

verifyProtoArgument(publicKingdomEventGroupsMock, EventGroupsCoroutineImplBase::listEventGroups)
.ignoringRepeatedFieldOrder()
.isEqualTo(
cmmsListEventGroupsRequest {
parent = MEASUREMENT_CONSUMER_NAME
pageSize = DEFAULT_PAGE_SIZE
}
)
}

@Test
fun `listEventGroups returns only event groups that match filter when filter has no metadata`() {
val cmmsEventGroup2 =
Expand Down Expand Up @@ -593,14 +672,6 @@ class EventGroupsServiceTest {

@Test
fun `listEventGroups throws FAILED_PRECONDITION when store doesn't have private key`() {
val celEnvCacheProvider =
CelEnvCacheProvider(
EventGroupMetadataDescriptorsCoroutineStub(grpcTestServerRule.channel),
EventGroup.getDescriptor(),
Duration.ofSeconds(5),
emptyList(),
)

service =
EventGroupsService(
EventGroupsCoroutineStub(grpcTestServerRule.channel),
Expand Down

0 comments on commit cb75127

Please sign in to comment.