From 4bf41f4325f8f3185b8591b2f56ff5751e6888f7 Mon Sep 17 00:00:00 2001 From: Sanjay Vasandani Date: Fri, 15 Mar 2024 16:43:56 -0700 Subject: [PATCH] Update cross-media-measurement-api dep for DataProvider capabilities. --- MODULE.bazel | 10 +- MODULE.bazel.lock | 70 ++--- .../spanner/SpannerDataProvidersService.kt | 14 + .../ReplaceDataProviderCapabilities.kt | 52 ++++ .../kingdom/service/api/v2alpha/BUILD.bazel | 4 + .../api/v2alpha/DataProvidersService.kt | 56 ++++ .../api/v2alpha/MeasurementsService.kt | 256 ++++++++++++++++-- .../service/api/v2alpha/ProtoConversions.kt | 113 +------- .../testing/DataProvidersServiceTest.kt | 30 +- .../internal/kingdom/data_provider.proto | 5 + .../kingdom/data_providers_service.proto | 10 + .../api/v2alpha/DataProvidersServiceTest.kt | 82 +++++- .../api/v2alpha/MeasurementsServiceTest.kt | 121 ++++++++- 13 files changed, 636 insertions(+), 187 deletions(-) create mode 100644 src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/writers/ReplaceDataProviderCapabilities.kt diff --git a/MODULE.bazel b/MODULE.bazel index b938ad057dd..d0bcc5cca27 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -136,7 +136,6 @@ bazel_dep( ) bazel_dep( name = "cross-media-measurement-api", - version = "0.60.0", repo_name = "wfa_measurement_proto", ) bazel_dep( @@ -330,3 +329,12 @@ single_version_override( module_name = "boringssl", version = BORINGSSL_VERSION, ) + +# TODO(world-federation-of-advertisers/cross-media-measurement-api#199): Use version. +archive_override( + module_name = "cross-media-measurement-api", + strip_prefix = "cross-media-measurement-api-8a1e491350d89333343dcca9f534cc68fbaf9736", + urls = [ + "https://github.com/world-federation-of-advertisers/cross-media-measurement-api/archive/8a1e491350d89333343dcca9f534cc68fbaf9736.tar.gz", + ], +) diff --git a/MODULE.bazel.lock b/MODULE.bazel.lock index 2d6913ff6db..dfd207d0e7f 100644 --- a/MODULE.bazel.lock +++ b/MODULE.bazel.lock @@ -1,6 +1,6 @@ { "lockFileVersion": 3, - "moduleFileHash": "8708b0e5573f4f48c1b66c6194d5342e976a77caa74e2c763582c858911e0e1f", + "moduleFileHash": "f114bb54e27cc01e7b97201238a100a34e080d6f10d0f9ea8c921fdf5bf88ee2", "flags": { "cmdRegistries": [ "https://raw.githubusercontent.com/world-federation-of-advertisers/bazel-registry/main", @@ -31,7 +31,7 @@ "usingModule": "", "location": { "file": "@@//:MODULE.bazel", - "line": 188, + "line": 187, "column": 22 }, "imports": { @@ -53,7 +53,7 @@ "devDependency": false, "location": { "file": "@@//:MODULE.bazel", - "line": 189, + "line": 188, "column": 15 } }, @@ -70,7 +70,7 @@ "devDependency": false, "location": { "file": "@@//:MODULE.bazel", - "line": 195, + "line": 194, "column": 15 } }, @@ -107,7 +107,7 @@ "devDependency": false, "location": { "file": "@@//:MODULE.bazel", - "line": 204, + "line": 203, "column": 14 } } @@ -121,7 +121,7 @@ "usingModule": "", "location": { "file": "@@//:MODULE.bazel", - "line": 234, + "line": 233, "column": 20 }, "imports": { @@ -140,7 +140,7 @@ "devDependency": false, "location": { "file": "@@//:MODULE.bazel", - "line": 235, + "line": 234, "column": 23 } } @@ -154,7 +154,7 @@ "usingModule": "", "location": { "file": "@@//:MODULE.bazel", - "line": 243, + "line": 242, "column": 24 }, "imports": { @@ -176,7 +176,7 @@ "devDependency": false, "location": { "file": "@@//:MODULE.bazel", - "line": 244, + "line": 243, "column": 15 } }, @@ -190,7 +190,7 @@ "devDependency": false, "location": { "file": "@@//:MODULE.bazel", - "line": 249, + "line": 248, "column": 15 } }, @@ -204,7 +204,7 @@ "devDependency": false, "location": { "file": "@@//:MODULE.bazel", - "line": 254, + "line": 253, "column": 15 } }, @@ -218,7 +218,7 @@ "devDependency": false, "location": { "file": "@@//:MODULE.bazel", - "line": 259, + "line": 258, "column": 15 } }, @@ -232,7 +232,7 @@ "devDependency": false, "location": { "file": "@@//:MODULE.bazel", - "line": 264, + "line": 263, "column": 15 } } @@ -246,7 +246,7 @@ "usingModule": "", "location": { "file": "@@//:MODULE.bazel", - "line": 278, + "line": 277, "column": 23 }, "imports": {}, @@ -261,7 +261,7 @@ "devDependency": false, "location": { "file": "@@//:MODULE.bazel", - "line": 279, + "line": 278, "column": 17 } } @@ -297,7 +297,7 @@ "devDependency": false, "location": { "file": "@@//:MODULE.bazel", - "line": 288, + "line": 287, "column": 13 } }, @@ -312,7 +312,7 @@ "devDependency": false, "location": { "file": "@@//:MODULE.bazel", - "line": 295, + "line": 294, "column": 13 } }, @@ -327,7 +327,7 @@ "devDependency": false, "location": { "file": "@@//:MODULE.bazel", - "line": 302, + "line": 301, "column": 13 } }, @@ -342,7 +342,7 @@ "devDependency": false, "location": { "file": "@@//:MODULE.bazel", - "line": 309, + "line": 308, "column": 13 } } @@ -377,7 +377,7 @@ "wfa_rules_cue": "rules_cue@0.4.0", "wfa_common_jvm": "common-jvm@0.78.0", "wfa_common_cpp": "common-cpp@0.12.0", - "wfa_measurement_proto": "cross-media-measurement-api@0.60.0", + "wfa_measurement_proto": "cross-media-measurement-api@_", "wfa_consent_signaling_client": "consent-signaling-client@0.20.0", "any_sketch": "any-sketch@0.6.0", "any_sketch_java": "any-sketch-java@0.5.0", @@ -2566,10 +2566,10 @@ } } }, - "cross-media-measurement-api@0.60.0": { + "cross-media-measurement-api@_": { "name": "cross-media-measurement-api", - "version": "0.60.0", - "key": "cross-media-measurement-api@0.60.0", + "version": "", + "key": "cross-media-measurement-api@_", "repoName": "wfa_measurement_proto", "executionPlatformsToRegister": [], "toolchainsToRegister": [], @@ -2577,10 +2577,10 @@ { "extensionBzlFile": "@wfa_measurement_proto//build:non_module_deps.bzl", "extensionName": "non_module_deps", - "usingModule": "cross-media-measurement-api@0.60.0", + "usingModule": "cross-media-measurement-api@_", "location": { - "file": "https://raw.githubusercontent.com/world-federation-of-advertisers/bazel-registry/main/modules/cross-media-measurement-api/0.60.0/MODULE.bazel", - "line": 25, + "file": "@@cross-media-measurement-api~override//:MODULE.bazel", + "line": 24, "column": 32 }, "imports": { @@ -2598,22 +2598,6 @@ "com_google_googleapis": "googleapis@0.0.0-bzlmod.1", "bazel_tools": "bazel_tools@_", "local_config_platform": "local_config_platform@_" - }, - "repoSpec": { - "bzlFile": "@bazel_tools//tools/build_defs/repo:http.bzl", - "ruleClassName": "http_archive", - "attributes": { - "name": "cross-media-measurement-api~0.60.0", - "urls": [ - "https://github.com/world-federation-of-advertisers/cross-media-measurement-api/archive/refs/tags/v0.60.0.tar.gz" - ], - "integrity": "sha256-m9CFeMyEBWFe5lENTk7tT2e3trguJvfteLU+QcWjZQ4=", - "strip_prefix": "cross-media-measurement-api-0.60.0", - "remote_patches": { - "https://raw.githubusercontent.com/world-federation-of-advertisers/bazel-registry/main/modules/cross-media-measurement-api/0.60.0/patches/module_dot_bazel.patch": "sha256-/0FKi8ly/uVbEhju1Jt/h+4baPJNSCeyPvM+N2IoNeU=" - }, - "remote_patch_strip": 0 - } } }, "consent-signaling-client@0.20.0": { @@ -2626,7 +2610,7 @@ "extensionUsages": [], "deps": { "wfa_rules_kotlin_jvm": "rules_kotlin_jvm@0.2.0", - "wfa_measurement_proto": "cross-media-measurement-api@0.60.0", + "wfa_measurement_proto": "cross-media-measurement-api@_", "wfa_common_jvm": "common-jvm@0.78.0", "bazel_tools": "bazel_tools@_", "local_config_platform": "local_config_platform@_" diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/SpannerDataProvidersService.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/SpannerDataProvidersService.kt index 4dbfa949cfe..c5d7e7e8515 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/SpannerDataProvidersService.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/SpannerDataProvidersService.kt @@ -28,6 +28,7 @@ import org.wfanet.measurement.internal.kingdom.DataProvider import org.wfanet.measurement.internal.kingdom.DataProvidersGrpcKt.DataProvidersCoroutineImplBase import org.wfanet.measurement.internal.kingdom.GetDataProviderRequest import org.wfanet.measurement.internal.kingdom.ReplaceDataAvailabilityIntervalRequest +import org.wfanet.measurement.internal.kingdom.ReplaceDataProviderCapabilitiesRequest import org.wfanet.measurement.internal.kingdom.ReplaceDataProviderRequiredDuchiesRequest import org.wfanet.measurement.internal.kingdom.batchGetDataProvidersResponse import org.wfanet.measurement.kingdom.deploy.gcloud.spanner.common.DataProviderNotFoundException @@ -35,6 +36,7 @@ import org.wfanet.measurement.kingdom.deploy.gcloud.spanner.common.KingdomIntern import org.wfanet.measurement.kingdom.deploy.gcloud.spanner.readers.DataProviderReader import org.wfanet.measurement.kingdom.deploy.gcloud.spanner.writers.CreateDataProvider import org.wfanet.measurement.kingdom.deploy.gcloud.spanner.writers.ReplaceDataAvailabilityInterval +import org.wfanet.measurement.kingdom.deploy.gcloud.spanner.writers.ReplaceDataProviderCapabilities import org.wfanet.measurement.kingdom.deploy.gcloud.spanner.writers.ReplaceDataProviderRequiredDuchies class SpannerDataProvidersService( @@ -99,4 +101,16 @@ class SpannerDataProvidersService( throw e.asStatusRuntimeException(Status.Code.NOT_FOUND, "DataProvider not found.") } } + + override suspend fun replaceDataProviderCapabilities( + request: ReplaceDataProviderCapabilitiesRequest + ): DataProvider { + grpcRequire(request.externalDataProviderId != 0L) { "external_data_provider_id is missing." } + + try { + return ReplaceDataProviderCapabilities(request).execute(client, idGenerator) + } catch (e: DataProviderNotFoundException) { + throw e.asStatusRuntimeException(Status.Code.NOT_FOUND, "DataProvider not found.") + } + } } diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/writers/ReplaceDataProviderCapabilities.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/writers/ReplaceDataProviderCapabilities.kt new file mode 100644 index 00000000000..b3090e13725 --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/writers/ReplaceDataProviderCapabilities.kt @@ -0,0 +1,52 @@ +/* + * Copyright 2023 The Cross-Media Measurement Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.wfanet.measurement.kingdom.deploy.gcloud.spanner.writers + +import org.wfanet.measurement.common.identity.ExternalId +import org.wfanet.measurement.gcloud.spanner.bufferUpdateMutation +import org.wfanet.measurement.gcloud.spanner.set +import org.wfanet.measurement.gcloud.spanner.setJson +import org.wfanet.measurement.internal.kingdom.DataProvider +import org.wfanet.measurement.internal.kingdom.ReplaceDataProviderCapabilitiesRequest +import org.wfanet.measurement.internal.kingdom.copy +import org.wfanet.measurement.kingdom.deploy.gcloud.spanner.common.DataProviderNotFoundException +import org.wfanet.measurement.kingdom.deploy.gcloud.spanner.readers.DataProviderReader + +class ReplaceDataProviderCapabilities(private val request: ReplaceDataProviderCapabilitiesRequest) : + SpannerWriter() { + override suspend fun TransactionScope.runTransaction(): DataProvider { + val externalDataProviderId = ExternalId(request.externalDataProviderId) + val dataProviderResult: DataProviderReader.Result = + DataProviderReader().readByExternalDataProviderId(transactionContext, externalDataProviderId) + ?: throw DataProviderNotFoundException(externalDataProviderId) + + val updatedDetails: DataProvider.Details = + dataProviderResult.dataProvider.details.copy { capabilities = request.capabilities } + + transactionContext.bufferUpdateMutation("DataProviders") { + set("DataProviderId" to dataProviderResult.dataProviderId) + set("DataProviderDetails" to updatedDetails) + setJson("DataProviderDetailsJson" to updatedDetails) + } + + return dataProviderResult.dataProvider.copy { details = updatedDetails } + } + + override fun ResultScope.buildResult(): DataProvider { + return checkNotNull(transactionResult) + } +} diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/BUILD.bazel index 0465f91c73c..1b1f81b9f5c 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/BUILD.bazel @@ -247,6 +247,9 @@ kt_jvm_library( "//src/main/kotlin/org/wfanet/measurement/api:public_api_version", "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:principal_server_interceptor", "//src/main/kotlin/org/wfanet/measurement/api/v2alpha:resource_key", + "//src/main/kotlin/org/wfanet/measurement/kingdom/deploy/common:hmss_protocol_config", + "//src/main/kotlin/org/wfanet/measurement/kingdom/deploy/common:llv2_protocol_config", + "//src/main/kotlin/org/wfanet/measurement/kingdom/deploy/common:ro_llv2_protocol_config", "//src/main/proto/wfa/measurement/api/v2alpha:crypto_kt_jvm_proto", "//src/main/proto/wfa/measurement/api/v2alpha:differential_privacy_kt_jvm_proto", "//src/main/proto/wfa/measurement/api/v2alpha:measurement_kt_jvm_proto", @@ -254,6 +257,7 @@ kt_jvm_library( "//src/main/proto/wfa/measurement/api/v2alpha:measurements_service_kt_jvm_grpc_proto", "//src/main/proto/wfa/measurement/api/v2alpha:page_token_kt_jvm_proto", "//src/main/proto/wfa/measurement/api/v2alpha:protocol_config_kt_jvm_proto", + "//src/main/proto/wfa/measurement/internal/kingdom:data_providers_service_kt_jvm_grpc_proto", "//src/main/proto/wfa/measurement/internal/kingdom:measurement_kt_jvm_proto", "//src/main/proto/wfa/measurement/internal/kingdom:measurements_service_kt_jvm_grpc_proto", "@wfa_common_jvm//imports/kotlin/kotlinx/coroutines:core", diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/DataProvidersService.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/DataProvidersService.kt index f75b2516bba..7a30045e1bf 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/DataProvidersService.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/DataProvidersService.kt @@ -22,6 +22,7 @@ import org.wfanet.measurement.api.Version import org.wfanet.measurement.api.v2alpha.DataProvider import org.wfanet.measurement.api.v2alpha.DataProviderCertificateKey import org.wfanet.measurement.api.v2alpha.DataProviderKey +import org.wfanet.measurement.api.v2alpha.DataProviderKt import org.wfanet.measurement.api.v2alpha.DataProviderPrincipal import org.wfanet.measurement.api.v2alpha.DataProvidersGrpcKt.DataProvidersCoroutineImplBase as DataProvidersCoroutineService import org.wfanet.measurement.api.v2alpha.DuchyKey @@ -31,6 +32,7 @@ import org.wfanet.measurement.api.v2alpha.MeasurementConsumerPrincipal import org.wfanet.measurement.api.v2alpha.MeasurementPrincipal import org.wfanet.measurement.api.v2alpha.ModelProviderPrincipal import org.wfanet.measurement.api.v2alpha.ReplaceDataAvailabilityIntervalRequest +import org.wfanet.measurement.api.v2alpha.ReplaceDataProviderCapabilitiesRequest import org.wfanet.measurement.api.v2alpha.ReplaceDataProviderRequiredDuchiesRequest import org.wfanet.measurement.api.v2alpha.dataProvider import org.wfanet.measurement.api.v2alpha.principalFromCurrentContext @@ -40,12 +42,15 @@ import org.wfanet.measurement.common.ProtoReflection import org.wfanet.measurement.common.grpc.failGrpc import org.wfanet.measurement.common.grpc.grpcRequire import org.wfanet.measurement.common.grpc.grpcRequireNotNull +import org.wfanet.measurement.common.identity.ApiId import org.wfanet.measurement.common.identity.apiIdToExternalId import org.wfanet.measurement.common.identity.externalIdToApiId import org.wfanet.measurement.internal.kingdom.DataProvider as InternalDataProvider +import org.wfanet.measurement.internal.kingdom.DataProviderKt as InternalDataProviderKt import org.wfanet.measurement.internal.kingdom.DataProvidersGrpcKt.DataProvidersCoroutineStub import org.wfanet.measurement.internal.kingdom.getDataProviderRequest import org.wfanet.measurement.internal.kingdom.replaceDataAvailabilityIntervalRequest +import org.wfanet.measurement.internal.kingdom.replaceDataProviderCapabilitiesRequest import org.wfanet.measurement.internal.kingdom.replaceDataProviderRequiredDuchiesRequest class DataProvidersService(private val internalClient: DataProvidersCoroutineStub) : @@ -177,6 +182,42 @@ class DataProvidersService(private val internalClient: DataProvidersCoroutineStu } return internalDataProvider.toDataProvider() } + + override suspend fun replaceDataProviderCapabilities( + request: ReplaceDataProviderCapabilitiesRequest + ): DataProvider { + val key: DataProviderKey = + grpcRequireNotNull(DataProviderKey.fromName(request.name)) { + "Resource name unspecified or invalid" + } + + val principal: MeasurementPrincipal = principalFromCurrentContext + if (principal.resourceKey != key) { + failGrpc(Status.PERMISSION_DENIED) { + "Permission for method replaceDataProviderCapabilities denied on resource $request.name" + } + } + + val response: InternalDataProvider = + try { + internalClient.replaceDataProviderCapabilities( + replaceDataProviderCapabilitiesRequest { + externalDataProviderId = ApiId(key.dataProviderId).externalId.value + capabilities = request.capabilities.toInternal() + } + ) + } catch (e: StatusException) { + throw when (e.status.code) { + Status.Code.DEADLINE_EXCEEDED -> Status.DEADLINE_EXCEEDED + Status.Code.CANCELLED -> Status.CANCELLED + Status.Code.NOT_FOUND -> Status.NOT_FOUND.withDescription("DataProvider not found") + else -> Status.UNKNOWN + } + .withCause(e) + .asRuntimeException() + } + return response.toDataProvider() + } } private fun InternalDataProvider.toDataProvider(): DataProvider { @@ -204,5 +245,20 @@ private fun InternalDataProvider.toDataProvider(): DataProvider { } requiredDuchies += source.requiredExternalDuchyIdsList.map { DuchyKey(it).toName() } dataAvailabilityInterval = source.details.dataAvailabilityInterval + capabilities = source.details.capabilities.toCapabilities() + } +} + +private fun InternalDataProvider.Capabilities.toCapabilities(): DataProvider.Capabilities { + val source = this + return DataProviderKt.capabilities { + honestMajorityShareShuffleSupported = source.honestMajorityShareShuffleSupported + } +} + +private fun DataProvider.Capabilities.toInternal(): InternalDataProvider.Capabilities { + val source = this + return InternalDataProviderKt.capabilities { + honestMajorityShareShuffleSupported = source.honestMajorityShareShuffleSupported } } diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/MeasurementsService.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/MeasurementsService.kt index 951994992c3..0364f9ff1f6 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/MeasurementsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/MeasurementsService.kt @@ -63,23 +63,34 @@ import org.wfanet.measurement.common.base64UrlEncode import org.wfanet.measurement.common.grpc.failGrpc import org.wfanet.measurement.common.grpc.grpcRequire import org.wfanet.measurement.common.grpc.grpcRequireNotNull +import org.wfanet.measurement.common.identity.ApiId +import org.wfanet.measurement.common.identity.ExternalId import org.wfanet.measurement.common.identity.apiIdToExternalId import org.wfanet.measurement.internal.kingdom.CreateMeasurementRequest as InternalCreateMeasurementRequest +import org.wfanet.measurement.internal.kingdom.DataProvider as InternalDataProvider +import org.wfanet.measurement.internal.kingdom.DataProvidersGrpcKt.DataProvidersCoroutineStub as InternalDataProvidersCoroutineStub import org.wfanet.measurement.internal.kingdom.Measurement as InternalMeasurement import org.wfanet.measurement.internal.kingdom.Measurement.DataProviderValue import org.wfanet.measurement.internal.kingdom.Measurement.View as InternalMeasurementView import org.wfanet.measurement.internal.kingdom.MeasurementKt.dataProviderValue -import org.wfanet.measurement.internal.kingdom.MeasurementsGrpcKt.MeasurementsCoroutineStub +import org.wfanet.measurement.internal.kingdom.MeasurementsGrpcKt.MeasurementsCoroutineStub as InternalMeasurementsCoroutineStub +import org.wfanet.measurement.internal.kingdom.ProtocolConfig as InternalProtocolConfig +import org.wfanet.measurement.internal.kingdom.ProtocolConfigKt import org.wfanet.measurement.internal.kingdom.StreamMeasurementsRequest import org.wfanet.measurement.internal.kingdom.StreamMeasurementsRequestKt import org.wfanet.measurement.internal.kingdom.StreamMeasurementsRequestKt.filter import org.wfanet.measurement.internal.kingdom.batchCreateMeasurementsRequest +import org.wfanet.measurement.internal.kingdom.batchGetDataProvidersRequest import org.wfanet.measurement.internal.kingdom.batchGetMeasurementsRequest import org.wfanet.measurement.internal.kingdom.cancelMeasurementRequest import org.wfanet.measurement.internal.kingdom.createMeasurementRequest as internalCreateMeasurementRequest import org.wfanet.measurement.internal.kingdom.getMeasurementRequest import org.wfanet.measurement.internal.kingdom.measurementKey +import org.wfanet.measurement.internal.kingdom.protocolConfig import org.wfanet.measurement.internal.kingdom.streamMeasurementsRequest +import org.wfanet.measurement.kingdom.deploy.common.HmssProtocolConfig +import org.wfanet.measurement.kingdom.deploy.common.Llv2ProtocolConfig +import org.wfanet.measurement.kingdom.deploy.common.RoLlv2ProtocolConfig private const val DEFAULT_PAGE_SIZE = 50 private const val MAX_PAGE_SIZE = 1000 @@ -88,9 +99,16 @@ private const val MAX_BATCH_SIZE = 50 private const val MISSING_RESOURCE_NAME_ERROR = "Resource name is either unspecified or invalid" class MeasurementsService( - private val internalMeasurementsStub: MeasurementsCoroutineStub, + private val internalMeasurementsStub: InternalMeasurementsCoroutineStub, + private val internalDataProvidersStub: InternalDataProvidersCoroutineStub, private val noiseMechanisms: List, private val reachOnlyLlV2Enabled: Boolean, + /** + * Whether Honest Majority Share Shuffle (HMSS) is enabled. + * + * TODO(@renjiezh): Set this based on feature flag. + */ + private val hmssEnabled: Boolean = false, ) : MeasurementsCoroutineImplBase() { override suspend fun getMeasurement(request: GetMeasurementRequest): Measurement { @@ -124,20 +142,53 @@ class MeasurementsService( } override suspend fun createMeasurement(request: CreateMeasurementRequest): Measurement { - val authenticatedMeasurementConsumerKey = getAuthenticatedMeasurementConsumerKey() - + val authenticatedPrincipal: MeasurementPrincipal = principalFromCurrentContext val parentKey = grpcRequireNotNull(MeasurementConsumerKey.fromName(request.parent)) { "parent is either unspecified or invalid" } - - if (parentKey != authenticatedMeasurementConsumerKey) { + if (parentKey != authenticatedPrincipal.resourceKey) { failGrpc(Status.PERMISSION_DENIED) { "Cannot create a Measurement for another MeasurementConsumer" } } - val internalRequest = request.buildInternalCreateMeasurementRequest(parentKey) + grpcRequire(request.measurement.dataProvidersList.isNotEmpty()) { + "measurement.data_providers is empty" + } + + val externalDataProviderIds: List = + request.measurement.dataProvidersList.map { + val key = + grpcRequireNotNull(DataProviderKey.fromName(it.key)) { + "DataProvider resource name unspecified or invalid" + } + ApiId(key.dataProviderId).externalId + } + val dataProviderCapabilities: List = + try { + internalDataProvidersStub.batchGetDataProviders( + batchGetDataProvidersRequest { + this.externalDataProviderIds += externalDataProviderIds.map { it.value } + } + ) + } catch (e: StatusException) { + throw when (e.status.code) { + Status.Code.NOT_FOUND -> Status.NOT_FOUND.withDescription("DataProvider not found") + Status.Code.DEADLINE_EXCEEDED -> Status.DEADLINE_EXCEEDED + Status.Code.INTERNAL -> Status.INTERNAL + else -> Status.UNKNOWN + } + .withCause(e) + .asRuntimeException() + } + .dataProvidersList + .map { it.details.capabilities } + + // TODO(@SanjayVas): Check required capabilities once we have any. + + val internalRequest = + request.buildInternalCreateMeasurementRequest(dataProviderCapabilities, parentKey) val internalMeasurement = try { @@ -258,6 +309,41 @@ class MeasurementsService( failGrpc { "Number of elements in requests exceeds the maximum batch size." } } + val allExternalDataProviderIds: List = + request.requestsList + .flatMap { it.measurement.dataProvidersList } + .map { it.key } + .distinct() + .map { dataProviderName -> + val key = + grpcRequireNotNull(DataProviderKey.fromName(dataProviderName)) { + "DataProvider resource name unspecified or invalid" + } + ApiId(key.dataProviderId).externalId + } + val allDataProviderCapabilities: Map = + try { + internalDataProvidersStub.batchGetDataProviders( + batchGetDataProvidersRequest { + this.externalDataProviderIds += allExternalDataProviderIds.map { it.value } + } + ) + } catch (e: StatusException) { + throw when (e.status.code) { + Status.Code.NOT_FOUND -> Status.NOT_FOUND.withDescription("DataProvider not found") + Status.Code.DEADLINE_EXCEEDED -> Status.DEADLINE_EXCEEDED + Status.Code.INTERNAL -> Status.INTERNAL + else -> Status.UNKNOWN + } + .withCause(e) + .asRuntimeException() + } + .dataProvidersList + .associateBy { ExternalId(it.externalDataProviderId) } + .mapValues { it.value.details.capabilities } + + // TODO(@SanjayVas): Check required capabilities once we have any. + val internalCreateMeasurementRequests = mutableListOf() var isParentEmpty = false var isParentNotEmpty = false @@ -289,8 +375,15 @@ class MeasurementsService( } } + val externalDataProviderIds: Set = + createMeasurementRequest.measurement.dataProvidersList + .map { ApiId(DataProviderKey.fromName(it.key)!!.dataProviderId).externalId } + .toSet() val internalCreateMeasurementRequest = - createMeasurementRequest.buildInternalCreateMeasurementRequest(parentKey) + createMeasurementRequest.buildInternalCreateMeasurementRequest( + allDataProviderCapabilities.filterKeys { it in externalDataProviderIds }.values, + parentKey, + ) internalCreateMeasurementRequests.add(internalCreateMeasurementRequest) } @@ -381,8 +474,124 @@ class MeasurementsService( } } + private fun buildInternalProtocolConfig( + measurementSpec: MeasurementSpec, + dataProviderCapabilities: Collection, + ): InternalProtocolConfig { + val dataProvidersCount = dataProviderCapabilities.size + val internalNoiseMechanisms = noiseMechanisms.map { it.toInternal() } + @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. + return when (measurementSpec.measurementTypeCase) { + MeasurementSpec.MeasurementTypeCase.REACH -> { + if (dataProvidersCount == 1) { + protocolConfig { + direct = + ProtocolConfigKt.direct { + this.noiseMechanisms += internalNoiseMechanisms + customDirectMethodology = + InternalProtocolConfig.Direct.CustomDirectMethodology.getDefaultInstance() + deterministicCountDistinct = + InternalProtocolConfig.Direct.DeterministicCountDistinct.getDefaultInstance() + liquidLegionsCountDistinct = + InternalProtocolConfig.Direct.LiquidLegionsCountDistinct.getDefaultInstance() + } + } + } else { + if ( + hmssEnabled && dataProviderCapabilities.all { it.honestMajorityShareShuffleSupported } + ) { + protocolConfig { + externalProtocolConfigId = HmssProtocolConfig.name + honestMajorityShareShuffle = HmssProtocolConfig.protocolConfig + } + } else if (reachOnlyLlV2Enabled) { + protocolConfig { + externalProtocolConfigId = RoLlv2ProtocolConfig.name + reachOnlyLiquidLegionsV2 = RoLlv2ProtocolConfig.protocolConfig + } + } else { + protocolConfig { + externalProtocolConfigId = Llv2ProtocolConfig.name + liquidLegionsV2 = Llv2ProtocolConfig.protocolConfig + } + } + } + } + MeasurementSpec.MeasurementTypeCase.REACH_AND_FREQUENCY -> { + if (dataProvidersCount == 1) { + protocolConfig { + direct = + ProtocolConfigKt.direct { + this.noiseMechanisms += internalNoiseMechanisms + customDirectMethodology = + InternalProtocolConfig.Direct.CustomDirectMethodology.getDefaultInstance() + deterministicCountDistinct = + InternalProtocolConfig.Direct.DeterministicCountDistinct.getDefaultInstance() + liquidLegionsCountDistinct = + InternalProtocolConfig.Direct.LiquidLegionsCountDistinct.getDefaultInstance() + deterministicDistribution = + InternalProtocolConfig.Direct.DeterministicDistribution.getDefaultInstance() + liquidLegionsDistribution = + InternalProtocolConfig.Direct.LiquidLegionsDistribution.getDefaultInstance() + } + } + } else { + if ( + hmssEnabled && dataProviderCapabilities.all { it.honestMajorityShareShuffleSupported } + ) { + protocolConfig { + externalProtocolConfigId = HmssProtocolConfig.name + honestMajorityShareShuffle = HmssProtocolConfig.protocolConfig + } + } else { + protocolConfig { + externalProtocolConfigId = Llv2ProtocolConfig.name + liquidLegionsV2 = Llv2ProtocolConfig.protocolConfig + } + } + } + } + MeasurementSpec.MeasurementTypeCase.IMPRESSION -> { + protocolConfig { + direct = + ProtocolConfigKt.direct { + this.noiseMechanisms += internalNoiseMechanisms + customDirectMethodology = + InternalProtocolConfig.Direct.CustomDirectMethodology.getDefaultInstance() + deterministicCount = + InternalProtocolConfig.Direct.DeterministicCount.getDefaultInstance() + } + } + } + MeasurementSpec.MeasurementTypeCase.DURATION -> { + protocolConfig { + direct = + ProtocolConfigKt.direct { + this.noiseMechanisms += internalNoiseMechanisms + customDirectMethodology = + InternalProtocolConfig.Direct.CustomDirectMethodology.getDefaultInstance() + deterministicSum = InternalProtocolConfig.Direct.DeterministicSum.getDefaultInstance() + } + } + } + MeasurementSpec.MeasurementTypeCase.POPULATION -> { + protocolConfig { + direct = + ProtocolConfigKt.direct { + this.noiseMechanisms += internalNoiseMechanisms + deterministicCount = + InternalProtocolConfig.Direct.DeterministicCount.getDefaultInstance() + } + } + } + MeasurementSpec.MeasurementTypeCase.MEASUREMENTTYPE_NOT_SET -> + error("MeasurementType not set.") + } + } + private fun CreateMeasurementRequest.buildInternalCreateMeasurementRequest( - parentKey: MeasurementConsumerKey + dataProviderCapabilities: Collection, + parentKey: MeasurementConsumerKey, ): InternalCreateMeasurementRequest { val measurementConsumerCertificateKey = grpcRequireNotNull( @@ -409,13 +618,14 @@ class MeasurementsService( measurementSpec.validate() grpcRequire(measurement.dataProvidersList.isNotEmpty()) { "Data Providers list is empty" } - val dataProvidersMap = mutableMapOf() - measurement.dataProvidersList.forEach { - with(it.validateAndMap()) { - grpcRequire(!dataProvidersMap.containsKey(key)) { - "Duplicated keys found in the data_providers." + val dataProviderValues: Map = buildMap { + for (dataProviderEntry in measurement.dataProvidersList) { + val mapEntry: Map.Entry = + dataProviderEntry.toValidatedInternalMapEntry() + grpcRequire(!containsKey(mapEntry.key)) { + "Duplicated keys found in measurement.data_providers." } - dataProvidersMap[key] = value + put(mapEntry.key, mapEntry.value) } } @@ -426,10 +636,8 @@ class MeasurementsService( val internalMeasurement = measurement.toInternal( measurementConsumerCertificateKey, - dataProvidersMap, - measurementSpec, - noiseMechanisms.map { it.toInternal() }, - reachOnlyLlV2Enabled, + dataProviderValues, + buildInternalProtocolConfig(measurementSpec, dataProviderCapabilities), ) val requestId = this.requestId @@ -510,11 +718,9 @@ private fun MeasurementSpec.validate() { } /** Validates a [DataProviderEntry] for a request and then creates a map entry from it. */ -private fun DataProviderEntry.validateAndMap(): Map.Entry { - val dataProviderKey = - grpcRequireNotNull(DataProviderKey.fromName(key)) { - "Data Provider resource name is either unspecified or invalid" - } +private fun DataProviderEntry.toValidatedInternalMapEntry(): + Map.Entry { + val dataProviderKey = checkNotNull(DataProviderKey.fromName(key)) val dataProviderCertificateKey = grpcRequireNotNull(DataProviderCertificateKey.fromName(value.dataProviderCertificate)) { @@ -546,7 +752,7 @@ private fun DataProviderEntry.validateAndMap(): Map.Entry.toDataProviderEntry(apiVersion: Version): * Converts a public [Measurement] to an internal [InternalMeasurement] for creation. * * @throws [IllegalStateException] if MeasurementType not specified - * - * TODO(@renjie): Enable HMSS protocol based on feature flag. */ fun Measurement.toInternal( measurementConsumerCertificateKey: MeasurementConsumerCertificateKey, - dataProvidersMap: Map, - measurementSpecProto: MeasurementSpec, - internalNoiseMechanisms: List, - reachOnlyLlV2Enabled: Boolean, + dataProviderValues: Map, + internalProtocolConfig: InternalProtocolConfig, ): InternalMeasurement { val source = this return internalMeasurement { @@ -1006,108 +1000,29 @@ fun Measurement.toInternal( apiIdToExternalId(measurementConsumerCertificateKey.measurementConsumerId) externalMeasurementConsumerCertificateId = apiIdToExternalId(measurementConsumerCertificateKey.certificateId) - dataProviders.putAll(dataProvidersMap) + dataProviders.putAll(dataProviderValues.mapKeys { it.key.value }) details = details { apiVersion = Version.V2_ALPHA.string measurementSpec = source.measurementSpec.message.value measurementSpecSignature = source.measurementSpec.signature measurementSpecSignatureAlgorithmOid = source.measurementSpec.signatureAlgorithmOid + protocolConfig = internalProtocolConfig @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Proto enum fields are never null. - when (measurementSpecProto.measurementTypeCase) { - MeasurementSpec.MeasurementTypeCase.REACH -> { - if (dataProvidersCount > 1) { - if (reachOnlyLlV2Enabled) { - protocolConfig = internalProtocolConfig { - externalProtocolConfigId = RoLlv2ProtocolConfig.name - reachOnlyLiquidLegionsV2 = RoLlv2ProtocolConfig.protocolConfig - } - duchyProtocolConfig = duchyProtocolConfig { - reachOnlyLiquidLegionsV2 = RoLlv2ProtocolConfig.duchyProtocolConfig - } - } else { - protocolConfig = internalProtocolConfig { - externalProtocolConfigId = Llv2ProtocolConfig.name - liquidLegionsV2 = Llv2ProtocolConfig.protocolConfig - } - duchyProtocolConfig = duchyProtocolConfig { - liquidLegionsV2 = Llv2ProtocolConfig.duchyProtocolConfig - } - } - } else if (dataProvidersCount == 1) { - protocolConfig = internalProtocolConfig { - direct = - InternalProtocolConfigKt.direct { - noiseMechanisms += internalNoiseMechanisms - customDirectMethodology = - InternalProtocolConfig.Direct.CustomDirectMethodology.getDefaultInstance() - deterministicCountDistinct = - InternalProtocolConfig.Direct.DeterministicCountDistinct.getDefaultInstance() - liquidLegionsCountDistinct = - InternalProtocolConfig.Direct.LiquidLegionsCountDistinct.getDefaultInstance() - } - } - } - } - MeasurementSpec.MeasurementTypeCase.REACH_AND_FREQUENCY -> { - if (dataProvidersCount > 1) { - protocolConfig = internalProtocolConfig { - externalProtocolConfigId = Llv2ProtocolConfig.name - liquidLegionsV2 = Llv2ProtocolConfig.protocolConfig - } - duchyProtocolConfig = duchyProtocolConfig { - liquidLegionsV2 = Llv2ProtocolConfig.duchyProtocolConfig - } - } else if (dataProvidersCount == 1) { - protocolConfig = internalProtocolConfig { - direct = - InternalProtocolConfigKt.direct { - noiseMechanisms += internalNoiseMechanisms - customDirectMethodology = - InternalProtocolConfig.Direct.CustomDirectMethodology.getDefaultInstance() - deterministicCountDistinct = - InternalProtocolConfig.Direct.DeterministicCountDistinct.getDefaultInstance() - liquidLegionsCountDistinct = - InternalProtocolConfig.Direct.LiquidLegionsCountDistinct.getDefaultInstance() - deterministicDistribution = - InternalProtocolConfig.Direct.DeterministicDistribution.getDefaultInstance() - liquidLegionsDistribution = - InternalProtocolConfig.Direct.LiquidLegionsDistribution.getDefaultInstance() - } - } - } - } - MeasurementSpec.MeasurementTypeCase.IMPRESSION -> { - protocolConfig = internalProtocolConfig { - direct = - InternalProtocolConfigKt.direct { - noiseMechanisms += internalNoiseMechanisms - customDirectMethodology = - InternalProtocolConfig.Direct.CustomDirectMethodology.getDefaultInstance() - deterministicCount = - InternalProtocolConfig.Direct.DeterministicCount.getDefaultInstance() - } - } - } - MeasurementSpec.MeasurementTypeCase.DURATION -> { - protocolConfig = internalProtocolConfig { - direct = - InternalProtocolConfigKt.direct { - noiseMechanisms += internalNoiseMechanisms - customDirectMethodology = - InternalProtocolConfig.Direct.CustomDirectMethodology.getDefaultInstance() - deterministicSum = - InternalProtocolConfig.Direct.DeterministicSum.getDefaultInstance() - } + when (protocolConfig.protocolCase) { + InternalProtocolConfig.ProtocolCase.LIQUID_LEGIONS_V2 -> { + duchyProtocolConfig = duchyProtocolConfig { + liquidLegionsV2 = Llv2ProtocolConfig.duchyProtocolConfig } } - MeasurementSpec.MeasurementTypeCase.POPULATION -> { - protocolConfig = internalProtocolConfig { - direct = InternalProtocolConfig.Direct.getDefaultInstance() + InternalProtocolConfig.ProtocolCase.REACH_ONLY_LIQUID_LEGIONS_V2 -> { + duchyProtocolConfig = duchyProtocolConfig { + reachOnlyLiquidLegionsV2 = RoLlv2ProtocolConfig.duchyProtocolConfig } } - MeasurementSpec.MeasurementTypeCase.MEASUREMENTTYPE_NOT_SET -> - error("MeasurementType not set.") + InternalProtocolConfig.ProtocolCase.HONEST_MAJORITY_SHARE_SHUFFLE, + InternalProtocolConfig.ProtocolCase.DIRECT -> {} + InternalProtocolConfig.ProtocolCase.PROTOCOL_NOT_SET -> error("protocol not set") } } } diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/service/internal/testing/DataProvidersServiceTest.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/service/internal/testing/DataProvidersServiceTest.kt index e35bfc78db1..e7bab231cbe 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/service/internal/testing/DataProvidersServiceTest.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/service/internal/testing/DataProvidersServiceTest.kt @@ -44,6 +44,7 @@ import org.wfanet.measurement.internal.kingdom.copy import org.wfanet.measurement.internal.kingdom.dataProvider import org.wfanet.measurement.internal.kingdom.getDataProviderRequest import org.wfanet.measurement.internal.kingdom.replaceDataAvailabilityIntervalRequest +import org.wfanet.measurement.internal.kingdom.replaceDataProviderCapabilitiesRequest import org.wfanet.measurement.internal.kingdom.replaceDataProviderRequiredDuchiesRequest import org.wfanet.measurement.kingdom.deploy.common.testing.DuchyIdSetter import org.wfanet.measurement.kingdom.service.internal.testing.Population.Companion.DUCHIES @@ -93,7 +94,7 @@ abstract class DataProvidersServiceTest { } @Test - fun `createDataProvider succeeds`() = runBlocking { + fun `createDataProvider returns created DataProvider`() = runBlocking { val request = CREATE_DATA_PROVIDER_REQUEST val response: DataProvider = dataProvidersService.createDataProvider(request) @@ -113,7 +114,7 @@ abstract class DataProvidersServiceTest { fun `createDataProvider succeeds when requiredExternalDuchyIds is empty`() = runBlocking { val request = CREATE_DATA_PROVIDER_REQUEST.copy { requiredExternalDuchyIds.clear() } - val response = dataProvidersService.createDataProvider(request) + val response: DataProvider = dataProvidersService.createDataProvider(request) assertThat(response) .ignoringRepeatedFieldOrderOfFieldDescriptors(UNORDERED_FIELD_DESCRIPTORS) @@ -308,6 +309,31 @@ abstract class DataProvidersServiceTest { assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) } + @Test + fun `replaceDataProviderCapabilites updates DataProvider`() = runBlocking { + val dataProvider: DataProvider = + dataProvidersService.createDataProvider(CREATE_DATA_PROVIDER_REQUEST) + val capabilities = DataProviderKt.capabilities { honestMajorityShareShuffleSupported = true } + + val response: DataProvider = + dataProvidersService.replaceDataProviderCapabilities( + replaceDataProviderCapabilitiesRequest { + externalDataProviderId = dataProvider.externalDataProviderId + this.capabilities = capabilities + } + ) + + assertThat(response.details.capabilities).isEqualTo(capabilities) + // Ensure changes were persisted. + assertThat( + dataProvidersService.getDataProvider( + getDataProviderRequest { externalDataProviderId = dataProvider.externalDataProviderId } + ) + ) + .ignoringRepeatedFieldOrderOfFieldDescriptors(UNORDERED_FIELD_DESCRIPTORS) + .isEqualTo(response) + } + /** Random [IdGenerator] which records generated IDs. */ private class RecordingIdGenerator : IdGenerator { private val delegate = RandomIdGenerator() diff --git a/src/main/proto/wfa/measurement/internal/kingdom/data_provider.proto b/src/main/proto/wfa/measurement/internal/kingdom/data_provider.proto index 73e2a57cea0..60a2728f98f 100644 --- a/src/main/proto/wfa/measurement/internal/kingdom/data_provider.proto +++ b/src/main/proto/wfa/measurement/internal/kingdom/data_provider.proto @@ -30,6 +30,10 @@ message DataProvider { // verification. Certificate certificate = 2; + message Capabilities { + // Whether the Honest Majority Share Shuffle (HMSS) protocol is supported. + bool honest_majority_share_shuffle_supported = 1; + } message Details { // Version the public API for serialized message definitions. string api_version = 1; @@ -40,6 +44,7 @@ message DataProvider { string public_key_signature_algorithm_oid = 4; google.type.Interval data_availability_interval = 5; + Capabilities capabilities = 6; } Details details = 3; diff --git a/src/main/proto/wfa/measurement/internal/kingdom/data_providers_service.proto b/src/main/proto/wfa/measurement/internal/kingdom/data_providers_service.proto index f7db11a58fb..30f0a9e732a 100644 --- a/src/main/proto/wfa/measurement/internal/kingdom/data_providers_service.proto +++ b/src/main/proto/wfa/measurement/internal/kingdom/data_providers_service.proto @@ -36,6 +36,9 @@ service DataProviders { rpc ReplaceDataAvailabilityInterval(ReplaceDataAvailabilityIntervalRequest) returns (DataProvider); + + rpc ReplaceDataProviderCapabilities(ReplaceDataProviderCapabilitiesRequest) + returns (DataProvider); } message ReplaceDataProviderRequiredDuchiesRequest { @@ -59,3 +62,10 @@ message ReplaceDataAvailabilityIntervalRequest { fixed64 external_data_provider_id = 1; google.type.Interval data_availability_interval = 2; } + +message ReplaceDataProviderCapabilitiesRequest { + fixed64 external_data_provider_id = 1; + + // New value for `capabilities`. + DataProvider.Capabilities capabilities = 2; +} diff --git a/src/test/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/DataProvidersServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/DataProvidersServiceTest.kt index c9a26e22e1e..9ce61387513 100644 --- a/src/test/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/DataProvidersServiceTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/DataProvidersServiceTest.kt @@ -31,13 +31,17 @@ import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.JUnit4 import org.mockito.kotlin.any +import org.mockito.kotlin.stub import org.mockito.kotlin.whenever import org.wfanet.measurement.api.Version +import org.wfanet.measurement.api.v2alpha.DataProvider +import org.wfanet.measurement.api.v2alpha.DataProviderKt import org.wfanet.measurement.api.v2alpha.DuchyKey import org.wfanet.measurement.api.v2alpha.copy import org.wfanet.measurement.api.v2alpha.dataProvider import org.wfanet.measurement.api.v2alpha.getDataProviderRequest import org.wfanet.measurement.api.v2alpha.replaceDataAvailabilityIntervalRequest +import org.wfanet.measurement.api.v2alpha.replaceDataProviderCapabilitiesRequest import org.wfanet.measurement.api.v2alpha.replaceDataProviderRequiredDuchiesRequest import org.wfanet.measurement.api.v2alpha.setMessage import org.wfanet.measurement.api.v2alpha.signedMessage @@ -58,7 +62,7 @@ import org.wfanet.measurement.common.toProtoTime import org.wfanet.measurement.consent.client.common.toEncryptionPublicKey import org.wfanet.measurement.internal.kingdom.CertificateKt import org.wfanet.measurement.internal.kingdom.DataProvider as InternalDataProvider -import org.wfanet.measurement.internal.kingdom.DataProviderKt.details +import org.wfanet.measurement.internal.kingdom.DataProviderKt as InternalDataProviderKt import org.wfanet.measurement.internal.kingdom.DataProvidersGrpcKt.DataProvidersCoroutineImplBase as InternalDataProvidersService import org.wfanet.measurement.internal.kingdom.DataProvidersGrpcKt.DataProvidersCoroutineStub as InternalDataProvidersClient import org.wfanet.measurement.internal.kingdom.certificate as internalCertificate @@ -66,6 +70,7 @@ import org.wfanet.measurement.internal.kingdom.copy import org.wfanet.measurement.internal.kingdom.dataProvider as internalDataProvider import org.wfanet.measurement.internal.kingdom.getDataProviderRequest as internalGetDataProviderRequest import org.wfanet.measurement.internal.kingdom.replaceDataAvailabilityIntervalRequest as internalReplaceDataAvailabilityIntervalRequest +import org.wfanet.measurement.internal.kingdom.replaceDataProviderCapabilitiesRequest as internalReplaceDataProviderCapabilitiesRequest import org.wfanet.measurement.internal.kingdom.replaceDataProviderRequiredDuchiesRequest as internalReplaceDataProviderRequiredDuchiesRequest private const val DATA_PROVIDER_ID = 123L @@ -527,6 +532,61 @@ class DataProvidersServiceTest { assertThat(exception.status.code).isEqualTo(Status.Code.PERMISSION_DENIED) } + @Test + fun `replaceDataProviderCapabilities returns updated DataProvider`() { + val internalDataProvider = + INTERNAL_DATA_PROVIDER.copy { + details = + details.copy { + capabilities = capabilities.copy { honestMajorityShareShuffleSupported = true } + } + } + internalServiceMock.stub { + onBlocking { replaceDataProviderCapabilities(any()) }.thenReturn(internalDataProvider) + } + val request = replaceDataProviderCapabilitiesRequest { + name = DATA_PROVIDER_NAME + capabilities = DataProviderKt.capabilities { honestMajorityShareShuffleSupported = true } + } + + val response: DataProvider = runBlocking { + withDataProviderPrincipal(DATA_PROVIDER_NAME) { + service.replaceDataProviderCapabilities(request) + } + } + + assertThat(response).isEqualTo(DATA_PROVIDER.copy { capabilities = request.capabilities }) + verifyProtoArgument( + internalServiceMock, + InternalDataProvidersService::replaceDataProviderCapabilities, + ) + .isEqualTo( + internalReplaceDataProviderCapabilitiesRequest { + externalDataProviderId = internalDataProvider.externalDataProviderId + capabilities = internalDataProvider.details.capabilities + } + ) + } + + @Test + fun `replaceDataProviderCapabilities throws PERMISSION_DENIED for incorrect principal`() { + val request = replaceDataProviderCapabilitiesRequest { + name = DATA_PROVIDER_NAME + capabilities = DataProviderKt.capabilities { honestMajorityShareShuffleSupported = true } + } + + val exception = + assertFailsWith { + runBlocking { + withDataProviderPrincipal(DATA_PROVIDER_NAME_2) { + service.replaceDataProviderCapabilities(request) + } + } + } + + assertThat(exception.status.code).isEqualTo(Status.Code.PERMISSION_DENIED) + } + companion object { private val API_VERSION = Version.V2_ALPHA @@ -544,16 +604,17 @@ class DataProvidersServiceTest { private val INTERNAL_DATA_PROVIDER: InternalDataProvider = internalDataProvider { externalDataProviderId = DATA_PROVIDER_ID - details = details { - apiVersion = API_VERSION.string - publicKey = SIGNED_PUBLIC_KEY.message.value - publicKeySignature = SIGNED_PUBLIC_KEY.signature - publicKeySignatureAlgorithmOid = SIGNED_PUBLIC_KEY.signatureAlgorithmOid - dataAvailabilityInterval = interval { - startTime = timestamp { seconds = 100 } - endTime = timestamp { seconds = 200 } + details = + InternalDataProviderKt.details { + apiVersion = API_VERSION.string + publicKey = SIGNED_PUBLIC_KEY.message.value + publicKeySignature = SIGNED_PUBLIC_KEY.signature + publicKeySignatureAlgorithmOid = SIGNED_PUBLIC_KEY.signatureAlgorithmOid + dataAvailabilityInterval = interval { + startTime = timestamp { seconds = 100 } + endTime = timestamp { seconds = 200 } + } } - } certificate = internalCertificate { externalDataProviderId = DATA_PROVIDER_ID externalCertificateId = CERTIFICATE_ID @@ -572,6 +633,7 @@ class DataProvidersServiceTest { publicKey = SIGNED_PUBLIC_KEY requiredDuchies += DUCHY_NAMES dataAvailabilityInterval = INTERNAL_DATA_PROVIDER.details.dataAvailabilityInterval + capabilities = DataProvider.Capabilities.getDefaultInstance() } } } diff --git a/src/test/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/MeasurementsServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/MeasurementsServiceTest.kt index 0cd23e7b1a2..a483f6f2af9 100644 --- a/src/test/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/MeasurementsServiceTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/MeasurementsServiceTest.kt @@ -17,6 +17,7 @@ package org.wfanet.measurement.kingdom.service.api.v2alpha import com.google.common.truth.Truth.assertThat +import com.google.common.truth.Truth.assertWithMessage import com.google.common.truth.extensions.proto.ProtoTruth.assertThat import com.google.protobuf.ByteString import com.google.protobuf.Timestamp @@ -102,6 +103,10 @@ import org.wfanet.measurement.common.testing.captureFirst import org.wfanet.measurement.common.testing.verifyProtoArgument import org.wfanet.measurement.common.toByteString import org.wfanet.measurement.common.toProtoTime +import org.wfanet.measurement.internal.kingdom.BatchGetDataProvidersRequest +import org.wfanet.measurement.internal.kingdom.DataProvider as InternalDataProvider +import org.wfanet.measurement.internal.kingdom.DataProviderKt as InternalDataProviderKt +import org.wfanet.measurement.internal.kingdom.DataProvidersGrpcKt import org.wfanet.measurement.internal.kingdom.DuchyProtocolConfig import org.wfanet.measurement.internal.kingdom.Measurement as InternalMeasurement import org.wfanet.measurement.internal.kingdom.Measurement.State as InternalState @@ -116,11 +121,13 @@ import org.wfanet.measurement.internal.kingdom.StreamMeasurementsRequest import org.wfanet.measurement.internal.kingdom.StreamMeasurementsRequestKt import org.wfanet.measurement.internal.kingdom.batchCreateMeasurementsRequest as internalBatchCreateMeasurementsRequest import org.wfanet.measurement.internal.kingdom.batchCreateMeasurementsResponse as internalBatchCreateMeasurementsResponse +import org.wfanet.measurement.internal.kingdom.batchGetDataProvidersResponse as internalBatchGetDataProvidersResponse import org.wfanet.measurement.internal.kingdom.batchGetMeasurementsRequest as internalBatchGetMeasurementsRequest import org.wfanet.measurement.internal.kingdom.batchGetMeasurementsResponse as internalBatchGetMeasurementsResponse import org.wfanet.measurement.internal.kingdom.cancelMeasurementRequest as internalCancelMeasurementRequest import org.wfanet.measurement.internal.kingdom.copy import org.wfanet.measurement.internal.kingdom.createMeasurementRequest as internalCreateMeasurementRequest +import org.wfanet.measurement.internal.kingdom.dataProvider as internalDataProvider import org.wfanet.measurement.internal.kingdom.differentialPrivacyParams as internalDifferentialPrivacyParams import org.wfanet.measurement.internal.kingdom.duchyProtocolConfig import org.wfanet.measurement.internal.kingdom.getMeasurementRequest as internalGetMeasurementRequest @@ -129,12 +136,11 @@ import org.wfanet.measurement.internal.kingdom.measurement as internalMeasuremen import org.wfanet.measurement.internal.kingdom.measurementKey import org.wfanet.measurement.internal.kingdom.protocolConfig as internalProtocolConfig import org.wfanet.measurement.internal.kingdom.streamMeasurementsRequest +import org.wfanet.measurement.kingdom.deploy.common.HmssProtocolConfig import org.wfanet.measurement.kingdom.deploy.common.Llv2ProtocolConfig import org.wfanet.measurement.kingdom.deploy.common.RoLlv2ProtocolConfig private const val DEFAULT_LIMIT = 50 -private const val DATA_PROVIDERS_CERTIFICATE_NAME = - "dataProviders/AAAAAAAAAHs/certificates/AAAAAAAAAHs" private const val DATA_PROVIDERS_RESULT_CERTIFICATE_NAME = "dataProviders/AAAAAAAAALs/certificates/AAAAAAAAALs" private const val MEASUREMENT_CONSUMER_NAME = "measurementConsumers/AAAAAAAAAHs" @@ -198,8 +204,24 @@ class MeasurementsServiceTest { } ) } + private val internalDataProvidersMock: DataProvidersGrpcKt.DataProvidersCoroutineImplBase = + mockService { + onBlocking { batchGetDataProviders(any()) } + .thenAnswer { invocation -> + val request: BatchGetDataProvidersRequest = invocation.getArgument(0) + val internalDataProviders: List = + request.externalDataProviderIdsList.map { + internalDataProvider { externalDataProviderId = it } + } + internalBatchGetDataProvidersResponse { dataProviders.addAll(internalDataProviders) } + } + } - @get:Rule val grpcTestServerRule = GrpcTestServerRule { addService(internalMeasurementsMock) } + @get:Rule + val grpcTestServerRule = GrpcTestServerRule { + addService(internalMeasurementsMock) + addService(internalDataProvidersMock) + } private lateinit var service: MeasurementsService @@ -208,8 +230,10 @@ class MeasurementsServiceTest { service = MeasurementsService( MeasurementsGrpcKt.MeasurementsCoroutineStub(grpcTestServerRule.channel), + DataProvidersGrpcKt.DataProvidersCoroutineStub(grpcTestServerRule.channel), NOISE_MECHANISMS, reachOnlyLlV2Enabled = true, + hmssEnabled = true, ) } @@ -738,6 +762,66 @@ class MeasurementsServiceTest { ) } + @Test + fun `createMeasurement with HMSS enabled and EDPs capable specifies HMSS protocol`() { + internalDataProvidersMock.stub { + onBlocking { batchGetDataProviders(any()) } + .thenReturn( + internalBatchGetDataProvidersResponse { + for (externalDataProviderId in EXTERNAL_DATA_PROVIDER_IDS) { + dataProviders += internalDataProvider { + this.externalDataProviderId = externalDataProviderId.value + details = + details.copy { + capabilities = + InternalDataProviderKt.capabilities { + honestMajorityShareShuffleSupported = true + } + } + } + } + } + ) + } + val measurement = + MEASUREMENT.copy { + clearFailure() + results.clear() + clearProtocolConfig() + } + val request = createMeasurementRequest { + parent = MEASUREMENT_CONSUMER_NAME + this.measurement = measurement + requestId = "foo" + } + + withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME) { + runBlocking { service.createMeasurement(request) } + } + + verifyProtoArgument( + internalMeasurementsMock, + MeasurementsGrpcKt.MeasurementsCoroutineImplBase::createMeasurement, + ) + .isEqualTo( + internalCreateMeasurementRequest { + this.measurement = + INTERNAL_MEASUREMENT.copy { + clearExternalMeasurementId() + clearUpdateTime() + results.clear() + details = + details.copy { + clearFailure() + protocolConfig = HMSS_INTERNAL_PROTOCOL_CONFIG + clearDuchyProtocolConfig() + } + } + requestId = request.requestId + } + ) + } + @Test fun `createMeasurement throws INVALID_ARGUMENT when model line is missing for POPULATION measurement`() { val exception = @@ -809,11 +893,19 @@ class MeasurementsServiceTest { assertFailsWith { withDataProviderPrincipal(DATA_PROVIDERS_NAME) { runBlocking { - service.createMeasurement(createMeasurementRequest { measurement = MEASUREMENT }) + service.createMeasurement( + createMeasurementRequest { + parent = MEASUREMENT_CONSUMER_NAME + measurement = MEASUREMENT + } + ) } } } - assertThat(exception.status.code).isEqualTo(Status.Code.PERMISSION_DENIED) + + assertWithMessage(exception.toString()) + .that(exception.status.code) + .isEqualTo(Status.Code.PERMISSION_DENIED) } @Test @@ -824,7 +916,9 @@ class MeasurementsServiceTest { service.createMeasurement(createMeasurementRequest { measurement = MEASUREMENT }) } } - assertThat(exception.status.code).isEqualTo(Status.Code.UNAUTHENTICATED) + assertWithMessage(exception.toString()) + .that(exception.status.code) + .isEqualTo(Status.Code.UNAUTHENTICATED) } @Test @@ -1968,6 +2062,7 @@ class MeasurementsServiceTest { assertThat(exception.status.code).isEqualTo(Status.Code.INVALID_ARGUMENT) } + @Test fun `batchCreateMeasurements throws INVALID_ARGUMENT when child parent doesn't match`() { val createMeasurementRequest = createMeasurementRequest { parent = MEASUREMENT_CONSUMER_NAME_2 @@ -2336,6 +2431,10 @@ class MeasurementsServiceTest { setOf("aggregator"), 2, ) + HmssProtocolConfig.setForTest( + HMSS_INTERNAL_PROTOCOL_CONFIG.honestMajorityShareShuffle, + setOf("aggregator", "worker1", "worker2"), + ) } private val API_VERSION = Version.V2_ALPHA @@ -2431,6 +2530,11 @@ class MeasurementsServiceTest { reachOnlyLiquidLegionsV2 = DuchyProtocolConfig.LiquidLegionsV2.getDefaultInstance() } + private val HMSS_INTERNAL_PROTOCOL_CONFIG = internalProtocolConfig { + externalProtocolConfigId = "hmss" + honestMajorityShareShuffle = InternalProtocolConfigKt.honestMajorityShareShuffle {} + } + private val DATA_PROVIDER_PUBLIC_KEY = encryptionPublicKey { data = UPDATE_TIME.toByteString() } private val MEASUREMENT_PUBLIC_KEY = encryptionPublicKey { data = UPDATE_TIME.toByteString() } @@ -2652,7 +2756,10 @@ class MeasurementsServiceTest { } private val DEFAULT_INTERNAL_DIRECT_POPULATION_PROTOCOL_CONFIG: InternalProtocolConfig.Direct = - InternalProtocolConfig.Direct.getDefaultInstance() + direct { + noiseMechanisms += DEFAULT_INTERNAL_DIRECT_NOISE_MECHANISMS + deterministicCount = InternalProtocolConfig.Direct.DeterministicCount.getDefaultInstance() + } private const val BATCH_LIMIT = 50 }