From 7943b207c70b35ef27b14b7d4e4cb7288995e687 Mon Sep 17 00:00:00 2001 From: Sanjay Vasandani Date: Wed, 14 Jun 2023 11:01:32 -0700 Subject: [PATCH] Update cross-media-measurement-api for EventGroup pattern. --- build/repositories.bzl | 5 +- .../measurement/api/v2alpha/EventGroupKey.kt | 5 +- .../MeasurementConsumerEventGroupKey.kt | 54 +++ .../measurement/common/api/ResourceKey.kt | 4 + .../spanner/SpannerEventGroupsService.kt | 63 ++- .../common/KingdomInternalException.kt | 19 +- .../spanner/queries/StreamEventGroups.kt | 35 +- .../spanner/readers/EventGroupReader.kt | 30 +- .../spanner/writers/DeleteEventGroup.kt | 32 +- .../spanner/writers/UpdateEventGroup.kt | 17 +- .../service/api/v2alpha/EventGroupsService.kt | 397 +++++++++--------- .../testing/EventGroupsServiceTest.kt | 169 +++++--- .../loadtest/frontend/FrontendSimulator.kt | 6 +- .../service/api/v1alpha/EventGroupsService.kt | 33 +- .../measurement/api/v2alpha/page_token.proto | 12 +- .../kingdom/event_groups_service.proto | 28 +- .../spanner/add-event-groups-by-mc.sql | 21 + .../resources/kingdom/spanner/changelog.yaml | 3 + .../k8s/EmptyClusterCorrectnessTest.kt | 2 +- .../api/v2alpha/EventGroupsServiceTest.kt | 310 ++++++-------- .../api/v1alpha/EventGroupsServiceTest.kt | 33 +- 21 files changed, 734 insertions(+), 544 deletions(-) create mode 100644 src/main/kotlin/org/wfanet/measurement/api/v2alpha/MeasurementConsumerEventGroupKey.kt create mode 100644 src/main/resources/kingdom/spanner/add-event-groups-by-mc.sql diff --git a/build/repositories.bzl b/build/repositories.bzl index f3017305647..921a6b1cbea 100644 --- a/build/repositories.bzl +++ b/build/repositories.bzl @@ -40,9 +40,10 @@ def wfa_measurement_system_repositories(): wfa_repo_archive( name = "wfa_measurement_proto", + # DO_NOT_SUBMIT(world-federation-of-advertisers/cross-media-measurement-api#148): Use version. + commit = "6fc6ecc98ba92be12e25c37c45f15a44bb407dab", repo = "cross-media-measurement-api", - sha256 = "e1738d74028be874e2ea4a3a7c9c2696f5aea60eb82c473771e8962cad838826", - version = "0.34.0", + sha256 = "ebacd9ced009bd68a924bbbed2d8a4265a1fb4bcbdaa17c951057eb84000f131", ) wfa_repo_archive( diff --git a/src/main/kotlin/org/wfanet/measurement/api/v2alpha/EventGroupKey.kt b/src/main/kotlin/org/wfanet/measurement/api/v2alpha/EventGroupKey.kt index 70fa6049f71..7219c660424 100644 --- a/src/main/kotlin/org/wfanet/measurement/api/v2alpha/EventGroupKey.kt +++ b/src/main/kotlin/org/wfanet/measurement/api/v2alpha/EventGroupKey.kt @@ -15,18 +15,21 @@ package org.wfanet.measurement.api.v2alpha import org.wfanet.measurement.common.ResourceNameParser +import org.wfanet.measurement.common.api.ChildResourceKey import org.wfanet.measurement.common.api.ResourceKey private val parser = ResourceNameParser("dataProviders/{data_provider}/eventGroups/{event_group}") /** [ResourceKey] of an EventGroup. */ -data class EventGroupKey(val dataProviderId: String, val eventGroupId: String) : ResourceKey { +data class EventGroupKey(val dataProviderId: String, val eventGroupId: String) : ChildResourceKey { override fun toName(): String { return parser.assembleName( mapOf(IdVariable.DATA_PROVIDER to dataProviderId, IdVariable.EVENT_GROUP to eventGroupId) ) } + override val parentKey = DataProviderKey(dataProviderId) + companion object FACTORY : ResourceKey.Factory { val defaultValue = EventGroupKey("", "") diff --git a/src/main/kotlin/org/wfanet/measurement/api/v2alpha/MeasurementConsumerEventGroupKey.kt b/src/main/kotlin/org/wfanet/measurement/api/v2alpha/MeasurementConsumerEventGroupKey.kt new file mode 100644 index 00000000000..944e869e0f0 --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/api/v2alpha/MeasurementConsumerEventGroupKey.kt @@ -0,0 +1,54 @@ +/* + * Copyright 2023 The Cross-Media Measurement Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.wfanet.measurement.api.v2alpha + +import org.wfanet.measurement.common.ResourceNameParser +import org.wfanet.measurement.common.api.ChildResourceKey +import org.wfanet.measurement.common.api.ResourceKey + +/** [ResourceKey] of an EventGroup with a MeasurementConsumer as the parent. */ +data class MeasurementConsumerEventGroupKey( + val measurementConsumerId: String, + val eventGroupId: String +) : ChildResourceKey { + override fun toName(): String { + return parser.assembleName( + mapOf( + IdVariable.MEASUREMENT_CONSUMER to measurementConsumerId, + IdVariable.EVENT_GROUP to eventGroupId + ) + ) + } + + override val parentKey = MeasurementConsumerKey(measurementConsumerId) + + companion object FACTORY : ResourceKey.Factory { + private val parser = + ResourceNameParser("measurementConsumers/{measurement_consumer}/eventGroups/{event_group}") + + val defaultValue = MeasurementConsumerEventGroupKey("", "") + + override fun fromName(resourceName: String): MeasurementConsumerEventGroupKey? { + return parser.parseIdVars(resourceName)?.let { + MeasurementConsumerEventGroupKey( + it.getValue(IdVariable.MEASUREMENT_CONSUMER), + it.getValue(IdVariable.EVENT_GROUP) + ) + } + } + } +} diff --git a/src/main/kotlin/org/wfanet/measurement/common/api/ResourceKey.kt b/src/main/kotlin/org/wfanet/measurement/common/api/ResourceKey.kt index ff77fe9f878..8e9b72b064d 100644 --- a/src/main/kotlin/org/wfanet/measurement/common/api/ResourceKey.kt +++ b/src/main/kotlin/org/wfanet/measurement/common/api/ResourceKey.kt @@ -32,3 +32,7 @@ interface ResourceKey { const val WILDCARD_ID = "-" } } + +interface ChildResourceKey : ResourceKey { + val parentKey: ResourceKey +} diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/SpannerEventGroupsService.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/SpannerEventGroupsService.kt index c16d3f660c0..825da9e4b4e 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/SpannerEventGroupsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/SpannerEventGroupsService.kt @@ -17,8 +17,8 @@ package org.wfanet.measurement.kingdom.deploy.gcloud.spanner import io.grpc.Status import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.map -import org.wfanet.measurement.common.grpc.failGrpc import org.wfanet.measurement.common.grpc.grpcRequire +import org.wfanet.measurement.common.identity.ExternalId import org.wfanet.measurement.common.identity.IdGenerator import org.wfanet.measurement.gcloud.spanner.AsyncDatabaseClient import org.wfanet.measurement.internal.kingdom.CreateEventGroupRequest @@ -32,6 +32,7 @@ import org.wfanet.measurement.internal.kingdom.eventGroup import org.wfanet.measurement.kingdom.deploy.gcloud.spanner.common.CertificateIsInvalidException import org.wfanet.measurement.kingdom.deploy.gcloud.spanner.common.DataProviderNotFoundException import org.wfanet.measurement.kingdom.deploy.gcloud.spanner.common.EventGroupInvalidArgsException +import org.wfanet.measurement.kingdom.deploy.gcloud.spanner.common.EventGroupNotFoundByMeasurementConsumerException import org.wfanet.measurement.kingdom.deploy.gcloud.spanner.common.EventGroupNotFoundException import org.wfanet.measurement.kingdom.deploy.gcloud.spanner.common.EventGroupStateIllegalException import org.wfanet.measurement.kingdom.deploy.gcloud.spanner.common.KingdomInternalException @@ -114,42 +115,60 @@ class SpannerEventGroupsService( } override suspend fun getEventGroup(request: GetEventGroupRequest): EventGroup { - return EventGroupReader() - .readByExternalIds( - client.singleUse(), - request.externalDataProviderId, - request.externalEventGroupId, - ) - ?.eventGroup - ?: failGrpc(Status.NOT_FOUND) { "EventGroup not found" } + grpcRequire(request.externalEventGroupId != 0L) { "external_event_group_id not specified" } + val externalEventGroupId = ExternalId(request.externalEventGroupId) + val reader = EventGroupReader() + + @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Protobuf enum fields cannot be null. + return when (request.externalParentIdCase) { + GetEventGroupRequest.ExternalParentIdCase.EXTERNAL_DATA_PROVIDER_ID -> { + val externalDataProviderId = ExternalId(request.externalDataProviderId) + reader.readByDataProvider(client.singleUse(), externalDataProviderId, externalEventGroupId) + ?: throw EventGroupNotFoundException(externalDataProviderId, externalEventGroupId) + .asStatusRuntimeException(Status.Code.NOT_FOUND) + } + GetEventGroupRequest.ExternalParentIdCase.EXTERNAL_MEASUREMENT_CONSUMER_ID -> { + val externalMeasurementConsumerId = ExternalId(request.externalMeasurementConsumerId) + reader.readByMeasurementConsumer( + client.singleUse(), + externalMeasurementConsumerId, + externalEventGroupId + ) + ?: throw EventGroupNotFoundByMeasurementConsumerException( + externalMeasurementConsumerId, + externalEventGroupId + ) + .asStatusRuntimeException(Status.Code.NOT_FOUND) + } + GetEventGroupRequest.ExternalParentIdCase.EXTERNALPARENTID_NOT_SET -> + throw Status.INVALID_ARGUMENT.withDescription("external_parent_id not specified") + .asRuntimeException() + }.eventGroup } override suspend fun deleteEventGroup(request: DeleteEventGroupRequest): EventGroup { - grpcRequire(request.externalDataProviderId > 0L) { "ExternalDataProviderId unspecified" } - grpcRequire(request.externalEventGroupId > 0L) { "ExternalEventGroupId unspecified" } - - val eventGroup = eventGroup { - externalDataProviderId = request.externalDataProviderId - externalEventGroupId = request.externalEventGroupId - } + grpcRequire(request.externalDataProviderId != 0L) { "external_data_provider_id unspecified" } + grpcRequire(request.externalEventGroupId > 0L) { "external_event_group_id unspecified" } try { - return DeleteEventGroup(eventGroup).execute(client, idGenerator) + return DeleteEventGroup(request).execute(client, idGenerator) } catch (e: EventGroupNotFoundException) { - throw e.asStatusRuntimeException(Status.Code.NOT_FOUND, "EventGroup not found.") + throw e.asStatusRuntimeException(Status.Code.NOT_FOUND) + } catch (e: EventGroupNotFoundByMeasurementConsumerException) { + throw e.asStatusRuntimeException(Status.Code.NOT_FOUND) } catch (e: EventGroupStateIllegalException) { - when (e.state) { + throw when (e.state) { EventGroup.State.DELETED -> { - throw e.asStatusRuntimeException(Status.Code.NOT_FOUND, "EventGroup state is DELETED.") + e.asStatusRuntimeException(Status.Code.NOT_FOUND) } EventGroup.State.ACTIVE, EventGroup.State.STATE_UNSPECIFIED, EventGroup.State.UNRECOGNIZED -> { - throw e.asStatusRuntimeException(Status.Code.INTERNAL, "Unexpected internal error.") + e.asStatusRuntimeException(Status.Code.INTERNAL) } } } catch (e: KingdomInternalException) { - throw e.asStatusRuntimeException(Status.Code.INTERNAL, "Unexpected internal error.") + throw e.asStatusRuntimeException(Status.Code.INTERNAL) } } diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/common/KingdomInternalException.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/common/KingdomInternalException.kt index cfa7de2880e..5ae386ebde7 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/common/KingdomInternalException.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/common/KingdomInternalException.kt @@ -517,8 +517,21 @@ class EventGroupNotFoundException( override val context get() = mapOf( - "external_data_provider_id" to externalDataProviderId.toString(), - "external_event_group_id" to externalEventGroupId.toString() + "external_data_provider_id" to externalDataProviderId.value.toString(), + "external_event_group_id" to externalEventGroupId.value.toString() + ) +} + +class EventGroupNotFoundByMeasurementConsumerException( + val externalMeasurementConsumerId: ExternalId, + val externalEventGroupId: ExternalId, + provideDescription: () -> String = { "EventGroup not found" } +) : KingdomInternalException(ErrorCode.EVENT_GROUP_NOT_FOUND, provideDescription) { + override val context + get() = + mapOf( + "external_measurement_consumer_id" to externalMeasurementConsumerId.value.toString(), + "external_event_group_id" to externalEventGroupId.value.toString() ) } @@ -539,7 +552,7 @@ class EventGroupStateIllegalException( val externalDataProviderId: ExternalId, val externalEventGroupId: ExternalId, val state: EventGroup.State, - provideDescription: () -> String = { "EventGroup state illegal" } + provideDescription: () -> String = { "EventGroup state is $state" } ) : KingdomInternalException(ErrorCode.EVENT_GROUP_STATE_ILLEGAL, provideDescription) { override val context get() = diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/queries/StreamEventGroups.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/queries/StreamEventGroups.kt index fd0304edf56..fff883c8419 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/queries/StreamEventGroups.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/queries/StreamEventGroups.kt @@ -36,17 +36,31 @@ class StreamEventGroups(requestFilter: StreamEventGroupsRequest.Filter, limit: I private fun Statement.Builder.appendWhereClause(filter: StreamEventGroupsRequest.Filter) { val conjuncts = mutableListOf() + if (filter.externalDataProviderId != 0L) { + conjuncts.add("ExternalDataProviderId = @$EXTERNAL_DATA_PROVIDER_ID") + bind(EXTERNAL_DATA_PROVIDER_ID).to(filter.externalDataProviderId) + } + if (filter.externalMeasurementConsumerId != 0L) { + conjuncts.add("ExternalMeasurementConsumerId = @$EXTERNAL_MEASUREMENT_CONSUMER_ID") + bind(EXTERNAL_MEASUREMENT_CONSUMER_ID).to(filter.externalMeasurementConsumerId) + } if (filter.externalMeasurementConsumerIdsList.isNotEmpty()) { conjuncts.add("ExternalMeasurementConsumerId IN UNNEST(@$EXTERNAL_MEASUREMENT_CONSUMER_IDS)") bind(EXTERNAL_MEASUREMENT_CONSUMER_IDS) .toInt64Array(filter.externalMeasurementConsumerIdsList.map { it.toLong() }) } - if (filter.externalDataProviderId != 0L) { - conjuncts.add("ExternalDataProviderId = @$EXTERNAL_DATA_PROVIDER_ID") - bind(EXTERNAL_DATA_PROVIDER_ID to filter.externalDataProviderId) + if (filter.externalDataProviderIdsList.isNotEmpty()) { + conjuncts.add("ExternalDataProviderId IN UNNEST(@$EXTERNAL_DATA_PROVIDER_IDS)") + bind(EXTERNAL_DATA_PROVIDER_IDS) + .toInt64Array(filter.externalDataProviderIdsList.map { it.toLong() }) + } + + if (!filter.showDeleted) { + conjuncts.add("State != @$DELETED_STATE") + bind(DELETED_STATE).toProtoEnum(EventGroup.State.DELETED) } - if (filter.externalEventGroupIdAfter != 0L && filter.externalDataProviderIdAfter != 0L) { + if (filter.hasAfter()) { conjuncts.add( """ ((ExternalDataProviderId > @$EXTERNAL_DATA_PROVIDER_ID_AFTER) @@ -55,13 +69,8 @@ class StreamEventGroups(requestFilter: StreamEventGroupsRequest.Filter, limit: I """ .trimIndent() ) - bind(EXTERNAL_DATA_PROVIDER_ID_AFTER).to(filter.externalDataProviderIdAfter) - bind(EXTERNAL_EVENT_GROUP_ID_AFTER).to(filter.externalEventGroupIdAfter) - } - - if (!filter.showDeleted) { - conjuncts.add("State != @$DELETED_STATE") - bind(DELETED_STATE).toProtoEnum(EventGroup.State.DELETED) + bind(EXTERNAL_DATA_PROVIDER_ID_AFTER).to(filter.after.externalDataProviderId) + bind(EXTERNAL_EVENT_GROUP_ID_AFTER).to(filter.after.externalEventGroupId) } if (conjuncts.isEmpty()) { @@ -74,8 +83,10 @@ class StreamEventGroups(requestFilter: StreamEventGroupsRequest.Filter, limit: I companion object { const val LIMIT = "limit" - const val EXTERNAL_MEASUREMENT_CONSUMER_IDS = "externalMeasurementConsumerIds" const val EXTERNAL_DATA_PROVIDER_ID = "externalDataProviderId" + const val EXTERNAL_MEASUREMENT_CONSUMER_ID = "externalMeasurementConsumerId" + const val EXTERNAL_MEASUREMENT_CONSUMER_IDS = "externalMeasurementConsumerIds" + const val EXTERNAL_DATA_PROVIDER_IDS = "externalDataProviderIds" const val EXTERNAL_EVENT_GROUP_ID_AFTER = "externalEventGroupIdAfter" const val EXTERNAL_DATA_PROVIDER_ID_AFTER = "externalDataProviderIdAfter" const val DELETED_STATE = "deletedState" diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/readers/EventGroupReader.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/readers/EventGroupReader.kt index 6937aa681a4..f6446bca92a 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/readers/EventGroupReader.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/readers/EventGroupReader.kt @@ -17,6 +17,7 @@ package org.wfanet.measurement.kingdom.deploy.gcloud.spanner.readers import com.google.cloud.spanner.Statement import com.google.cloud.spanner.Struct import kotlinx.coroutines.flow.singleOrNull +import org.wfanet.measurement.common.identity.ExternalId import org.wfanet.measurement.common.identity.InternalId import org.wfanet.measurement.gcloud.spanner.AsyncDatabaseClient import org.wfanet.measurement.gcloud.spanner.appendClause @@ -62,10 +63,10 @@ class EventGroupReader : BaseSpannerReader() { .singleOrNull() } - suspend fun readByExternalIds( + suspend fun readByDataProvider( readContext: AsyncDatabaseClient.ReadContext, - externalDataProviderId: Long, - externalEventGroupId: Long, + externalDataProviderId: ExternalId, + externalEventGroupId: ExternalId, ): Result? { return fillStatementBuilder { appendClause( @@ -84,6 +85,28 @@ class EventGroupReader : BaseSpannerReader() { .singleOrNull() } + suspend fun readByMeasurementConsumer( + readContext: AsyncDatabaseClient.ReadContext, + externalMeasurementConsumerId: ExternalId, + externalEventGroupId: ExternalId, + ): Result? { + return fillStatementBuilder { + appendClause( + """ + WHERE + ExternalMeasurementConsumerId = @${Params.EXTERNAL_MEASUREMENT_CONSUMER_ID} + AND ExternalEventGroupId = @${Params.EXTERNAL_EVENT_GROUP_ID} + """ + .trimIndent() + ) + bind(Params.EXTERNAL_MEASUREMENT_CONSUMER_ID to externalMeasurementConsumerId) + bind(Params.EXTERNAL_EVENT_GROUP_ID to externalEventGroupId) + appendClause("LIMIT 1") + } + .execute(readContext) + .singleOrNull() + } + override suspend fun translate(struct: Struct): Result = Result( buildEventGroup(struct), @@ -136,6 +159,7 @@ class EventGroupReader : BaseSpannerReader() { private object Params { const val EXTERNAL_DATA_PROVIDER_ID = "externalDataProviderId" + const val EXTERNAL_MEASUREMENT_CONSUMER_ID = "externalMeasurementConsumerId" const val EXTERNAL_EVENT_GROUP_ID = "externalEventGroupId" const val DATA_PROVIDER_ID = "dataProviderId" const val CREATE_REQUEST_ID = "createRequestId" diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/writers/DeleteEventGroup.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/writers/DeleteEventGroup.kt index 8d9052126aa..65a0ff78126 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/writers/DeleteEventGroup.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/writers/DeleteEventGroup.kt @@ -21,6 +21,7 @@ import org.wfanet.measurement.common.identity.ExternalId import org.wfanet.measurement.gcloud.spanner.bufferUpdateMutation import org.wfanet.measurement.gcloud.spanner.set import org.wfanet.measurement.gcloud.spanner.setJson +import org.wfanet.measurement.internal.kingdom.DeleteEventGroupRequest import org.wfanet.measurement.internal.kingdom.EventGroup import org.wfanet.measurement.internal.kingdom.copy import org.wfanet.measurement.kingdom.deploy.gcloud.spanner.common.EventGroupNotFoundException @@ -35,31 +36,26 @@ import org.wfanet.measurement.kingdom.deploy.gcloud.spanner.readers.EventGroupRe * * [EventGroupNotFoundException] EventGroup not found * * [EventGroupStateIllegalException] EventGroup state is DELETED */ -class DeleteEventGroup(private val eventGroup: EventGroup) : +class DeleteEventGroup(private val request: DeleteEventGroupRequest) : SpannerWriter() { override suspend fun TransactionScope.runTransaction(): EventGroup { - val internalEventGroupResult = + val externalEventGroupId = ExternalId(request.externalEventGroupId) + val externalDataProviderId = ExternalId(request.externalDataProviderId) + val result: EventGroupReader.Result = EventGroupReader() - .readByExternalIds( - transactionContext, - eventGroup.externalDataProviderId, - eventGroup.externalEventGroupId, - ) - ?: throw EventGroupNotFoundException( - ExternalId(eventGroup.externalDataProviderId), - ExternalId(eventGroup.externalEventGroupId), - ) - if (internalEventGroupResult.eventGroup.state == EventGroup.State.DELETED) { + .readByDataProvider(transactionContext, externalDataProviderId, externalEventGroupId) + ?: throw EventGroupNotFoundException(externalDataProviderId, externalEventGroupId) + if (result.eventGroup.state == EventGroup.State.DELETED) { throw EventGroupStateIllegalException( - ExternalId(eventGroup.externalDataProviderId), - ExternalId(eventGroup.externalEventGroupId), - internalEventGroupResult.eventGroup.state + ExternalId(result.eventGroup.externalDataProviderId), + externalEventGroupId, + result.eventGroup.state ) } transactionContext.bufferUpdateMutation("EventGroups") { - set("DataProviderId" to internalEventGroupResult.internalDataProviderId.value) - set("EventGroupId" to internalEventGroupResult.internalEventGroupId.value) + set("DataProviderId" to result.internalDataProviderId.value) + set("EventGroupId" to result.internalEventGroupId.value) set("MeasurementConsumerCertificateId" to null as Long?) set("UpdateTime" to Value.COMMIT_TIMESTAMP) set("EventGroupDetails" to EventGroup.Details.getDefaultInstance()) @@ -67,7 +63,7 @@ class DeleteEventGroup(private val eventGroup: EventGroup) : set("State" to EventGroup.State.DELETED) } - return internalEventGroupResult.eventGroup.copy { + return result.eventGroup.copy { this.externalMeasurementConsumerCertificateId = 0L this.details = EventGroup.Details.getDefaultInstance() this.state = EventGroup.State.DELETED diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/writers/UpdateEventGroup.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/writers/UpdateEventGroup.kt index 3c3fb291f0f..f7a77895555 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/writers/UpdateEventGroup.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/writers/UpdateEventGroup.kt @@ -38,21 +38,16 @@ import org.wfanet.measurement.kingdom.deploy.gcloud.spanner.writers.checkValidCe class UpdateEventGroup(private val eventGroup: EventGroup) : SpannerWriter() { override suspend fun TransactionScope.runTransaction(): EventGroup { + val externalDataProviderId = ExternalId(eventGroup.externalDataProviderId) + val externalEventGroupId = ExternalId(eventGroup.externalEventGroupId) val internalEventGroupResult = EventGroupReader() - .readByExternalIds( - transactionContext, - eventGroup.externalDataProviderId, - eventGroup.externalEventGroupId - ) - ?: throw EventGroupNotFoundException( - ExternalId(eventGroup.externalDataProviderId), - ExternalId(eventGroup.externalEventGroupId) - ) + .readByDataProvider(transactionContext, externalDataProviderId, externalEventGroupId) + ?: throw EventGroupNotFoundException(externalDataProviderId, externalEventGroupId) if (internalEventGroupResult.eventGroup.state == EventGroup.State.DELETED) { throw EventGroupStateIllegalException( - ExternalId(eventGroup.externalEventGroupId), - ExternalId(eventGroup.externalEventGroupId), + externalEventGroupId, + externalEventGroupId, internalEventGroupResult.eventGroup.state ) } diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/EventGroupsService.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/EventGroupsService.kt index 1fc0079940c..fa8f2423395 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/EventGroupsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/EventGroupsService.kt @@ -35,96 +35,126 @@ import org.wfanet.measurement.api.v2alpha.ListEventGroupsPageTokenKt.previousPag import org.wfanet.measurement.api.v2alpha.ListEventGroupsRequest import org.wfanet.measurement.api.v2alpha.ListEventGroupsResponse import org.wfanet.measurement.api.v2alpha.MeasurementConsumerCertificateKey +import org.wfanet.measurement.api.v2alpha.MeasurementConsumerEventGroupKey import org.wfanet.measurement.api.v2alpha.MeasurementConsumerKey import org.wfanet.measurement.api.v2alpha.MeasurementConsumerPrincipal import org.wfanet.measurement.api.v2alpha.MeasurementPrincipal import org.wfanet.measurement.api.v2alpha.UpdateEventGroupRequest -import org.wfanet.measurement.api.v2alpha.copy import org.wfanet.measurement.api.v2alpha.eventGroup import org.wfanet.measurement.api.v2alpha.listEventGroupsPageToken import org.wfanet.measurement.api.v2alpha.listEventGroupsResponse import org.wfanet.measurement.api.v2alpha.principalFromCurrentContext import org.wfanet.measurement.api.v2alpha.signedData +import org.wfanet.measurement.common.api.ChildResourceKey import org.wfanet.measurement.common.api.ResourceKey import org.wfanet.measurement.common.base64UrlDecode import org.wfanet.measurement.common.base64UrlEncode import org.wfanet.measurement.common.grpc.failGrpc import org.wfanet.measurement.common.grpc.grpcRequire import org.wfanet.measurement.common.grpc.grpcRequireNotNull +import org.wfanet.measurement.common.identity.ApiId +import org.wfanet.measurement.common.identity.ExternalId import org.wfanet.measurement.common.identity.apiIdToExternalId import org.wfanet.measurement.common.identity.externalIdToApiId import org.wfanet.measurement.internal.kingdom.CreateEventGroupRequest as InternalCreateEventGroupRequest import org.wfanet.measurement.internal.kingdom.EventGroup as InternalEventGroup import org.wfanet.measurement.internal.kingdom.EventGroupKt import org.wfanet.measurement.internal.kingdom.EventGroupKt.details -import org.wfanet.measurement.internal.kingdom.EventGroupsGrpcKt.EventGroupsCoroutineStub +import org.wfanet.measurement.internal.kingdom.EventGroupsGrpcKt.EventGroupsCoroutineStub as InternalEventGroupsCoroutineStub +import org.wfanet.measurement.internal.kingdom.GetEventGroupRequest as InternalGetEventGroupRequest import org.wfanet.measurement.internal.kingdom.StreamEventGroupsRequest -import org.wfanet.measurement.internal.kingdom.StreamEventGroupsRequestKt.filter +import org.wfanet.measurement.internal.kingdom.StreamEventGroupsRequestKt as InternalStreamEventGroupsRequests import org.wfanet.measurement.internal.kingdom.createEventGroupRequest as internalCreateEventGroupRequest import org.wfanet.measurement.internal.kingdom.deleteEventGroupRequest import org.wfanet.measurement.internal.kingdom.eventGroup as internalEventGroup -import org.wfanet.measurement.internal.kingdom.getEventGroupRequest +import org.wfanet.measurement.internal.kingdom.eventGroupKey +import org.wfanet.measurement.internal.kingdom.getEventGroupRequest as internalGetEventGroupRequest import org.wfanet.measurement.internal.kingdom.streamEventGroupsRequest import org.wfanet.measurement.internal.kingdom.updateEventGroupRequest -private const val MIN_PAGE_SIZE = 1 -private const val DEFAULT_PAGE_SIZE = 50 -private const val MAX_PAGE_SIZE = 1000 -private const val WILDCARD = ResourceKey.WILDCARD_ID -private val API_VERSION = Version.V2_ALPHA - -class EventGroupsService(private val internalEventGroupsStub: EventGroupsCoroutineStub) : +class EventGroupsService(private val internalEventGroupsStub: InternalEventGroupsCoroutineStub) : EventGroupsCoroutineImplBase() { override suspend fun getEventGroup(request: GetEventGroupRequest): EventGroup { - val key = - grpcRequireNotNull(EventGroupKey.fromName(request.name)) { + fun permissionDeniedStatus() = + Status.PERMISSION_DENIED.withDescription( + "Permission denied on resource ${request.name} (or it might not exist)" + ) + + val key: ChildResourceKey = + grpcRequireNotNull( + EventGroupKey.fromName(request.name) + ?: MeasurementConsumerEventGroupKey.fromName(request.name) + ) { "Resource name is either unspecified or invalid" } - - val principal: MeasurementPrincipal = principalFromCurrentContext - - when (principal) { - is DataProviderPrincipal -> { - if (principal.resourceKey.dataProviderId != key.dataProviderId) { - failGrpc(Status.PERMISSION_DENIED) { - "Cannot get EventGroups belonging to other DataProviders" + val principal = principalFromCurrentContext + + val internalRequest: InternalGetEventGroupRequest = + when (key) { + is EventGroupKey -> { + val denied = + when (principal) { + is DataProviderPrincipal -> principal.resourceKey != key.parentKey + is MeasurementConsumerPrincipal -> false + else -> true + } + if (denied) throw permissionDeniedStatus().asRuntimeException() + internalGetEventGroupRequest { + externalDataProviderId = ApiId(key.dataProviderId).externalId.value + externalEventGroupId = ApiId(key.eventGroupId).externalId.value } } + is MeasurementConsumerEventGroupKey -> { + if (key.parentKey != principal.resourceKey) { + throw permissionDeniedStatus().asRuntimeException() + } + internalGetEventGroupRequest { + externalMeasurementConsumerId = ApiId(key.measurementConsumerId).externalId.value + externalEventGroupId = ApiId(key.eventGroupId).externalId.value + } + } + else -> error("Unexpected resource key $key") } - is MeasurementConsumerPrincipal -> {} - else -> { - failGrpc(Status.PERMISSION_DENIED) { "Caller does not have permission to get EventGroups" } - } - } - - val getRequest = getEventGroupRequest { - externalDataProviderId = apiIdToExternalId(key.dataProviderId) - externalEventGroupId = apiIdToExternalId(key.eventGroupId) - } - - val eventGroup = + val internalResponse: InternalEventGroup = try { - internalEventGroupsStub.getEventGroup(getRequest).toEventGroup() - } catch (ex: StatusException) { - when (ex.status.code) { - Status.Code.NOT_FOUND -> failGrpc(Status.NOT_FOUND, ex) { "EventGroup not found." } - else -> failGrpc(Status.UNKNOWN, ex) { "Unknown exception." } - } + internalEventGroupsStub.getEventGroup(internalRequest) + } catch (e: StatusException) { + throw when (e.status.code) { + Status.Code.NOT_FOUND -> + if (key.parentKey == principal.resourceKey) { + Status.NOT_FOUND + } else { + permissionDeniedStatus() + } + Status.Code.DEADLINE_EXCEEDED -> Status.DEADLINE_EXCEEDED + else -> Status.UNKNOWN + } + .withCause(e) + .asRuntimeException() } when (principal) { + is DataProviderPrincipal -> { + if ( + ExternalId(internalResponse.externalDataProviderId) != + ApiId(principal.resourceKey.dataProviderId).externalId + ) { + throw permissionDeniedStatus().asRuntimeException() + } + } is MeasurementConsumerPrincipal -> { - if (eventGroup.measurementConsumer != principal.resourceKey.toName()) { - failGrpc(Status.PERMISSION_DENIED) { - "Cannot get EventGroups belonging to other MeasurementConsumers" - } + if ( + ExternalId(internalResponse.externalMeasurementConsumerId) != + ApiId(principal.resourceKey.measurementConsumerId).externalId + ) { + throw permissionDeniedStatus().asRuntimeException() } } - else -> {} + else -> throw permissionDeniedStatus().asRuntimeException() } - return eventGroup + return internalResponse.toEventGroup() } override suspend fun createEventGroup(request: CreateEventGroupRequest): EventGroup { @@ -233,22 +263,18 @@ class EventGroupsService(private val internalEventGroupsStub: EventGroupsCorouti } override suspend fun deleteEventGroup(request: DeleteEventGroupRequest): EventGroup { + fun permissionDeniedStatus() = + Status.PERMISSION_DENIED.withDescription( + "Permission denied on resource ${request.name} (or it might not exist)" + ) + val eventGroupKey = grpcRequireNotNull(EventGroupKey.fromName(request.name)) { - "EventGroup name is either unspecified or invalid" + "Resource name is either unspecified or invalid" } - when (val principal: MeasurementPrincipal = principalFromCurrentContext) { - is DataProviderPrincipal -> { - if (principal.resourceKey.dataProviderId != eventGroupKey.dataProviderId) { - failGrpc(Status.PERMISSION_DENIED) { - "Cannot delete EventGroups for another DataProvider" - } - } - } - else -> { - failGrpc(Status.PERMISSION_DENIED) { "Only a DataProvider can delete an EventGroup" } - } + if (principalFromCurrentContext.resourceKey != eventGroupKey.parentKey) { + throw permissionDeniedStatus().asRuntimeException() } val deleteRequest = deleteEventGroupRequest { @@ -271,46 +297,38 @@ class EventGroupsService(private val internalEventGroupsStub: EventGroupsCorouti } override suspend fun listEventGroups(request: ListEventGroupsRequest): ListEventGroupsResponse { - val listEventGroupsPageToken = request.toListEventGroupPageToken() + fun permissionDeniedStatus() = + Status.PERMISSION_DENIED.withDescription( + "Permission ListEventGroups denied on resource ${request.parent} (or it might not exist)" + ) - when (val principal: MeasurementPrincipal = principalFromCurrentContext) { - is DataProviderPrincipal -> { - if ( - apiIdToExternalId(principal.resourceKey.dataProviderId) != - listEventGroupsPageToken.externalDataProviderId - ) { - failGrpc(Status.PERMISSION_DENIED) { - "Cannot list EventGroups belonging to other DataProviders" - } - } - } - is MeasurementConsumerPrincipal -> { - val externalMeasurementConsumerId = - apiIdToExternalId(principal.resourceKey.measurementConsumerId) - if (listEventGroupsPageToken.externalMeasurementConsumerIdsList.isEmpty()) { - failGrpc(Status.PERMISSION_DENIED) { - "Cannot list Event Groups belonging to other MeasurementConsumers" - } - } + grpcRequire(request.pageSize >= 0) { "Page size cannot be less than 0" } - listEventGroupsPageToken.externalMeasurementConsumerIdsList.forEach { - if (it != externalMeasurementConsumerId) { - failGrpc(Status.PERMISSION_DENIED) { - "Cannot list Event Groups belonging to other MeasurementConsumers" - } - } - } - } - else -> { - failGrpc(Status.PERMISSION_DENIED) { "Caller does not have permission to list EventGroups" } - } + val parentKey: ResourceKey = + DataProviderKey.fromName(request.parent) + ?: MeasurementConsumerKey.fromName(request.parent) + ?: throw Status.INVALID_ARGUMENT.withDescription("parent unspecified or invalid") + .asRuntimeException() + if (parentKey != principalFromCurrentContext.resourceKey) { + throw permissionDeniedStatus().asRuntimeException() } - val results: List = + val pageToken: ListEventGroupsPageToken? = + if (request.pageToken.isEmpty()) null + else ListEventGroupsPageToken.parseFrom(request.pageToken.base64UrlDecode()) + val pageSize = + if (request.pageSize == 0) DEFAULT_PAGE_SIZE else request.pageSize.coerceAtMost(MAX_PAGE_SIZE) + val internalRequest = + buildInternalStreamEventGroupsRequest( + request.filter, + request.showDeleted, + parentKey, + pageSize, + pageToken + ) + val internalEventGroups: List = try { - internalEventGroupsStub - .streamEventGroups(listEventGroupsPageToken.toStreamEventGroupsRequest()) - .toList() + internalEventGroupsStub.streamEventGroups(internalRequest).toList() } catch (e: StatusException) { throw when (e.status.code) { Status.Code.DEADLINE_EXCEEDED -> Status.DEADLINE_EXCEEDED @@ -320,27 +338,115 @@ class EventGroupsService(private val internalEventGroupsStub: EventGroupsCorouti .asRuntimeException() } - if (results.isEmpty()) { + if (internalEventGroups.isEmpty()) { return ListEventGroupsResponse.getDefaultInstance() } return listEventGroupsResponse { eventGroups += - results - .subList(0, min(results.size, listEventGroupsPageToken.pageSize)) + internalEventGroups + .subList(0, min(internalEventGroups.size, pageSize)) .map(InternalEventGroup::toEventGroup) - if (results.size > listEventGroupsPageToken.pageSize) { - val pageToken = - listEventGroupsPageToken.copy { - lastEventGroup = previousPageEnd { - externalDataProviderId = results[results.lastIndex - 1].externalDataProviderId - externalEventGroupId = results[results.lastIndex - 1].externalEventGroupId + if (internalEventGroups.size > pageSize) { + nextPageToken = + buildNextPageToken(internalRequest.filter, internalEventGroups) + .toByteString() + .base64UrlEncode() + } + } + } + + private fun buildNextPageToken( + internalFilter: StreamEventGroupsRequest.Filter, + results: List, + ): ListEventGroupsPageToken { + return listEventGroupsPageToken { + if (internalFilter.externalDataProviderId != 0L) { + externalDataProviderId = internalFilter.externalDataProviderId + } + if (internalFilter.externalMeasurementConsumerId != 0L) { + externalMeasurementConsumerId = internalFilter.externalMeasurementConsumerId + } + externalDataProviderIds += internalFilter.externalDataProviderIdsList + externalMeasurementConsumerIds += internalFilter.externalMeasurementConsumerIdsList + lastEventGroup = previousPageEnd { + externalDataProviderId = results[results.lastIndex - 1].externalDataProviderId + externalEventGroupId = results[results.lastIndex - 1].externalEventGroupId + } + } + } + + /** + * Builds a [StreamEventGroupsRequest] for [listEventGroups]. + * + * @throws io.grpc.StatusRuntimeException if [request] is found to be invalid + */ + private fun buildInternalStreamEventGroupsRequest( + filter: ListEventGroupsRequest.Filter, + showDeleted: Boolean, + parentKey: ResourceKey, + pageSize: Int, + pageToken: ListEventGroupsPageToken?, + ): StreamEventGroupsRequest { + return streamEventGroupsRequest { + this.filter = + InternalStreamEventGroupsRequests.filter { + if (parentKey is DataProviderKey) { + externalDataProviderId = ApiId(parentKey.dataProviderId).externalId.value + } + if (parentKey is MeasurementConsumerKey) { + externalMeasurementConsumerId = ApiId(parentKey.measurementConsumerId).externalId.value + } + if (filter.measurementConsumersList.isNotEmpty()) { + externalMeasurementConsumerIds += + filter.measurementConsumersList.map { + val measurementConsumerKey = + grpcRequireNotNull(MeasurementConsumerKey.fromName(it)) { + "Invalid resource name in filter.measurement_consumers" + } + ApiId(measurementConsumerKey.measurementConsumerId).externalId.value + } + } + if (filter.dataProvidersList.isNotEmpty()) { + externalDataProviderIds += + filter.dataProvidersList.map { + val dataProviderKey = + grpcRequireNotNull(DataProviderKey.fromName(it)) { + "Invalid resource name in filter.data_providers" + } + ApiId(dataProviderKey.dataProviderId).externalId.value + } + } + if (showDeleted) { + this.showDeleted = showDeleted + } + if (pageToken != null) { + if ( + pageToken.externalDataProviderId != externalDataProviderId || + pageToken.externalMeasurementConsumerId != externalMeasurementConsumerId || + pageToken.showDeleted != showDeleted || + pageToken.externalDataProviderIdsList != externalDataProviderIds || + pageToken.externalMeasurementConsumerIdsList != externalMeasurementConsumerIds + ) { + throw Status.INVALID_ARGUMENT.withDescription( + "Arguments other than page_size must remain the same for subsequent page requests" + ) + .asRuntimeException() + } + after = eventGroupKey { + externalDataProviderId = pageToken.lastEventGroup.externalDataProviderId + externalEventGroupId = pageToken.lastEventGroup.externalEventGroupId } } - nextPageToken = pageToken.toByteArray().base64UrlEncode() - } + } + limit = pageSize + 1 } } + + companion object { + private const val DEFAULT_PAGE_SIZE = 50 + private const val MAX_PAGE_SIZE = 1000 + } } /** Converts an internal [InternalEventGroup] to a public [EventGroup]. */ @@ -403,7 +509,7 @@ private fun EventGroup.toInternal( providedEventGroupId = eventGroupReferenceId details = details { - apiVersion = API_VERSION.string + apiVersion = Version.V2_ALPHA.string measurementConsumerPublicKey = this@toInternal.measurementConsumerPublicKey.data measurementConsumerPublicKeySignature = this@toInternal.measurementConsumerPublicKey.signature vidModelLines += this@toInternal.vidModelLinesList @@ -416,86 +522,3 @@ private fun EventGroup.toInternal( } } } - -/** Converts a public [ListEventGroupsRequest] to an internal [ListEventGroupsPageToken]. */ -private fun ListEventGroupsRequest.toListEventGroupPageToken(): ListEventGroupsPageToken { - val source = this - - grpcRequire(source.pageSize >= 0) { "Page size cannot be less than 0" } - - val parentKey: DataProviderKey = - grpcRequireNotNull(DataProviderKey.fromName(source.parent)) { - "Parent is either unspecified or invalid" - } - - grpcRequire( - (source.filter.measurementConsumersCount > 0 && parentKey.dataProviderId == WILDCARD) || - parentKey.dataProviderId != WILDCARD - ) { - "Either parent data provider or measurement consumers filter must be provided" - } - - var externalDataProviderId = 0L - if (parentKey.dataProviderId != WILDCARD) { - externalDataProviderId = apiIdToExternalId(parentKey.dataProviderId) - } - - val externalMeasurementConsumerIdsList = - source.filter.measurementConsumersList.map { measurementConsumerName -> - grpcRequireNotNull(MeasurementConsumerKey.fromName(measurementConsumerName)) { - "Measurement consumer name in filter invalid" - } - .let { key -> apiIdToExternalId(key.measurementConsumerId) } - } - - return if (source.pageToken.isNotBlank()) { - ListEventGroupsPageToken.parseFrom(source.pageToken.base64UrlDecode()).copy { - grpcRequire(this.externalDataProviderId == externalDataProviderId) { - "Arguments must be kept the same when using a page token" - } - - grpcRequire( - externalMeasurementConsumerIdsList.containsAll(externalMeasurementConsumerIds) && - externalMeasurementConsumerIds.containsAll(externalMeasurementConsumerIdsList) - ) { - "Arguments must be kept the same when using a page token" - } - - if ( - source.pageSize != 0 && source.pageSize >= MIN_PAGE_SIZE && source.pageSize <= MAX_PAGE_SIZE - ) { - pageSize = source.pageSize - } - this.showDeleted = source.showDeleted - } - } else { - listEventGroupsPageToken { - pageSize = - when { - source.pageSize < MIN_PAGE_SIZE -> DEFAULT_PAGE_SIZE - source.pageSize > MAX_PAGE_SIZE -> MAX_PAGE_SIZE - else -> source.pageSize - } - - this.externalDataProviderId = externalDataProviderId - externalMeasurementConsumerIds += externalMeasurementConsumerIdsList - this.showDeleted = source.showDeleted - } - } -} - -/** Converts an internal [ListEventGroupsPageToken] to an internal [StreamEventGroupsRequest]. */ -private fun ListEventGroupsPageToken.toStreamEventGroupsRequest(): StreamEventGroupsRequest { - val source = this - return streamEventGroupsRequest { - // get 1 more than the actual page size for deciding whether or not to set page token - limit = source.pageSize + 1 - filter = filter { - externalDataProviderId = source.externalDataProviderId - externalMeasurementConsumerIds += source.externalMeasurementConsumerIdsList - externalDataProviderIdAfter = source.lastEventGroup.externalDataProviderId - externalEventGroupIdAfter = source.lastEventGroup.externalEventGroupId - showDeleted = source.showDeleted - } - } -} diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/service/internal/testing/EventGroupsServiceTest.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/service/internal/testing/EventGroupsServiceTest.kt index bc358f6d42e..b4eb6fb9824 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/service/internal/testing/EventGroupsServiceTest.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/service/internal/testing/EventGroupsServiceTest.kt @@ -45,6 +45,7 @@ import org.wfanet.measurement.internal.kingdom.copy import org.wfanet.measurement.internal.kingdom.createEventGroupRequest import org.wfanet.measurement.internal.kingdom.deleteEventGroupRequest import org.wfanet.measurement.internal.kingdom.eventGroup +import org.wfanet.measurement.internal.kingdom.eventGroupKey import org.wfanet.measurement.internal.kingdom.getEventGroupRequest import org.wfanet.measurement.internal.kingdom.streamEventGroupsRequest import org.wfanet.measurement.internal.kingdom.updateEventGroupRequest @@ -91,7 +92,7 @@ abstract class EventGroupsServiceTest { } @Test - fun `getEventGroup fails for missing EventGroup`() = runBlocking { + fun `getEventGroup throws INVALID_ARGUMENT when parent ID omitted`() = runBlocking { val exception = assertFailsWith { eventGroupsService.getEventGroup( @@ -99,8 +100,30 @@ abstract class EventGroupsServiceTest { ) } + assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) + assertThat(exception).hasMessageThat().contains("external_parent_id") + } + + @Test + fun `getEventGroup throws NOT_FOUND when EventGroup not found`(): Unit = runBlocking { + val exception = + assertFailsWith { + eventGroupsService.getEventGroup( + getEventGroupRequest { + externalDataProviderId = 404L + externalEventGroupId = EXTERNAL_EVENT_GROUP_ID + } + ) + } + assertThat(exception.status.code).isEqualTo(Status.Code.NOT_FOUND) - assertThat(exception).hasMessageThat().contains("NOT_FOUND: EventGroup not found") + assertThat(exception.errorInfo?.metadataMap) + .containsAtLeast( + "external_data_provider_id", + "404", + "external_event_group_id", + EXTERNAL_EVENT_GROUP_ID.toString() + ) } @Test @@ -224,37 +247,36 @@ abstract class EventGroupsServiceTest { } @Test - fun `createEventGroup succeeds`() = runBlocking { + fun `createEventGroup returns created EventGroup`() = runBlocking { val measurementConsumer = population.createMeasurementConsumer(measurementConsumersService, accountsService) val externalMeasurementConsumerId = measurementConsumer.externalMeasurementConsumerId val externalCertificateId = measurementConsumer.certificate.externalCertificateId - val externalDataProviderId = population.createDataProvider(dataProvidersService).externalDataProviderId - - val eventGroup = eventGroup { - this.externalDataProviderId = externalDataProviderId - this.externalMeasurementConsumerId = externalMeasurementConsumerId - providedEventGroupId = PROVIDED_EVENT_GROUP_ID - externalMeasurementConsumerCertificateId = externalCertificateId - details = DETAILS + val request = createEventGroupRequest { + this.eventGroup = eventGroup { + this.externalDataProviderId = externalDataProviderId + this.externalMeasurementConsumerId = externalMeasurementConsumerId + providedEventGroupId = PROVIDED_EVENT_GROUP_ID + externalMeasurementConsumerCertificateId = externalCertificateId + details = DETAILS + } } - val createdEventGroup = - eventGroupsService.createEventGroup(createEventGroupRequest { this.eventGroup = eventGroup }) + val response: EventGroup = eventGroupsService.createEventGroup(request) - assertThat(createdEventGroup) + assertThat(response) .ignoringFields( EventGroup.EXTERNAL_EVENT_GROUP_ID_FIELD_NUMBER, EventGroup.CREATE_TIME_FIELD_NUMBER, EventGroup.UPDATE_TIME_FIELD_NUMBER, ) - .isEqualTo(eventGroup.copy { this.state = EventGroup.State.ACTIVE }) - assertThat(createdEventGroup.externalEventGroupId).isGreaterThan(0) - assertThat(createdEventGroup.createTime.seconds).isGreaterThan(0) - assertThat(createdEventGroup.updateTime).isEqualTo(createdEventGroup.createTime) - assertThat(createdEventGroup.state).isEqualTo(EventGroup.State.ACTIVE) + .isEqualTo(request.eventGroup.copy { this.state = EventGroup.State.ACTIVE }) + assertThat(response.externalEventGroupId).isNotEqualTo(0) + assertThat(response.createTime.seconds).isGreaterThan(0) + assertThat(response.updateTime).isEqualTo(response.createTime) + assertThat(response.state).isEqualTo(EventGroup.State.ACTIVE) } @Test @@ -531,63 +553,61 @@ abstract class EventGroupsServiceTest { } @Test - fun `getEventGroup succeeds`() = runBlocking { + fun `getEventGroup returns EventGroup by DataProvider`() = runBlocking { val externalMeasurementConsumerId = population .createMeasurementConsumer(measurementConsumersService, accountsService) .externalMeasurementConsumerId - val externalDataProviderId = population.createDataProvider(dataProvidersService).externalDataProviderId + val eventGroup = + eventGroupsService.createEventGroup( + createEventGroupRequest { + eventGroup = eventGroup { + this.externalDataProviderId = externalDataProviderId + this.externalMeasurementConsumerId = externalMeasurementConsumerId + } + } + ) - val eventGroup = eventGroup { - this.externalDataProviderId = externalDataProviderId - this.externalMeasurementConsumerId = externalMeasurementConsumerId - } - - val createdEventGroup = - eventGroupsService.createEventGroup(createEventGroupRequest { this.eventGroup = eventGroup }) - - val eventGroupRead = + val response = eventGroupsService.getEventGroup( - GetEventGroupRequest.newBuilder() - .also { - it.externalDataProviderId = externalDataProviderId - it.externalEventGroupId = createdEventGroup.externalEventGroupId - } - .build() + getEventGroupRequest { + this.externalDataProviderId = externalDataProviderId + externalEventGroupId = eventGroup.externalEventGroupId + } ) - assertThat(eventGroupRead).isEqualTo(createdEventGroup) + assertThat(response).isEqualTo(eventGroup) } @Test - fun `getEventGroup succeeds when certificate id is set`() = runBlocking { - val measurementConsumer = - population.createMeasurementConsumer(measurementConsumersService, accountsService) - + fun `getEventGroup returns EventGroup by MeasurementConsumer`() = runBlocking { + val externalMeasurementConsumerId = + population + .createMeasurementConsumer(measurementConsumersService, accountsService) + .externalMeasurementConsumerId val externalDataProviderId = population.createDataProvider(dataProvidersService).externalDataProviderId + val eventGroup = + eventGroupsService.createEventGroup( + createEventGroupRequest { + eventGroup = eventGroup { + this.externalDataProviderId = externalDataProviderId + this.externalMeasurementConsumerId = externalMeasurementConsumerId + } + } + ) - val eventGroup = eventGroup { - this.externalDataProviderId = externalDataProviderId - this.externalMeasurementConsumerId = measurementConsumer.externalMeasurementConsumerId - this.externalMeasurementConsumerCertificateId = - measurementConsumer.certificate.externalCertificateId - } - - val createdEventGroup = - eventGroupsService.createEventGroup(createEventGroupRequest { this.eventGroup = eventGroup }) - - val retrievedEventGroup = + val response: EventGroup = eventGroupsService.getEventGroup( getEventGroupRequest { - this.externalDataProviderId = externalDataProviderId - this.externalEventGroupId = createdEventGroup.externalEventGroupId + this.externalMeasurementConsumerId = externalMeasurementConsumerId + externalEventGroupId = eventGroup.externalEventGroupId } ) - assertThat(retrievedEventGroup).isEqualTo(createdEventGroup) + assertThat(response).isEqualTo(eventGroup) } @Test @@ -627,7 +647,7 @@ abstract class EventGroupsServiceTest { eventGroupsService .streamEventGroups( streamEventGroupsRequest { - filter = filter { this.externalDataProviderId = externalDataProviderId } + filter = filter { externalDataProviderIds += externalDataProviderId } } ) .toList() @@ -682,7 +702,7 @@ abstract class EventGroupsServiceTest { eventGroupsService .streamEventGroups( streamEventGroupsRequest { - filter = filter { this.externalDataProviderId = externalDataProviderId } + filter = filter { externalDataProviderIds += externalDataProviderId } limit = 1 } ) @@ -697,8 +717,10 @@ abstract class EventGroupsServiceTest { streamEventGroupsRequest { filter = filter { this.externalDataProviderId = externalDataProviderId - externalEventGroupIdAfter = eventGroups[0].externalEventGroupId - externalDataProviderIdAfter = eventGroups[0].externalDataProviderId + after = eventGroupKey { + this.externalDataProviderId = eventGroups[0].externalDataProviderId + externalEventGroupId = eventGroups[0].externalEventGroupId + } } limit = 1 } @@ -803,7 +825,7 @@ abstract class EventGroupsServiceTest { } @Test - fun `deleteEventGroup fails for missing data provider`() = runBlocking { + fun `deleteEventGroup throws NOT_FOUND for missing DataProvider`() = runBlocking { val externalMeasurementConsumerId = population .createMeasurementConsumer(measurementConsumersService, accountsService) @@ -833,18 +855,16 @@ abstract class EventGroupsServiceTest { } assertThat(exception.status.code).isEqualTo(Status.Code.NOT_FOUND) - assertThat(exception).hasMessageThat().contains("EventGroup not found") + assertThat(exception).hasMessageThat().contains("EventGroup") } @Test - fun `deleteEventGroup succeeds`() = runBlocking { + fun `deleteEventGroup transitions EventGroup to DELETED state`() = runBlocking { val measurementConsumer = population.createMeasurementConsumer(measurementConsumersService, accountsService) - val externalDataProviderId = population.createDataProvider(dataProvidersService).externalDataProviderId - - val createdEventGroup = + val eventGroup = eventGroupsService.createEventGroup( createEventGroupRequest { eventGroup = eventGroup { @@ -856,23 +876,32 @@ abstract class EventGroupsServiceTest { } ) - val deletedEventGroup = + val response: EventGroup = eventGroupsService.deleteEventGroup( deleteEventGroupRequest { this.externalDataProviderId = externalDataProviderId - this.externalEventGroupId = createdEventGroup.externalEventGroupId + this.externalEventGroupId = eventGroup.externalEventGroupId } ) - assertThat(deletedEventGroup) + assertThat(response) .isEqualTo( - createdEventGroup.copy { + eventGroup.copy { this.externalMeasurementConsumerCertificateId = 0L - this.updateTime = deletedEventGroup.updateTime + this.updateTime = response.updateTime this.details = EventGroup.Details.getDefaultInstance() this.state = EventGroup.State.DELETED } ) + assertThat(response) + .isEqualTo( + eventGroupsService.getEventGroup( + getEventGroupRequest { + this.externalDataProviderId = externalDataProviderId + externalEventGroupId = eventGroup.externalEventGroupId + } + ) + ) } @Test diff --git a/src/main/kotlin/org/wfanet/measurement/loadtest/frontend/FrontendSimulator.kt b/src/main/kotlin/org/wfanet/measurement/loadtest/frontend/FrontendSimulator.kt index 0788287e7cd..78234fb50bb 100644 --- a/src/main/kotlin/org/wfanet/measurement/loadtest/frontend/FrontendSimulator.kt +++ b/src/main/kotlin/org/wfanet/measurement/loadtest/frontend/FrontendSimulator.kt @@ -43,7 +43,6 @@ import org.wfanet.measurement.api.v2alpha.EventGroup import org.wfanet.measurement.api.v2alpha.EventGroupKey import org.wfanet.measurement.api.v2alpha.EventGroupsGrpcKt.EventGroupsCoroutineStub import org.wfanet.measurement.api.v2alpha.GetDataProviderRequest -import org.wfanet.measurement.api.v2alpha.ListEventGroupsRequestKt import org.wfanet.measurement.api.v2alpha.ListRequisitionsRequestKt import org.wfanet.measurement.api.v2alpha.Measurement import org.wfanet.measurement.api.v2alpha.Measurement.DataProviderEntry @@ -616,10 +615,7 @@ class FrontendSimulator( } private suspend fun listEventGroups(measurementConsumer: String): List { - val request = listEventGroupsRequest { - parent = DATA_PROVIDER_WILDCARD - filter = ListEventGroupsRequestKt.filter { measurementConsumers += measurementConsumer } - } + val request = listEventGroupsRequest { parent = measurementConsumer } try { return eventGroupsClient .withAuthenticationKey(measurementConsumerData.apiAuthenticationKey) diff --git a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/EventGroupsService.kt b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/EventGroupsService.kt index 433bc8bfe9f..b86ba46a1a2 100644 --- a/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/EventGroupsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/EventGroupsService.kt @@ -26,9 +26,10 @@ import org.wfanet.measurement.api.v2alpha.EventGroup as CmmsEventGroup import org.wfanet.measurement.api.v2alpha.EventGroupKey as CmmsEventGroupKey import org.wfanet.measurement.api.v2alpha.EventGroupsGrpcKt.EventGroupsCoroutineStub import org.wfanet.measurement.api.v2alpha.ListEventGroupsRequestKt.filter -import org.wfanet.measurement.api.v2alpha.MeasurementConsumerKey +import org.wfanet.measurement.api.v2alpha.MeasurementConsumerKey as CmmsMeasurementConsumerKey import org.wfanet.measurement.api.v2alpha.listEventGroupsRequest as cmmsListEventGroupsRequest import org.wfanet.measurement.api.withAuthenticationKey +import org.wfanet.measurement.common.api.ResourceKey import org.wfanet.measurement.common.crypto.PrivateKeyHandle import org.wfanet.measurement.common.grpc.failGrpc import org.wfanet.measurement.consent.client.measurementconsumer.decryptMetadata @@ -68,8 +69,6 @@ class EventGroupsService( val parentKey = EventGroupParentKey.fromName(request.parent) ?: failGrpc(Status.INVALID_ARGUMENT) { "parent malformed or unspecified" } - val dataProviderKey = CmmsDataProviderKey(parentKey.dataProviderReferenceId) - val dataProviderName = dataProviderKey.toName() val pageSize = when { request.pageSize < MIN_PAGE_SIZE -> DEFAULT_PAGE_SIZE @@ -83,10 +82,14 @@ class EventGroupsService( .withAuthenticationKey(apiAuthenticationKey) .listEventGroups( cmmsListEventGroupsRequest { - parent = dataProviderName + parent = CmmsMeasurementConsumerKey(parentKey.measurementConsumerId).toName() this.pageSize = pageSize pageToken = request.pageToken - filter = filter { measurementConsumers += principalName } + filter = filter { + if (parentKey.dataProviderReferenceId != ResourceKey.WILDCARD_ID) { + dataProviders += CmmsDataProviderKey(parentKey.dataProviderReferenceId).toName() + } + } } ) } catch (e: StatusException) { @@ -112,25 +115,20 @@ class EventGroupsService( it.toEventGroup(cmmsMetadata) } - val filter: String = request.filter - if (filter.isEmpty()) { - return listEventGroupsResponse { - this.eventGroups += eventGroups - nextPageToken = cmmsListEventGroupResponse.nextPageToken - } - } - - val filteredEventGroups = filterEventGroups(eventGroups, filter) return listEventGroupsResponse { - this.eventGroups += filteredEventGroups + this.eventGroups += filterEventGroups(eventGroups, request.filter) nextPageToken = cmmsListEventGroupResponse.nextPageToken } } private suspend fun filterEventGroups( - eventGroups: Iterable, + eventGroups: List, filter: String, ): List { + if (filter.isEmpty()) { + return eventGroups + } + val typeRegistryAndEnv = celEnvProvider.getTypeRegistryAndEnv() val env = typeRegistryAndEnv.env val typeRegistry = typeRegistryAndEnv.typeRegistry @@ -215,7 +213,8 @@ class EventGroupsService( private fun CmmsEventGroup.toEventGroup(cmmsMetadata: CmmsEventGroup.Metadata?): EventGroup { val source = this val cmmsEventGroupKey = requireNotNull(CmmsEventGroupKey.fromName(name)) - val measurementConsumerKey = requireNotNull(MeasurementConsumerKey.fromName(measurementConsumer)) + val measurementConsumerKey = + requireNotNull(CmmsMeasurementConsumerKey.fromName(measurementConsumer)) return eventGroup { name = EventGroupKey( diff --git a/src/main/proto/wfa/measurement/api/v2alpha/page_token.proto b/src/main/proto/wfa/measurement/api/v2alpha/page_token.proto index e11aabd6664..395e81d6cc7 100644 --- a/src/main/proto/wfa/measurement/api/v2alpha/page_token.proto +++ b/src/main/proto/wfa/measurement/api/v2alpha/page_token.proto @@ -26,15 +26,21 @@ option java_package = "org.wfanet.measurement.api.v2alpha"; option java_multiple_files = true; message ListEventGroupsPageToken { - int32 page_size = 1; - fixed64 external_data_provider_id = 2; + reserved 1; + + oneof external_parent_id { + fixed64 external_data_provider_id = 2; + fixed64 external_measurement_consumer_id = 6; + } + repeated fixed64 external_data_provider_ids = 7; repeated fixed64 external_measurement_consumer_ids = 3; + bool show_deleted = 5; + message PreviousPageEnd { fixed64 external_data_provider_id = 1; fixed64 external_event_group_id = 2; } PreviousPageEnd last_event_group = 4; - bool show_deleted = 5; } message ListEventGroupMetadataDescriptorsPageToken { diff --git a/src/main/proto/wfa/measurement/internal/kingdom/event_groups_service.proto b/src/main/proto/wfa/measurement/internal/kingdom/event_groups_service.proto index 748b9eebf86..10170502442 100644 --- a/src/main/proto/wfa/measurement/internal/kingdom/event_groups_service.proto +++ b/src/main/proto/wfa/measurement/internal/kingdom/event_groups_service.proto @@ -45,12 +45,17 @@ message UpdateEventGroupRequest { } message GetEventGroupRequest { - fixed64 external_data_provider_id = 1; + // External ID of the parent. Required. + oneof external_parent_id { + fixed64 external_data_provider_id = 1; + fixed64 external_measurement_consumer_id = 3; + } + // External ID of the EventGroup. Required. fixed64 external_event_group_id = 2; } message DeleteEventGroupRequest { - // The external id of `EventGroup Data Provider`. Required. + // External ID of the parent `DataProvider`. Required. fixed64 external_data_provider_id = 1; // The external id of `EventGroup`. Required. fixed64 external_event_group_id = 2; @@ -60,12 +65,16 @@ message StreamEventGroupsRequest { // Filter criteria as a conjunction of specified fields. Repeated fields are // disjunctions of their items. message Filter { - int64 external_data_provider_id = 1; - repeated int64 external_measurement_consumer_ids = 2; - // for next page token, both after fields need to be set - fixed64 external_data_provider_id_after = 3; - fixed64 external_event_group_id_after = 4; + reserved 3, 4; + + fixed64 external_data_provider_id = 1; + fixed64 external_measurement_consumer_id = 6; + + repeated fixed64 external_data_provider_ids = 7; + repeated fixed64 external_measurement_consumer_ids = 2; bool show_deleted = 5; + + EventGroupKey after = 8; } Filter filter = 1; @@ -73,3 +82,8 @@ message StreamEventGroupsRequest { // unlimited. int32 limit = 2; } + +message EventGroupKey { + fixed64 external_data_provider_id = 1; + fixed64 external_event_group_id = 2; +} diff --git a/src/main/resources/kingdom/spanner/add-event-groups-by-mc.sql b/src/main/resources/kingdom/spanner/add-event-groups-by-mc.sql new file mode 100644 index 00000000000..dd6b9770655 --- /dev/null +++ b/src/main/resources/kingdom/spanner/add-event-groups-by-mc.sql @@ -0,0 +1,21 @@ +-- liquibase formatted sql + +-- Copyright 2023 The Cross-Media Measurement Authors +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +-- changeset sanjayvas:11 dbms:cloudspanner +-- comment: Add index to enforce uniqueness of ExternalEventGroupId by MeasurementConsumer. + +CREATE UNIQUE INDEX EventGroupsByMeasurementConsumer + ON EventGroups (MeasurementConsumerId, ExternalEventGroupId); diff --git a/src/main/resources/kingdom/spanner/changelog.yaml b/src/main/resources/kingdom/spanner/changelog.yaml index 6d53a332a85..003dfcac09b 100644 --- a/src/main/resources/kingdom/spanner/changelog.yaml +++ b/src/main/resources/kingdom/spanner/changelog.yaml @@ -51,3 +51,6 @@ databaseChangeLog: - include: file: update-vid-model-foreign-keys-and-indexes.sql relativeToChangeLogFile: true +- include: + file: add-event-groups-by-mc.sql + relativeToChangeLogFile: true diff --git a/src/test/kotlin/org/wfanet/measurement/integration/k8s/EmptyClusterCorrectnessTest.kt b/src/test/kotlin/org/wfanet/measurement/integration/k8s/EmptyClusterCorrectnessTest.kt index 2237ec82525..4c5ac2c24e9 100644 --- a/src/test/kotlin/org/wfanet/measurement/integration/k8s/EmptyClusterCorrectnessTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/integration/k8s/EmptyClusterCorrectnessTest.kt @@ -506,7 +506,7 @@ class EmptyClusterCorrectnessTest : AbstractCorrectnessTest(measurementSystem) { withAuthenticationKey(apiKey) .listEventGroups( listEventGroupsRequest { - parent = "dataProviders/-" + parent = measurementConsumer filter = ListEventGroupsRequestKt.filter { measurementConsumers += measurementConsumer } } diff --git a/src/test/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/EventGroupsServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/EventGroupsServiceTest.kt index dbe652cd03e..26a844ec689 100644 --- a/src/test/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/EventGroupsServiceTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/EventGroupsServiceTest.kt @@ -38,9 +38,11 @@ import org.wfanet.measurement.api.v2alpha.DataProviderKey import org.wfanet.measurement.api.v2alpha.EventGroup import org.wfanet.measurement.api.v2alpha.EventGroupKey import org.wfanet.measurement.api.v2alpha.EventGroupKt +import org.wfanet.measurement.api.v2alpha.ListEventGroupsPageToken import org.wfanet.measurement.api.v2alpha.ListEventGroupsPageTokenKt.previousPageEnd import org.wfanet.measurement.api.v2alpha.ListEventGroupsRequest import org.wfanet.measurement.api.v2alpha.ListEventGroupsRequestKt.filter +import org.wfanet.measurement.api.v2alpha.ListEventGroupsResponse import org.wfanet.measurement.api.v2alpha.MeasurementConsumerCertificateKey import org.wfanet.measurement.api.v2alpha.MeasurementConsumerKey import org.wfanet.measurement.api.v2alpha.copy @@ -57,6 +59,7 @@ import org.wfanet.measurement.api.v2alpha.updateEventGroupRequest import org.wfanet.measurement.api.v2alpha.withDataProviderPrincipal import org.wfanet.measurement.api.v2alpha.withMeasurementConsumerPrincipal import org.wfanet.measurement.api.v2alpha.withModelProviderPrincipal +import org.wfanet.measurement.common.base64UrlDecode import org.wfanet.measurement.common.base64UrlEncode import org.wfanet.measurement.common.grpc.testing.GrpcTestServerRule import org.wfanet.measurement.common.grpc.testing.mockService @@ -75,6 +78,7 @@ import org.wfanet.measurement.internal.kingdom.copy import org.wfanet.measurement.internal.kingdom.createEventGroupRequest as internalCreateEventGroupRequest import org.wfanet.measurement.internal.kingdom.deleteEventGroupRequest as internalDeleteEventGroupRequest import org.wfanet.measurement.internal.kingdom.eventGroup as internalEventGroup +import org.wfanet.measurement.internal.kingdom.eventGroupKey import org.wfanet.measurement.internal.kingdom.getEventGroupRequest as internalGetEventGroupRequest import org.wfanet.measurement.internal.kingdom.streamEventGroupsRequest import org.wfanet.measurement.internal.kingdom.updateEventGroupRequest as internalUpdateEventGroupRequest @@ -83,8 +87,6 @@ private val CREATE_TIME: Timestamp = Instant.ofEpochSecond(123).toProtoTime() private const val DEFAULT_LIMIT = 50 -private const val WILDCARD_NAME = "dataProviders/-" - private val DATA_PROVIDER_NAME = makeDataProvider(123L) private val DATA_PROVIDER_NAME_2 = makeDataProvider(124L) private val DATA_PROVIDER_EXTERNAL_ID = @@ -651,42 +653,123 @@ class EventGroupsServiceTest { } @Test - fun `listEventGroups with parent uses filter with parent`() { + fun `listEventGroups requests EventGroups by DataProvider`() { val request = listEventGroupsRequest { parent = DATA_PROVIDER_NAME - filter = filter { measurementConsumers += MEASUREMENT_CONSUMER_NAME } + pageSize = 100 } - val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME) { + val response: ListEventGroupsResponse = + withDataProviderPrincipal(DATA_PROVIDER_NAME) { runBlocking { service.listEventGroups(request) } } - val expected = listEventGroupsResponse { - eventGroups += EVENT_GROUP - eventGroups += EVENT_GROUP.copy { name = EVENT_GROUP_NAME_2 } - eventGroups += EVENT_GROUP.copy { name = EVENT_GROUP_NAME_3 } + assertThat(response) + .isEqualTo( + listEventGroupsResponse { + eventGroups += EVENT_GROUP + eventGroups += EVENT_GROUP.copy { name = EVENT_GROUP_NAME_2 } + eventGroups += EVENT_GROUP.copy { name = EVENT_GROUP_NAME_3 } + } + ) + val internalRequest = + captureFirst { + verify(internalEventGroupsMock).streamEventGroups(capture()) + } + assertThat(internalRequest) + .isEqualTo( + streamEventGroupsRequest { + filter = + StreamEventGroupsRequestKt.filter { externalDataProviderId = DATA_PROVIDER_EXTERNAL_ID } + limit = request.pageSize + 1 + } + ) + } + + @Test + fun `listEventGroups requests EventGroups by MeasurementConsumer`() { + val request = listEventGroupsRequest { + parent = MEASUREMENT_CONSUMER_NAME + pageSize = 100 } - val streamEventGroupsRequest = + val response: ListEventGroupsResponse = + withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME) { + runBlocking { service.listEventGroups(request) } + } + + assertThat(response) + .isEqualTo( + listEventGroupsResponse { + eventGroups += EVENT_GROUP + eventGroups += EVENT_GROUP.copy { name = EVENT_GROUP_NAME_2 } + eventGroups += EVENT_GROUP.copy { name = EVENT_GROUP_NAME_3 } + } + ) + val internalRequest = captureFirst { verify(internalEventGroupsMock).streamEventGroups(capture()) } + assertThat(internalRequest) + .isEqualTo( + streamEventGroupsRequest { + filter = + StreamEventGroupsRequestKt.filter { + externalMeasurementConsumerId = MEASUREMENT_CONSUMER_EXTERNAL_ID + } + limit = request.pageSize + 1 + } + ) + } - assertThat(streamEventGroupsRequest) - .ignoringRepeatedFieldOrder() + @Test + fun `listEventGroups response includes next page token when there are more items`() { + val request = listEventGroupsRequest { + parent = DATA_PROVIDER_NAME + filter = filter { measurementConsumers += MEASUREMENT_CONSUMER_NAME } + pageSize = 2 + } + + val response: ListEventGroupsResponse = + withDataProviderPrincipal(DATA_PROVIDER_NAME) { + runBlocking { service.listEventGroups(request) } + } + + assertThat(response) + .ignoringFields(ListEventGroupsResponse.NEXT_PAGE_TOKEN_FIELD_NUMBER) + .isEqualTo( + listEventGroupsResponse { + eventGroups += EVENT_GROUP + eventGroups += EVENT_GROUP.copy { name = EVENT_GROUP_NAME_2 } + } + ) + val internalRequest = + captureFirst { + verify(internalEventGroupsMock).streamEventGroups(capture()) + } + assertThat(internalRequest) .isEqualTo( streamEventGroupsRequest { - limit = DEFAULT_LIMIT + 1 filter = StreamEventGroupsRequestKt.filter { externalDataProviderId = DATA_PROVIDER_EXTERNAL_ID externalMeasurementConsumerIds += MEASUREMENT_CONSUMER_EXTERNAL_ID } + limit = request.pageSize + 1 + } + ) + val nextPageToken = ListEventGroupsPageToken.parseFrom(response.nextPageToken.base64UrlDecode()) + assertThat(nextPageToken) + .isEqualTo( + listEventGroupsPageToken { + externalDataProviderId = DATA_PROVIDER_EXTERNAL_ID + externalMeasurementConsumerIds += MEASUREMENT_CONSUMER_EXTERNAL_ID + lastEventGroup = previousPageEnd { + externalDataProviderId = DATA_PROVIDER_EXTERNAL_ID + externalEventGroupId = EVENT_GROUP_EXTERNAL_ID_2 + } } ) - - assertThat(result).ignoringRepeatedFieldOrder().isEqualTo(expected) } @Test @@ -695,7 +778,6 @@ class EventGroupsServiceTest { parent = DATA_PROVIDER_NAME pageSize = 2 val listEventGroupsPageToken = listEventGroupsPageToken { - pageSize = 2 externalDataProviderId = DATA_PROVIDER_EXTERNAL_ID externalMeasurementConsumerIds += MEASUREMENT_CONSUMER_EXTERNAL_ID lastEventGroup = previousPageEnd { @@ -708,7 +790,7 @@ class EventGroupsServiceTest { } val result = - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME) { + withDataProviderPrincipal(DATA_PROVIDER_NAME) { runBlocking { service.listEventGroups(request) } } @@ -716,7 +798,6 @@ class EventGroupsServiceTest { eventGroups += EVENT_GROUP eventGroups += EVENT_GROUP.copy { name = EVENT_GROUP_NAME_2 } val listEventGroupsPageToken = listEventGroupsPageToken { - pageSize = request.pageSize externalDataProviderId = DATA_PROVIDER_EXTERNAL_ID externalMeasurementConsumerIds += MEASUREMENT_CONSUMER_EXTERNAL_ID lastEventGroup = previousPageEnd { @@ -740,9 +821,11 @@ class EventGroupsServiceTest { filter = StreamEventGroupsRequestKt.filter { externalDataProviderId = DATA_PROVIDER_EXTERNAL_ID - externalDataProviderIdAfter = DATA_PROVIDER_EXTERNAL_ID - externalEventGroupIdAfter = EVENT_GROUP_EXTERNAL_ID externalMeasurementConsumerIds += MEASUREMENT_CONSUMER_EXTERNAL_ID + after = eventGroupKey { + externalDataProviderId = DATA_PROVIDER_EXTERNAL_ID + externalEventGroupId = EVENT_GROUP_EXTERNAL_ID + } } } ) @@ -751,51 +834,8 @@ class EventGroupsServiceTest { } @Test - fun `listEventGroups with new page size replaces page size in page token`() { - val request = listEventGroupsRequest { - parent = DATA_PROVIDER_NAME - pageSize = 4 - val listEventGroupsPageToken = listEventGroupsPageToken { - pageSize = 2 - externalDataProviderId = DATA_PROVIDER_EXTERNAL_ID - lastEventGroup = previousPageEnd { - externalEventGroupId = EVENT_GROUP_EXTERNAL_ID - externalDataProviderId = DATA_PROVIDER_EXTERNAL_ID - } - externalMeasurementConsumerIds += MEASUREMENT_CONSUMER_EXTERNAL_ID - } - filter = filter { measurementConsumers += MEASUREMENT_CONSUMER_NAME } - pageToken = listEventGroupsPageToken.toByteArray().base64UrlEncode() - } - - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME) { - runBlocking { service.listEventGroups(request) } - } - - val streamEventGroupsRequest = - captureFirst { - verify(internalEventGroupsMock).streamEventGroups(capture()) - } - - assertThat(streamEventGroupsRequest) - .comparingExpectedFieldsOnly() - .isEqualTo(streamEventGroupsRequest { limit = request.pageSize + 1 }) - } - - @Test - fun `listEventGroups with no page size uses page size in page token`() { - val request = listEventGroupsRequest { - parent = DATA_PROVIDER_NAME - val listEventGroupsPageToken = listEventGroupsPageToken { - pageSize = 2 - externalDataProviderId = DATA_PROVIDER_EXTERNAL_ID - lastEventGroup = previousPageEnd { - externalEventGroupId = EVENT_GROUP_EXTERNAL_ID - externalDataProviderId = DATA_PROVIDER_EXTERNAL_ID - } - } - pageToken = listEventGroupsPageToken.toByteArray().base64UrlEncode() - } + fun `listEventGroups uses default page size when unspecified`() { + val request = listEventGroupsRequest { parent = DATA_PROVIDER_NAME } withDataProviderPrincipal(DATA_PROVIDER_NAME) { runBlocking { service.listEventGroups(request) } @@ -805,53 +845,39 @@ class EventGroupsServiceTest { captureFirst { verify(internalEventGroupsMock).streamEventGroups(capture()) } - assertThat(streamEventGroupsRequest) .comparingExpectedFieldsOnly() - .isEqualTo(streamEventGroupsRequest { limit = 3 }) + .isEqualTo(streamEventGroupsRequest { limit = 51 }) } @Test - fun `listEventGroups with parent and filter with measurement consumers uses filter with both`() { - val request = listEventGroupsRequest { - parent = DATA_PROVIDER_NAME - filter = filter { - measurementConsumers += MEASUREMENT_CONSUMER_NAME - measurementConsumers += MEASUREMENT_CONSUMER_NAME - } + fun `listEventGroups throws INVALID_ARGUMENT when subsequent request params mismatch page token`() { + val initialRequest = listEventGroupsRequest { + parent = MEASUREMENT_CONSUMER_NAME + filter = filter { dataProviders += DATA_PROVIDER_NAME } + pageSize = 2 } - - val result = + val initialResponse: ListEventGroupsResponse = withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME) { - runBlocking { service.listEventGroups(request) } + runBlocking { service.listEventGroups(initialRequest) } } - val expected = listEventGroupsResponse { - eventGroups += EVENT_GROUP - eventGroups += EVENT_GROUP.copy { name = EVENT_GROUP_NAME_2 } - eventGroups += EVENT_GROUP.copy { name = EVENT_GROUP_NAME_3 } - } - - val streamEventGroupsRequest = - captureFirst { - verify(internalEventGroupsMock).streamEventGroups(capture()) - } - - assertThat(streamEventGroupsRequest) - .ignoringRepeatedFieldOrder() - .isEqualTo( - streamEventGroupsRequest { - limit = DEFAULT_LIMIT + 1 - filter = - StreamEventGroupsRequestKt.filter { - externalDataProviderId = DATA_PROVIDER_EXTERNAL_ID - externalMeasurementConsumerIds += MEASUREMENT_CONSUMER_EXTERNAL_ID - externalMeasurementConsumerIds += MEASUREMENT_CONSUMER_EXTERNAL_ID - } + val exception = + assertFailsWith { + withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME) { + runBlocking { + service.listEventGroups( + initialRequest.copy { + showDeleted = true + pageToken = initialResponse.nextPageToken + } + ) + } } - ) + } - assertThat(result).ignoringRepeatedFieldOrder().isEqualTo(expected) + assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) + assertThat(exception).hasMessageThat().contains("page") } @Test @@ -867,7 +893,7 @@ class EventGroupsServiceTest { } @Test - fun `listEventGroups throws PERMISSION_DENIED when edp caller doesn't match`() { + fun `listEventGroups throws PERMISSION_DENIED when DataProvider principal mismatches`() { val request = listEventGroupsRequest { parent = DATA_PROVIDER_NAME } val exception = @@ -880,18 +906,12 @@ class EventGroupsServiceTest { } @Test - fun `listEventGroups throws PERMISSION_DENIED when mc caller doesn't match filter MC`() { - val request = listEventGroupsRequest { - parent = DATA_PROVIDER_NAME - filter = filter { - measurementConsumers += MEASUREMENT_CONSUMER_NAME - measurementConsumers += "measurementConsumers/BBBAAAAAAHt" - } - } + fun `listEventGroups throws PERMISSION_DENIED when MeasurementConsumer principal mismatches`() { + val request = listEventGroupsRequest { parent = MEASUREMENT_CONSUMER_NAME } val exception = assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME) { + withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME_2) { runBlocking { service.listEventGroups(request) } } } @@ -899,7 +919,7 @@ class EventGroupsServiceTest { } @Test - fun `listEventGroups throws PERMISSION_DENIED when mc caller and missing mc filter`() { + fun `listEventGroups throws PERMISSION_DENIED parent type mismatches`() { val request = listEventGroupsRequest { parent = DATA_PROVIDER_NAME } val exception = @@ -911,17 +931,6 @@ class EventGroupsServiceTest { assertThat(exception.status.code).isEqualTo(Status.Code.PERMISSION_DENIED) } - @Test - fun `listEventGroups throws INVALID_ARGUMENT when only wildcard parent`() { - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME) { - runBlocking { service.listEventGroups(listEventGroupsRequest { parent = WILDCARD_NAME }) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - } - @Test fun `listRequisitions throws INVALID_ARGUMENT when parent is missing`() { val exception = @@ -937,7 +946,7 @@ class EventGroupsServiceTest { fun `listEventGroups throws INVALID_ARGUMENT when measurement consumer in filter is invalid`() { val exception = assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME) { + withDataProviderPrincipal(DATA_PROVIDER_NAME) { runBlocking { service.listEventGroups( listEventGroupsRequest { @@ -968,55 +977,4 @@ class EventGroupsServiceTest { } assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) } - - @Test - fun `listEventGroups throws invalid argument when parent doesn't match parent in page token`() { - val request = listEventGroupsRequest { - parent = DATA_PROVIDER_NAME - pageSize = 2 - val listEventGroupsPageToken = listEventGroupsPageToken { - pageSize = 2 - externalDataProviderId = 654 - lastEventGroup = previousPageEnd { - externalEventGroupId = EVENT_GROUP_EXTERNAL_ID - externalDataProviderId = DATA_PROVIDER_EXTERNAL_ID - } - } - pageToken = listEventGroupsPageToken.toByteArray().base64UrlEncode() - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME) { - runBlocking { service.listEventGroups(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - } - - @Test - fun `listEventGroups throws invalid argument when mc ids don't match ids in page token`() { - val request = listEventGroupsRequest { - parent = DATA_PROVIDER_NAME - pageSize = 2 - val listEventGroupsPageToken = listEventGroupsPageToken { - pageSize = 2 - externalDataProviderId = DATA_PROVIDER_EXTERNAL_ID - externalMeasurementConsumerIds += 123 - lastEventGroup = previousPageEnd { - externalEventGroupId = EVENT_GROUP_EXTERNAL_ID - externalDataProviderId = DATA_PROVIDER_EXTERNAL_ID - } - } - pageToken = listEventGroupsPageToken.toByteArray().base64UrlEncode() - } - - val exception = - assertFailsWith { - withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME) { - runBlocking { service.listEventGroups(request) } - } - } - assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) - } } diff --git a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/EventGroupsServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/EventGroupsServiceTest.kt index b22a20f67bd..7fb5dfccb4c 100644 --- a/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/EventGroupsServiceTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/reporting/service/api/v1alpha/EventGroupsServiceTest.kt @@ -39,6 +39,7 @@ import org.wfanet.measurement.api.v2alpha.EventGroupMetadataDescriptorsGrpcKt.Ev import org.wfanet.measurement.api.v2alpha.EventGroupMetadataDescriptorsGrpcKt.EventGroupMetadataDescriptorsCoroutineStub import org.wfanet.measurement.api.v2alpha.EventGroupsGrpcKt.EventGroupsCoroutineImplBase import org.wfanet.measurement.api.v2alpha.EventGroupsGrpcKt.EventGroupsCoroutineStub +import org.wfanet.measurement.api.v2alpha.ListEventGroupsRequest as CmmsListEventGroupsRequest import org.wfanet.measurement.api.v2alpha.ListEventGroupsRequestKt import org.wfanet.measurement.api.v2alpha.MeasurementConsumerKey import org.wfanet.measurement.api.v2alpha.copy @@ -268,10 +269,10 @@ class EventGroupsServiceTest { ) val expectedCmmsEventGroupsRequest = cmmsListEventGroupsRequest { - parent = DATA_PROVIDER_NAME + parent = MEASUREMENT_CONSUMER_NAME pageSize = 10 pageToken = PAGE_TOKEN - filter = ListEventGroupsRequestKt.filter { measurementConsumers += MEASUREMENT_CONSUMER_NAME } + filter = ListEventGroupsRequestKt.filter { dataProviders += DATA_PROVIDER_NAME } } verifyProtoArgument(cmmsEventGroupsServiceMock, EventGroupsCoroutineImplBase::listEventGroups) @@ -302,16 +303,36 @@ class EventGroupsServiceTest { ) val expectedCmmsEventGroupsRequest = cmmsListEventGroupsRequest { - parent = DATA_PROVIDER_NAME + parent = MEASUREMENT_CONSUMER_NAME pageSize = DEFAULT_PAGE_SIZE pageToken = PAGE_TOKEN - filter = ListEventGroupsRequestKt.filter { measurementConsumers += MEASUREMENT_CONSUMER_NAME } + filter = ListEventGroupsRequestKt.filter { dataProviders += DATA_PROVIDER_NAME } } verifyProtoArgument(cmmsEventGroupsServiceMock, EventGroupsCoroutineImplBase::listEventGroups) .isEqualTo(expectedCmmsEventGroupsRequest) } + @Test + fun `listEventGroups omits DataProvider filter in CMMS request when ID is wildcard`() { + val request = listEventGroupsRequest { + parent = "measurementConsumers/$MEASUREMENT_CONSUMER_REFERENCE_ID/dataProviders/-" + } + + withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME, CONFIG) { + runBlocking { service.listEventGroups(request) } + } + + verifyProtoArgument(cmmsEventGroupsServiceMock, EventGroupsCoroutineImplBase::listEventGroups) + .isEqualTo( + cmmsListEventGroupsRequest { + parent = MEASUREMENT_CONSUMER_NAME + pageSize = DEFAULT_PAGE_SIZE + filter = CmmsListEventGroupsRequest.Filter.getDefaultInstance() + } + ) + } + @Test fun `listEventGroups returns list with filter when event group with metadata and one without`() { runBlocking { @@ -340,10 +361,10 @@ class EventGroupsServiceTest { assertThat(result).isEqualTo(listEventGroupsResponse { eventGroups += EVENT_GROUP }) val expectedCmmsEventGroupsRequest = cmmsListEventGroupsRequest { - parent = DATA_PROVIDER_NAME + parent = MEASUREMENT_CONSUMER_NAME pageSize = DEFAULT_PAGE_SIZE pageToken = PAGE_TOKEN - filter = ListEventGroupsRequestKt.filter { measurementConsumers += MEASUREMENT_CONSUMER_NAME } + filter = ListEventGroupsRequestKt.filter { dataProviders += DATA_PROVIDER_NAME } } verifyProtoArgument(cmmsEventGroupsServiceMock, EventGroupsCoroutineImplBase::listEventGroups)