From aac627dc2966111841feb471a8c1285d0469fb35 Mon Sep 17 00:00:00 2001 From: Rieman Li Date: Sat, 5 Aug 2023 00:28:44 +0000 Subject: [PATCH] Circulate direct computation methodology from protocol config to measurement result. --- build/repositories.bzl | 6 +- src/main/k8s/kingdom.cue | 8 +- .../eventdataprovider/noiser/Noiser.kt | 11 +- .../integration/common/BUILD.bazel | 1 + .../common/InProcessEdpSimulator.kt | 3 - .../integration/common/InProcessKingdom.kt | 9 +- .../common/server/V2alphaPublicApiServer.kt | 50 +++ .../spanner/writers/CreateMeasurement.kt | 3 +- .../api/v2alpha/ExchangeStepsService.kt | 4 +- .../api/v2alpha/MeasurementsService.kt | 5 +- .../service/api/v2alpha/ProtoConversions.kt | 394 ++++++++++++++--- .../testing/MeasurementsServiceTest.kt | 12 +- .../service/internal/testing/Population.kt | 2 + .../testing/RequisitionsServiceTest.kt | 2 +- .../system/v1alpha/ProtoConversions.kt | 3 + .../loadtest/dataprovider/BUILD.bazel | 1 + .../loadtest/dataprovider/EdpSimulator.kt | 282 ++++++++++-- .../dataprovider/EdpSimulatorFlags.kt | 9 - .../dataprovider/EdpSimulatorRunner.kt | 1 - .../dataprovider/MeasurementResults.kt | 5 +- .../wfa/measurement/api/v2alpha/BUILD.bazel | 15 + .../measurement/internal/kingdom/BUILD.bazel | 4 + .../internal/kingdom/direct_computation.proto | 61 +++ .../internal/kingdom/protocol_config.proto | 68 +++ .../api/v2alpha/MeasurementsServiceTest.kt | 155 ++++++- .../api/v2alpha/RequisitionsServiceTest.kt | 202 +++++++-- .../loadtest/dataprovider/EdpSimulatorTest.kt | 417 ++++++++++++++++-- 27 files changed, 1508 insertions(+), 225 deletions(-) create mode 100644 src/main/proto/wfa/measurement/internal/kingdom/direct_computation.proto diff --git a/build/repositories.bzl b/build/repositories.bzl index 20a9d7a3f94..ef3e00fbf85 100644 --- a/build/repositories.bzl +++ b/build/repositories.bzl @@ -38,11 +38,13 @@ def wfa_measurement_system_repositories(): version = "0.10.0", ) + # DO_NOT_SUBMIT(world-federation-of-advertisers/cross-media-measurement-api/#163): + # switch to a release version before submitting the PR. wfa_repo_archive( name = "wfa_measurement_proto", repo = "cross-media-measurement-api", - sha256 = "1d829e7d95e6dedea1a4ea746e5613915dd60ca095b7b35bdcf19fa067697f2a", - version = "0.39.2", + sha256 = "e9f24d78e06f5ec78fe4c4ed42d73ad5440f2af9ed7b5e560df2040ce75b592f", + version = "0.39.3", ) wfa_repo_archive( diff --git a/src/main/k8s/kingdom.cue b/src/main/k8s/kingdom.cue index 0fced806f7a..401b0981af8 100644 --- a/src/main/k8s/kingdom.cue +++ b/src/main/k8s/kingdom.cue @@ -72,6 +72,12 @@ import ("strings") _open_id_redirect_uri_flag: "--open-id-redirect-uri=https://localhost:2048" + _directNoiseMechanismFlags: [ + "--direct-noise-mechanism=NONE", + "--direct-noise-mechanism=CONTINUOUS_LAPLACE", + "--direct-noise-mechanism=CONTINUOUS_GAUSSIAN", + ] + _kingdomCompletedMeasurementsTimeToLiveFlag: "--time-to-live=\(_completedMeasurementsTimeToLive)" _kingdomCompletedMeasurementsDryRunRetentionPolicyFlag: "--dry-run=\(_completedMeasurementsDryRun)" _kingdomPendingMeasurementsTimeToLiveFlag: "--time-to-live=\(_pendingMeasurementsTimeToLive)" @@ -169,7 +175,7 @@ import ("strings") _akid_to_principal_map_file_flag, _open_id_redirect_uri_flag, _duchy_info_config_flag, - ] + Container._commonServerFlags + ] + _directNoiseMechanismFlags + Container._commonServerFlags } spec: template: spec: { _mounts: "config-files": #ConfigMapMount diff --git a/src/main/kotlin/org/wfanet/measurement/eventdataprovider/noiser/Noiser.kt b/src/main/kotlin/org/wfanet/measurement/eventdataprovider/noiser/Noiser.kt index 458af2ce82a..818b86ae868 100644 --- a/src/main/kotlin/org/wfanet/measurement/eventdataprovider/noiser/Noiser.kt +++ b/src/main/kotlin/org/wfanet/measurement/eventdataprovider/noiser/Noiser.kt @@ -17,17 +17,12 @@ package org.wfanet.measurement.eventdataprovider.noiser /** Internal Differential Privacy(DP) parameters. */ data class DpParams(val epsilon: Double, val delta: Double) -/** - * Noise mechanism for generating publisher noise for direct measurements. - * - * TODO(@iverson52000): Move this to public API if EDP needs to report back the direct noise - * mechanism for PBM tracking. NONE mechanism is testing only and should not move to public API. - */ +/** Noise mechanism for generating publisher noise for direct measurements. */ enum class DirectNoiseMechanism { /** NONE mechanism is testing only. */ NONE, - LAPLACE, - GAUSSIAN, + CONTINUOUS_LAPLACE, + CONTINUOUS_GAUSSIAN, } /** A base Noiser interface for direct measurements. */ diff --git a/src/main/kotlin/org/wfanet/measurement/integration/common/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/integration/common/BUILD.bazel index 428e7d64bc9..a4ffb9e3474 100644 --- a/src/main/kotlin/org/wfanet/measurement/integration/common/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/integration/common/BUILD.bazel @@ -77,6 +77,7 @@ kt_jvm_library( "//src/main/kotlin/org/wfanet/measurement/kingdom/deploy/common/service:data_services", "//src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha:api_key_authentication_server_interceptor", "//src/main/kotlin/org/wfanet/measurement/loadtest/panelmatchresourcesetup", + "//src/main/proto/wfa/measurement/api/v2alpha:protocol_config_kt_jvm_proto", "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc", "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc/testing", "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/testing", diff --git a/src/main/kotlin/org/wfanet/measurement/integration/common/InProcessEdpSimulator.kt b/src/main/kotlin/org/wfanet/measurement/integration/common/InProcessEdpSimulator.kt index 69cac2480e9..b1d3ca9cdc1 100644 --- a/src/main/kotlin/org/wfanet/measurement/integration/common/InProcessEdpSimulator.kt +++ b/src/main/kotlin/org/wfanet/measurement/integration/common/InProcessEdpSimulator.kt @@ -41,7 +41,6 @@ import org.wfanet.measurement.api.v2alpha.RequisitionsGrpcKt.RequisitionsCorouti import org.wfanet.measurement.api.v2alpha.event_group_metadata.testing.SyntheticEventGroupSpec import org.wfanet.measurement.common.identity.withPrincipalName import org.wfanet.measurement.common.throttler.MinimumIntervalThrottler -import org.wfanet.measurement.eventdataprovider.noiser.DirectNoiseMechanism import org.wfanet.measurement.eventdataprovider.privacybudgetmanagement.CompositionMechanism import org.wfanet.measurement.eventdataprovider.privacybudgetmanagement.InMemoryBackingStore import org.wfanet.measurement.eventdataprovider.privacybudgetmanagement.PrivacyBucketFilter @@ -102,7 +101,6 @@ class InProcessEdpSimulator( 100.0f ), trustedCertificates = trustedCertificates, - DIRECT_NOISE_MECHANISM, random = random, compositionMechanism = COMPOSITION_MECHANISM, ) @@ -133,7 +131,6 @@ class InProcessEdpSimulator( private val logger: Logger = Logger.getLogger(this::class.java.name) private const val RANDOM_SEED: Long = 1 private val random = Random(RANDOM_SEED) - private val DIRECT_NOISE_MECHANISM = DirectNoiseMechanism.LAPLACE private val COMPOSITION_MECHANISM = CompositionMechanism.DP_ADVANCED } } diff --git a/src/main/kotlin/org/wfanet/measurement/integration/common/InProcessKingdom.kt b/src/main/kotlin/org/wfanet/measurement/integration/common/InProcessKingdom.kt index 325620f123b..e2f8c7dc215 100644 --- a/src/main/kotlin/org/wfanet/measurement/integration/common/InProcessKingdom.kt +++ b/src/main/kotlin/org/wfanet/measurement/integration/common/InProcessKingdom.kt @@ -20,6 +20,7 @@ import java.util.logging.Logger import org.junit.rules.TestRule import org.junit.runner.Description import org.junit.runners.model.Statement +import org.wfanet.measurement.api.v2alpha.ProtocolConfig import org.wfanet.measurement.api.v2alpha.testing.withMetadataPrincipalIdentities import org.wfanet.measurement.common.grpc.testing.GrpcTestServerRule import org.wfanet.measurement.common.grpc.withDefaultDeadline @@ -150,7 +151,7 @@ class InProcessKingdom( EventGroupMetadataDescriptorsService(internalEventGroupMetadataDescriptorsClient) .withMetadataPrincipalIdentities() .withApiKeyAuthenticationServerInterceptor(internalApiKeysClient), - MeasurementsService(internalMeasurementsClient) + MeasurementsService(internalMeasurementsClient, MEASUREMENT_NOISE_MECHANISMS) .withMetadataPrincipalIdentities() .withApiKeyAuthenticationServerInterceptor(internalApiKeysClient), PublicKeysService(internalPublicKeysClient) @@ -206,5 +207,11 @@ class InProcessKingdom( /** Default deadline for RPCs to internal server in milliseconds. */ private const val DEFAULT_INTERNAL_DEADLINE_MILLIS = 30_000L + private val MEASUREMENT_NOISE_MECHANISMS: List = + listOf( + ProtocolConfig.NoiseMechanism.NONE, + ProtocolConfig.NoiseMechanism.CONTINUOUS_LAPLACE, + ProtocolConfig.NoiseMechanism.CONTINUOUS_GAUSSIAN, + ) } } diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/common/server/V2alphaPublicApiServer.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/common/server/V2alphaPublicApiServer.kt index a90b1d7e1d4..a37029be311 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/common/server/V2alphaPublicApiServer.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/common/server/V2alphaPublicApiServer.kt @@ -17,6 +17,7 @@ package org.wfanet.measurement.kingdom.deploy.common.server import io.grpc.ServerServiceDefinition import java.io.File import org.wfanet.measurement.api.v2alpha.AkidPrincipalLookup +import org.wfanet.measurement.api.v2alpha.ProtocolConfig.NoiseMechanism import org.wfanet.measurement.api.v2alpha.withPrincipalsFromX509AuthorityKeyIdentifiers import org.wfanet.measurement.common.commandLineMain import org.wfanet.measurement.common.crypto.SigningCerts @@ -143,6 +144,7 @@ private fun run( .withApiKeyAuthenticationServerInterceptor(internalApiKeysCoroutineStub), MeasurementsService( InternalMeasurementsCoroutineStub(channel), + v2alphaFlags.directNoiseMechanisms ) .withPrincipalsFromX509AuthorityKeyIdentifiers(principalLookup) .withApiKeyAuthenticationServerInterceptor(internalApiKeysCoroutineStub), @@ -192,6 +194,7 @@ fun main(args: Array) = commandLineMain(::run, args) /** Flags specific to the V2alpha API version. */ private class V2alphaFlags { + @CommandLine.Option( names = ["--authority-key-identifier-to-principal-map-file"], description = ["File path to a AuthorityKeyToPrincipalMap textproto"], @@ -207,4 +210,51 @@ private class V2alphaFlags { ) lateinit var redirectUri: String private set + + lateinit var directNoiseMechanisms: List + private set + + @CommandLine.Spec + lateinit var spec: CommandLine.Model.CommandSpec // injected by picocli + private set + + @CommandLine.Option( + names = ["--direct-noise-mechanism"], + description = + [ + "Noise mechanisms that can be used in direct computation. It can be specified multiple " + + "times." + ], + required = true + ) + fun setDirectNoiseMechanisms(noiseMechanisms: List) { + for (noiseMechanism in noiseMechanisms) { + when (noiseMechanism) { + NoiseMechanism.NONE, + NoiseMechanism.CONTINUOUS_LAPLACE, + NoiseMechanism.CONTINUOUS_GAUSSIAN -> {} + NoiseMechanism.GEOMETRIC, + // TODO(@riemanli): support DISCRETE_GAUSSIAN after having a clear definition of it. + NoiseMechanism.DISCRETE_GAUSSIAN -> { + throw CommandLine.ParameterException( + spec.commandLine(), + String.format( + "Invalid noise mechanism $noiseMechanism for option '--direct-noise-mechanism'. " + + "Discrete mechanisms are not supported for direct computations." + ) + ) + } + NoiseMechanism.NOISE_MECHANISM_UNSPECIFIED, + NoiseMechanism.UNRECOGNIZED -> { + throw CommandLine.ParameterException( + spec.commandLine(), + String.format( + "Invalid noise mechanism $noiseMechanism for option '--direct-noise-mechanism'." + ) + ) + } + } + } + directNoiseMechanisms = noiseMechanisms + } } diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/writers/CreateMeasurement.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/writers/CreateMeasurement.kt index 314b5636305..fd010b94bd5 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/writers/CreateMeasurement.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/writers/CreateMeasurement.kt @@ -88,8 +88,9 @@ class CreateMeasurement(private val request: CreateMeasurementRequest) : ProtocolConfig.ProtocolCase.REACH_ONLY_LIQUID_LEGIONS_V2 -> { createComputedMeasurement(request.measurement, measurementConsumerId) } - ProtocolConfig.ProtocolCase.PROTOCOL_NOT_SET -> + ProtocolConfig.ProtocolCase.DIRECT -> createDirectMeasurement(request.measurement, measurementConsumerId) + ProtocolConfig.ProtocolCase.PROTOCOL_NOT_SET -> error("Protocol is not set.") } } diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/ExchangeStepsService.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/ExchangeStepsService.kt index a45c2eeca4f..cf0bb135c03 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/ExchangeStepsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/ExchangeStepsService.kt @@ -159,9 +159,7 @@ class ExchangeStepsService(private val internalExchangeSteps: InternalExchangeSt try { it.toV2Alpha() } catch (e: Throwable) { - failGrpc(Status.INVALID_ARGUMENT) { - e.message ?: "Failed to convert ProtocolConfig ExchangeStep" - } + failGrpc(Status.INVALID_ARGUMENT) { e.message ?: "Failed to convert ExchangeStep" } } } nextPageToken = results.last().updateTime.toByteArray().base64UrlEncode() diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/MeasurementsService.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/MeasurementsService.kt index 60aeb86d8d1..1a5c499f091 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/MeasurementsService.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/MeasurementsService.kt @@ -41,6 +41,7 @@ import org.wfanet.measurement.api.v2alpha.MeasurementKey import org.wfanet.measurement.api.v2alpha.MeasurementPrincipal import org.wfanet.measurement.api.v2alpha.MeasurementSpec import org.wfanet.measurement.api.v2alpha.MeasurementsGrpcKt.MeasurementsCoroutineImplBase +import org.wfanet.measurement.api.v2alpha.ProtocolConfig.NoiseMechanism import org.wfanet.measurement.api.v2alpha.copy import org.wfanet.measurement.api.v2alpha.listMeasurementsPageToken import org.wfanet.measurement.api.v2alpha.listMeasurementsResponse @@ -72,6 +73,7 @@ private const val MISSING_RESOURCE_NAME_ERROR = "Resource name is either unspeci class MeasurementsService( private val internalMeasurementsStub: MeasurementsCoroutineStub, + private val noiseMechanisms: List ) : MeasurementsCoroutineImplBase() { override suspend fun getMeasurement(request: GetMeasurementRequest): Measurement { @@ -167,7 +169,8 @@ class MeasurementsService( request.measurement.toInternal( measurementConsumerCertificateKey, dataProvidersMap, - parsedMeasurementSpec + parsedMeasurementSpec, + noiseMechanisms.map { it.toInternal() } ) requestId = request.requestId } diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/ProtoConversions.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/ProtoConversions.kt index 16f668fd284..5632d235ed1 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/ProtoConversions.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/ProtoConversions.kt @@ -15,7 +15,6 @@ package org.wfanet.measurement.kingdom.service.api.v2alpha import com.google.protobuf.util.Timestamps -import com.google.type.date import com.google.type.interval import java.time.ZoneOffset import org.wfanet.measurement.api.Version @@ -107,6 +106,7 @@ import org.wfanet.measurement.internal.kingdom.ModelShard as InternalModelShard import org.wfanet.measurement.internal.kingdom.ModelSuite as InternalModelSuite import org.wfanet.measurement.internal.kingdom.ProtocolConfig as InternalProtocolConfig import org.wfanet.measurement.internal.kingdom.ProtocolConfig.NoiseMechanism as InternalNoiseMechanism +import org.wfanet.measurement.internal.kingdom.ProtocolConfigKt as InternalProtocolConfigKt import org.wfanet.measurement.internal.kingdom.duchyProtocolConfig import org.wfanet.measurement.internal.kingdom.exchangeWorkflow import org.wfanet.measurement.internal.kingdom.measurement as internalMeasurement @@ -120,6 +120,80 @@ import org.wfanet.measurement.internal.kingdom.protocolConfig as internalProtoco import org.wfanet.measurement.kingdom.deploy.common.Llv2ProtocolConfig import org.wfanet.measurement.kingdom.deploy.common.RoLlv2ProtocolConfig +/** + * Default maximum frequency used in the direct distribution methodology. + * + * TODO(world-federation-of-advertisers/cross-media-measurement-api#160): this value won't be needed + * once the maximum frequency field is moved to measurement spec + */ +const val DEFAULT_MAXIMUM_FREQUENCY_DIRECT_DISTRIBUTION = 20 + +/** Default options of direct noise mechanisms to data providers. */ +val DEFAULT_DIRECT_NOISE_MECHANISMS: List = + listOf( + NoiseMechanism.NONE, + NoiseMechanism.GEOMETRIC, + NoiseMechanism.DISCRETE_GAUSSIAN, + NoiseMechanism.CONTINUOUS_LAPLACE, + NoiseMechanism.CONTINUOUS_GAUSSIAN + ) + +/** + * Default direct reach protocol config for backward compatibility. + * + * Used when existing direct protocol configs of reach measurements don't have methodologies. + */ +val DEFAULT_DIRECT_REACH_PROTOCOL_CONFIG: ProtocolConfig.Direct = direct { + noiseMechanisms += DEFAULT_DIRECT_NOISE_MECHANISMS + deterministicCountDistinct = ProtocolConfig.Direct.DeterministicCountDistinct.getDefaultInstance() + liquidLegionsCountDistinct = ProtocolConfig.Direct.LiquidLegionsCountDistinct.getDefaultInstance() +} + +/** + * Default direct reach-and-freqeuncy protocol config for backward compatibility. + * + * Used when existing direct protocol configs of reach-and-freqeuncy measurements don't have + * methodologies. + */ +val DEFAULT_DIRECT_REACH_AND_FREQUENCY_PROTOCOL_CONFIG: ProtocolConfig.Direct = direct { + noiseMechanisms += DEFAULT_DIRECT_NOISE_MECHANISMS + deterministicCountDistinct = ProtocolConfig.Direct.DeterministicCountDistinct.getDefaultInstance() + liquidLegionsCountDistinct = ProtocolConfig.Direct.LiquidLegionsCountDistinct.getDefaultInstance() + deterministicDistribution = ProtocolConfig.Direct.DeterministicDistribution.getDefaultInstance() + liquidLegionsDistribution = ProtocolConfig.Direct.LiquidLegionsDistribution.getDefaultInstance() +} + +/** + * Default direct impression protocol config for backward compatibility. + * + * Used when existing direct protocol configs of impression measurements don't have methodologies. + */ +val DEFAULT_DIRECT_IMPRESSION_PROTOCOL_CONFIG = direct { + noiseMechanisms += DEFAULT_DIRECT_NOISE_MECHANISMS + deterministicCount = ProtocolConfig.Direct.DeterministicCount.getDefaultInstance() +} + +/** + * Default direct watch duration protocol config for backward compatibility. + * + * Used when existing direct protocol configs of watch duration measurements don't have + * methodologies. + */ +val DEFAULT_DIRECT_WATCH_DURATION_PROTOCOL_CONFIG = direct { + noiseMechanisms += DEFAULT_DIRECT_NOISE_MECHANISMS + deterministicSum = ProtocolConfig.Direct.DeterministicSum.getDefaultInstance() +} + +/** + * Default direct population protocol config for backward compatibility. + * + * Used when existing direct protocol configs of population measurements don't have methodologies. + */ +val DEFAULT_DIRECT_POPULATION_PROTOCOL_CONFIG = direct { + noiseMechanisms += DEFAULT_DIRECT_NOISE_MECHANISMS + deterministicCount = ProtocolConfig.Direct.DeterministicCount.getDefaultInstance() +} + /** Converts an internal [InternalMeasurement.State] to a public [State]. */ fun InternalMeasurement.State.toState(): State = when (this) { @@ -180,13 +254,29 @@ fun InternalDifferentialPrivacyParams.toDifferentialPrivacyParams(): Differentia /** Converts an internal [InternalNoiseMechanism] to a public [NoiseMechanism]. */ fun InternalNoiseMechanism.toNoiseMechanism(): NoiseMechanism { return when (this) { + InternalNoiseMechanism.NONE -> NoiseMechanism.NONE InternalNoiseMechanism.GEOMETRIC -> NoiseMechanism.GEOMETRIC InternalNoiseMechanism.DISCRETE_GAUSSIAN -> NoiseMechanism.DISCRETE_GAUSSIAN + InternalNoiseMechanism.CONTINUOUS_LAPLACE -> NoiseMechanism.CONTINUOUS_LAPLACE + InternalNoiseMechanism.CONTINUOUS_GAUSSIAN -> NoiseMechanism.CONTINUOUS_GAUSSIAN InternalNoiseMechanism.NOISE_MECHANISM_UNSPECIFIED, InternalNoiseMechanism.UNRECOGNIZED -> error("invalid internal noise mechanism.") } } +/** Converts a public [NoiseMechanism] to an internal [InternalNoiseMechanism]. */ +fun NoiseMechanism.toInternal(): InternalNoiseMechanism { + return when (this) { + NoiseMechanism.GEOMETRIC -> InternalNoiseMechanism.GEOMETRIC + NoiseMechanism.DISCRETE_GAUSSIAN -> InternalNoiseMechanism.DISCRETE_GAUSSIAN + NoiseMechanism.NONE -> InternalNoiseMechanism.NONE + NoiseMechanism.CONTINUOUS_LAPLACE -> InternalNoiseMechanism.CONTINUOUS_LAPLACE + NoiseMechanism.CONTINUOUS_GAUSSIAN -> InternalNoiseMechanism.CONTINUOUS_GAUSSIAN + NoiseMechanism.NOISE_MECHANISM_UNSPECIFIED, + NoiseMechanism.UNRECOGNIZED -> error("invalid internal noise mechanism.") + } +} + /** Converts an internal [InternalProtocolConfig] to a public [ProtocolConfig]. */ fun InternalProtocolConfig.toProtocolConfig( measurementTypeCase: MeasurementSpec.MeasurementTypeCase, @@ -209,80 +299,72 @@ fun InternalProtocolConfig.toProtocolConfig( } when (measurementType) { - ProtocolConfig.MeasurementType.REACH, - ProtocolConfig.MeasurementType.REACH_AND_FREQUENCY -> { - if (dataProviderCount == 1) { - protocols += protocol { direct = direct {} } - } else { - @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Protobuf enum fields are never null. - when (source.protocolCase) { - InternalProtocolConfig.ProtocolCase.LIQUID_LEGIONS_V2 -> { - protocols += protocol { - liquidLegionsV2 = liquidLegionsV2 { - if (source.liquidLegionsV2.hasSketchParams()) { - val sourceSketchParams = source.liquidLegionsV2.sketchParams - sketchParams = liquidLegionsSketchParams { - decayRate = sourceSketchParams.decayRate - maxSize = sourceSketchParams.maxSize - samplingIndicatorSize = sourceSketchParams.samplingIndicatorSize - } - } - if (source.liquidLegionsV2.hasDataProviderNoise()) { - dataProviderNoise = - source.liquidLegionsV2.dataProviderNoise.toDifferentialPrivacyParams() - } - ellipticCurveId = source.liquidLegionsV2.ellipticCurveId - maximumFrequency = source.liquidLegionsV2.maximumFrequency - // Use `GEOMETRIC` for unspecified InternalNoiseMechanism for old Measurements. - noiseMechanism = - if ( - source.liquidLegionsV2.noiseMechanism == - InternalNoiseMechanism.NOISE_MECHANISM_UNSPECIFIED - ) { - NoiseMechanism.GEOMETRIC - } else { - source.liquidLegionsV2.noiseMechanism.toNoiseMechanism() - } + ProtocolConfig.MeasurementType.REACH -> { + protocols += + // Direct protocol takes precedence + if (dataProviderCount == 1) { + protocol { + direct = + if (source.hasDirect()) { + source.direct.toDirect() + } else { + // For backward compatibility + DEFAULT_DIRECT_REACH_PROTOCOL_CONFIG } - } } - InternalProtocolConfig.ProtocolCase.REACH_ONLY_LIQUID_LEGIONS_V2 -> { - protocols += protocol { - reachOnlyLiquidLegionsV2 = reachOnlyLiquidLegionsV2 { - if (source.reachOnlyLiquidLegionsV2.hasSketchParams()) { - val sourceSketchParams = source.reachOnlyLiquidLegionsV2.sketchParams - sketchParams = reachOnlyLiquidLegionsSketchParams { - decayRate = sourceSketchParams.decayRate - maxSize = sourceSketchParams.maxSize - } - } - if (source.reachOnlyLiquidLegionsV2.hasDataProviderNoise()) { - dataProviderNoise = - source.reachOnlyLiquidLegionsV2.dataProviderNoise - .toDifferentialPrivacyParams() - } - ellipticCurveId = source.reachOnlyLiquidLegionsV2.ellipticCurveId - // Use `GEOMETRIC` for unspecified InternalNoiseMechanism for old Measurements. - noiseMechanism = - if ( - source.reachOnlyLiquidLegionsV2.noiseMechanism == - InternalNoiseMechanism.NOISE_MECHANISM_UNSPECIFIED - ) { - NoiseMechanism.GEOMETRIC - } else { - source.reachOnlyLiquidLegionsV2.noiseMechanism.toNoiseMechanism() - } + } else { + buildMpcProtocolConfig(source) + } + } + ProtocolConfig.MeasurementType.REACH_AND_FREQUENCY -> { + protocols += + // Direct protocol takes precedence + if (dataProviderCount == 1) { + protocol { + direct = + if (source.hasDirect()) { + source.direct.toDirect() + } else { + // For backward compatibility + DEFAULT_DIRECT_REACH_AND_FREQUENCY_PROTOCOL_CONFIG } - } } - InternalProtocolConfig.ProtocolCase.PROTOCOL_NOT_SET -> error("Protocol not specified") + } else { + buildMpcProtocolConfig(source) } + } + ProtocolConfig.MeasurementType.IMPRESSION -> { + protocols += protocol { + direct = + if (source.hasDirect()) { + source.direct.toDirect() + } else { + // For backward compatibility + DEFAULT_DIRECT_IMPRESSION_PROTOCOL_CONFIG + } + } + } + ProtocolConfig.MeasurementType.DURATION -> { + protocols += protocol { + direct = + if (source.hasDirect()) { + source.direct.toDirect() + } else { + // For backward compatibility + DEFAULT_DIRECT_WATCH_DURATION_PROTOCOL_CONFIG + } } } - ProtocolConfig.MeasurementType.IMPRESSION, - ProtocolConfig.MeasurementType.DURATION, ProtocolConfig.MeasurementType.POPULATION -> { - protocols += protocol { direct = direct {} } + protocols += protocol { + direct = + if (source.hasDirect()) { + source.direct.toDirect() + } else { + // For backward compatibility + DEFAULT_DIRECT_POPULATION_PROTOCOL_CONFIG + } + } } ProtocolConfig.MeasurementType.MEASUREMENT_TYPE_UNSPECIFIED, ProtocolConfig.MeasurementType.UNRECOGNIZED -> error("Invalid MeasurementType") @@ -290,6 +372,118 @@ fun InternalProtocolConfig.toProtocolConfig( } } +/** + * Builds a public [ProtocolConfig.Protocol] for MPC only from an internal [InternalProtocolConfig]. + */ +private fun buildMpcProtocolConfig( + protocolConfig: InternalProtocolConfig +): ProtocolConfig.Protocol { + @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Protobuf enum fields are never null. + return when (protocolConfig.protocolCase) { + InternalProtocolConfig.ProtocolCase.LIQUID_LEGIONS_V2 -> { + protocol { + liquidLegionsV2 = liquidLegionsV2 { + if (protocolConfig.liquidLegionsV2.hasSketchParams()) { + val sourceSketchParams = protocolConfig.liquidLegionsV2.sketchParams + sketchParams = liquidLegionsSketchParams { + decayRate = sourceSketchParams.decayRate + maxSize = sourceSketchParams.maxSize + samplingIndicatorSize = sourceSketchParams.samplingIndicatorSize + } + } + if (protocolConfig.liquidLegionsV2.hasDataProviderNoise()) { + dataProviderNoise = + protocolConfig.liquidLegionsV2.dataProviderNoise.toDifferentialPrivacyParams() + } + ellipticCurveId = protocolConfig.liquidLegionsV2.ellipticCurveId + maximumFrequency = protocolConfig.liquidLegionsV2.maximumFrequency + noiseMechanism = + if ( + protocolConfig.liquidLegionsV2.noiseMechanism == + InternalNoiseMechanism.NOISE_MECHANISM_UNSPECIFIED + ) { + // Use `GEOMETRIC` for unspecified InternalNoiseMechanism for old Measurements. + NoiseMechanism.GEOMETRIC + } else { + protocolConfig.liquidLegionsV2.noiseMechanism.toNoiseMechanism() + } + } + } + } + InternalProtocolConfig.ProtocolCase.REACH_ONLY_LIQUID_LEGIONS_V2 -> { + protocol { + reachOnlyLiquidLegionsV2 = reachOnlyLiquidLegionsV2 { + if (protocolConfig.reachOnlyLiquidLegionsV2.hasSketchParams()) { + val sourceSketchParams = protocolConfig.reachOnlyLiquidLegionsV2.sketchParams + sketchParams = reachOnlyLiquidLegionsSketchParams { + decayRate = sourceSketchParams.decayRate + maxSize = sourceSketchParams.maxSize + } + } + if (protocolConfig.reachOnlyLiquidLegionsV2.hasDataProviderNoise()) { + dataProviderNoise = + protocolConfig.reachOnlyLiquidLegionsV2.dataProviderNoise + .toDifferentialPrivacyParams() + } + ellipticCurveId = protocolConfig.reachOnlyLiquidLegionsV2.ellipticCurveId + noiseMechanism = + if ( + protocolConfig.reachOnlyLiquidLegionsV2.noiseMechanism == + InternalNoiseMechanism.NOISE_MECHANISM_UNSPECIFIED + ) { + NoiseMechanism.GEOMETRIC + } else { + protocolConfig.reachOnlyLiquidLegionsV2.noiseMechanism.toNoiseMechanism() + } + } + } + } + InternalProtocolConfig.ProtocolCase.DIRECT -> { + error("Direct protocol cannot be used for MPC-based Measurements") + } + InternalProtocolConfig.ProtocolCase.PROTOCOL_NOT_SET -> { + error("Protocol not specified") + } + } +} + +/** + * Converts an internal [InternalProtocolConfig.Direct] to a public [InternalProtocolConfig.Direct]. + */ +private fun InternalProtocolConfig.Direct.toDirect(): ProtocolConfig.Direct { + val source = this + + return direct { + noiseMechanisms += + source.noiseMechanismsList.map { internalNoiseMechanism -> + internalNoiseMechanism.toNoiseMechanism() + } + + if (source.hasDeterministicCountDistinct()) { + deterministicCountDistinct = + ProtocolConfig.Direct.DeterministicCountDistinct.getDefaultInstance() + } + if (source.hasDeterministicDistribution()) { + deterministicDistribution = + ProtocolConfig.Direct.DeterministicDistribution.getDefaultInstance() + } + if (source.hasDeterministicCount()) { + deterministicCount = ProtocolConfig.Direct.DeterministicCount.getDefaultInstance() + } + if (source.hasDeterministicSum()) { + deterministicSum = ProtocolConfig.Direct.DeterministicSum.getDefaultInstance() + } + if (source.hasLiquidLegionsCountDistinct()) { + liquidLegionsCountDistinct = + ProtocolConfig.Direct.LiquidLegionsCountDistinct.getDefaultInstance() + } + if (source.hasLiquidLegionsDistribution()) { + liquidLegionsDistribution = + ProtocolConfig.Direct.LiquidLegionsDistribution.getDefaultInstance() + } + } +} + /** Converts an internal [InternalModelSuite] to a public [ModelSuite]. */ fun InternalModelSuite.toModelSuite(): ModelSuite { val source = this @@ -711,7 +905,8 @@ fun Map.Entry.toDataProviderEntry(): DataProviderEntry fun Measurement.toInternal( measurementConsumerCertificateKey: MeasurementConsumerCertificateKey, dataProvidersMap: Map, - measurementSpecProto: MeasurementSpec + measurementSpecProto: MeasurementSpec, + internalNoiseMechanisms: List ): InternalMeasurement { val publicMeasurement = this @@ -748,6 +943,17 @@ fun Measurement.toInternal( liquidLegionsV2 = Llv2ProtocolConfig.duchyProtocolConfig } } + } else if (dataProvidersCount == 1) { + protocolConfig = internalProtocolConfig { + direct = + InternalProtocolConfigKt.direct { + noiseMechanisms += internalNoiseMechanisms + deterministicCountDistinct = + InternalProtocolConfig.Direct.DeterministicCountDistinct.getDefaultInstance() + liquidLegionsCountDistinct = + InternalProtocolConfig.Direct.LiquidLegionsCountDistinct.getDefaultInstance() + } + } } } MeasurementSpec.MeasurementTypeCase.REACH_AND_FREQUENCY -> { @@ -759,11 +965,57 @@ fun Measurement.toInternal( duchyProtocolConfig = duchyProtocolConfig { liquidLegionsV2 = Llv2ProtocolConfig.duchyProtocolConfig } + } else if (dataProvidersCount == 1) { + protocolConfig = internalProtocolConfig { + direct = + InternalProtocolConfigKt.direct { + noiseMechanisms += internalNoiseMechanisms + deterministicCountDistinct = + InternalProtocolConfig.Direct.DeterministicCountDistinct.getDefaultInstance() + liquidLegionsCountDistinct = + InternalProtocolConfig.Direct.LiquidLegionsCountDistinct.getDefaultInstance() + deterministicDistribution = + InternalProtocolConfigKt.DirectKt.deterministicDistribution { + maximumFrequency = DEFAULT_MAXIMUM_FREQUENCY_DIRECT_DISTRIBUTION + } + liquidLegionsDistribution = + InternalProtocolConfigKt.DirectKt.liquidLegionsDistribution { + maximumFrequency = DEFAULT_MAXIMUM_FREQUENCY_DIRECT_DISTRIBUTION + } + } + } + } + } + MeasurementSpec.MeasurementTypeCase.IMPRESSION -> { + protocolConfig = internalProtocolConfig { + direct = + InternalProtocolConfigKt.direct { + noiseMechanisms += internalNoiseMechanisms + deterministicCount = + InternalProtocolConfig.Direct.DeterministicCount.getDefaultInstance() + } + } + } + MeasurementSpec.MeasurementTypeCase.DURATION -> { + protocolConfig = internalProtocolConfig { + direct = + InternalProtocolConfigKt.direct { + noiseMechanisms += internalNoiseMechanisms + deterministicSum = + InternalProtocolConfig.Direct.DeterministicSum.getDefaultInstance() + } + } + } + MeasurementSpec.MeasurementTypeCase.POPULATION -> { + protocolConfig = internalProtocolConfig { + direct = + InternalProtocolConfigKt.direct { + noiseMechanisms += internalNoiseMechanisms + deterministicCount = + InternalProtocolConfig.Direct.DeterministicCount.getDefaultInstance() + } } } - MeasurementSpec.MeasurementTypeCase.IMPRESSION, - MeasurementSpec.MeasurementTypeCase.DURATION, - MeasurementSpec.MeasurementTypeCase.POPULATION, -> {} MeasurementSpec.MeasurementTypeCase.MEASUREMENTTYPE_NOT_SET -> error("MeasurementType not set.") } diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/service/internal/testing/MeasurementsServiceTest.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/service/internal/testing/MeasurementsServiceTest.kt index a33cd0566fb..9f62be51105 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/service/internal/testing/MeasurementsServiceTest.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/service/internal/testing/MeasurementsServiceTest.kt @@ -634,7 +634,7 @@ abstract class MeasurementsServiceTest { details = details.copy { clearDuchyProtocolConfig() - clearProtocolConfig() + protocolConfig = protocolConfig { direct = ProtocolConfig.Direct.getDefaultInstance() } } } @@ -673,7 +673,7 @@ abstract class MeasurementsServiceTest { details = details.copy { clearDuchyProtocolConfig() - clearProtocolConfig() + protocolConfig = protocolConfig { direct = ProtocolConfig.Direct.getDefaultInstance() } } } @@ -708,7 +708,9 @@ abstract class MeasurementsServiceTest { details = details.copy { clearDuchyProtocolConfig() - clearProtocolConfig() + protocolConfig = protocolConfig { + direct = ProtocolConfig.Direct.getDefaultInstance() + } } } @@ -1297,7 +1299,9 @@ abstract class MeasurementsServiceTest { measurementConsumer.certificate.externalCertificateId details = details.copy { - clearProtocolConfig() + protocolConfig = protocolConfig { + direct = ProtocolConfig.Direct.getDefaultInstance() + } clearDuchyProtocolConfig() } } diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/service/internal/testing/Population.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/service/internal/testing/Population.kt index b73572911ef..842df89a448 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/service/internal/testing/Population.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/service/internal/testing/Population.kt @@ -57,6 +57,7 @@ import org.wfanet.measurement.internal.kingdom.ModelRelease import org.wfanet.measurement.internal.kingdom.ModelReleasesGrpcKt.ModelReleasesCoroutineImplBase import org.wfanet.measurement.internal.kingdom.ModelSuite import org.wfanet.measurement.internal.kingdom.ModelSuitesGrpcKt.ModelSuitesCoroutineImplBase +import org.wfanet.measurement.internal.kingdom.ProtocolConfig import org.wfanet.measurement.internal.kingdom.ProtocolConfigKt import org.wfanet.measurement.internal.kingdom.account import org.wfanet.measurement.internal.kingdom.activateAccountRequest @@ -312,6 +313,7 @@ class Population(val clock: Clock, val idGenerator: IdGenerator) { apiVersion = API_VERSION measurementSpec = "MeasurementSpec".toByteStringUtf8() measurementSpecSignature = "MeasurementSpec signature".toByteStringUtf8() + protocolConfig = protocolConfig { direct = ProtocolConfig.Direct.getDefaultInstance() } } return createMeasurement( measurementsService, diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/service/internal/testing/RequisitionsServiceTest.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/service/internal/testing/RequisitionsServiceTest.kt index 17618701c38..741404cbcf4 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/service/internal/testing/RequisitionsServiceTest.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/service/internal/testing/RequisitionsServiceTest.kt @@ -652,7 +652,7 @@ abstract class RequisitionsServiceTest { measurementSpec = measurement.details.measurementSpec measurementSpecSignature = measurement.details.measurementSpecSignature state = Measurement.State.PENDING_REQUISITION_FULFILLMENT - protocolConfig = protocolConfig {} + protocolConfig = protocolConfig { direct = ProtocolConfig.Direct.getDefaultInstance() } dataProvidersCount = 1 } } diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/service/system/v1alpha/ProtoConversions.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/service/system/v1alpha/ProtoConversions.kt index f400f2a46d7..97caf28ba7b 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/service/system/v1alpha/ProtoConversions.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/service/system/v1alpha/ProtoConversions.kt @@ -390,7 +390,10 @@ fun InternalNoiseMechanism.toSystemNoiseMechanism(): NoiseMechanism { return when (this) { InternalNoiseMechanism.GEOMETRIC -> NoiseMechanism.GEOMETRIC InternalNoiseMechanism.DISCRETE_GAUSSIAN -> NoiseMechanism.DISCRETE_GAUSSIAN + InternalNoiseMechanism.CONTINUOUS_LAPLACE, + InternalNoiseMechanism.CONTINUOUS_GAUSSIAN, InternalNoiseMechanism.NOISE_MECHANISM_UNSPECIFIED, + InternalNoiseMechanism.NONE, InternalNoiseMechanism.UNRECOGNIZED -> error("invalid internal noise mechanism.") } } diff --git a/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/BUILD.bazel index 3c70c6a8f2e..96fad5c1ae0 100644 --- a/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/BUILD.bazel @@ -148,6 +148,7 @@ kt_jvm_library( "//src/main/proto/wfa/any_sketch:sketch_kt_jvm_proto", "//src/main/proto/wfa/measurement/api/v2alpha:certificates_service_kt_jvm_grpc_proto", "//src/main/proto/wfa/measurement/api/v2alpha:crypto_kt_jvm_proto", + "//src/main/proto/wfa/measurement/api/v2alpha:direct_computation_kt_jvm_proto", "//src/main/proto/wfa/measurement/api/v2alpha:event_group_kt_jvm_proto", "//src/main/proto/wfa/measurement/api/v2alpha:event_group_metadata_descriptors_service_kt_jvm_grpc_proto", "//src/main/proto/wfa/measurement/api/v2alpha:event_groups_service_kt_jvm_grpc_proto", diff --git a/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/EdpSimulator.kt b/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/EdpSimulator.kt index af035df457d..8d6360acf59 100644 --- a/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/EdpSimulator.kt +++ b/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/EdpSimulator.kt @@ -42,6 +42,8 @@ import org.wfanet.anysketch.crypto.elGamalPublicKey as anySketchElGamalPublicKey import org.wfanet.measurement.api.v2alpha.Certificate import org.wfanet.measurement.api.v2alpha.CertificatesGrpcKt.CertificatesCoroutineStub import org.wfanet.measurement.api.v2alpha.DataProviderKey +import org.wfanet.measurement.api.v2alpha.DeterministicCountDistinct +import org.wfanet.measurement.api.v2alpha.DeterministicDistribution import org.wfanet.measurement.api.v2alpha.DifferentialPrivacyParams import org.wfanet.measurement.api.v2alpha.ElGamalPublicKey import org.wfanet.measurement.api.v2alpha.EncryptionPublicKey @@ -154,7 +156,6 @@ class EdpSimulator( private val throttler: Throttler, private val privacyBudgetManager: PrivacyBudgetManager, private val trustedCertificates: Map, - private val directNoiseMechanism: DirectNoiseMechanism, private val sketchEncrypter: SketchEncrypter = SketchEncrypter.Default, private val random: Random = Random, private val compositionMechanism: CompositionMechanism, @@ -550,26 +551,53 @@ class EdpSimulator( val protocols: List = requisition.protocolConfig.protocolsList if (protocols.any { it.hasDirect() }) { + val directProtocolConfig = + requisition.protocolConfig.protocolsList.first { it.hasDirect() }.direct + val directNoiseMechanismOptions = + directProtocolConfig.noiseMechanismsList + .mapNotNull { protocolConfigNoiseMechanism -> + protocolConfigNoiseMechanism.toDirectNoiseMechanism() + } + .toSet() + if (measurementSpec.hasReach() || measurementSpec.hasReachAndFrequency()) { + val directProtocol = + DirectProtocol( + directProtocolConfig, + selectReachAndFrequencyNoiseMechanism(directNoiseMechanismOptions) + ) fulfillDirectReachAndFrequencyMeasurement( requisition, measurementSpec, requisitionSpec.nonce, - eventGroupSpecs + eventGroupSpecs, + directProtocol ) } else if (measurementSpec.hasDuration()) { + val directProtocol = + DirectProtocol( + directProtocolConfig, + selectImpressionNoiseMechanism(directNoiseMechanismOptions) + ) fulfillDurationMeasurement( requisition, requisitionSpec, measurementSpec, - eventGroupSpecs + eventGroupSpecs, + directProtocol ) } else if (measurementSpec.hasImpression()) { + val directProtocol = + DirectProtocol( + directProtocolConfig, + selectWatchDurationNoiseMechanism(directNoiseMechanismOptions) + ) fulfillImpressionMeasurement( requisition, requisitionSpec, measurementSpec, - eventGroupSpecs + eventGroupSpecs, + directProtocol ) } else { logger.log( @@ -644,6 +672,11 @@ class EdpSimulator( } } + private data class DirectProtocol( + val directProtocolConfig: ProtocolConfig.Direct, + val selectedDirectNoiseMechanism: DirectNoiseMechanism + ) + /** * Builds [EventQuery.EventGroupSpec]s from a [requisitionSpec] by fetching [EventGroup]s. * @@ -769,6 +802,7 @@ class EdpSimulator( requisitionName: String, measurementSpec: MeasurementSpec, eventSpecs: Iterable, + directNoiseMechanism: DirectNoiseMechanism ) { logger.info( "chargeDirectPrivacyBudget with $compositionMechanism composition mechanism...", @@ -785,7 +819,7 @@ class EdpSimulator( ) ) CompositionMechanism.ACDP -> { - if (directNoiseMechanism != DirectNoiseMechanism.GAUSSIAN) { + if (directNoiseMechanism != DirectNoiseMechanism.CONTINUOUS_GAUSSIAN) { throw PrivacyBudgetManagerException( PrivacyBudgetManagerExceptionType.INCORRECT_NOISE_MECHANISM ) @@ -1044,12 +1078,14 @@ class EdpSimulator( requisition: Requisition, measurementSpec: MeasurementSpec, nonce: Long, - eventGroupSpecs: Iterable + eventGroupSpecs: Iterable, + directProtocol: DirectProtocol ) { chargeDirectPrivacyBudget( requisition.name, measurementSpec, eventGroupSpecs.map { it.spec }, + directProtocol.selectedDirectNoiseMechanism ) logger.info("Calculating direct reach and frequency...") @@ -1093,7 +1129,8 @@ class EdpSimulator( ) } - val measurementResult = buildDirectMeasurementResult(measurementSpec, sampledVids.asIterable()) + val measurementResult = + buildDirectMeasurementResult(directProtocol, measurementSpec, sampledVids.asIterable()) fulfillDirectMeasurement(requisition, measurementSpec, nonce, measurementResult) } @@ -1110,9 +1147,9 @@ class EdpSimulator( override val variance: Double get() = distribution.numericalVariance } - DirectNoiseMechanism.LAPLACE -> + DirectNoiseMechanism.CONTINUOUS_LAPLACE -> LaplaceNoiser(DpParams(privacyParams.epsilon, privacyParams.delta), random.asJavaRandom()) - DirectNoiseMechanism.GAUSSIAN -> + DirectNoiseMechanism.CONTINUOUS_GAUSSIAN -> GaussianNoiser(DpParams(privacyParams.epsilon, privacyParams.delta), random.asJavaRandom()) } @@ -1121,11 +1158,13 @@ class EdpSimulator( * * @param reachValue Direct reach value. * @param privacyParams Differential privacy params for reach. + * @param directNoiseMechanism Selected noise mechanism for direct reach. * @return Noised reach value. */ private fun addReachPublisherNoise( reachValue: Int, - privacyParams: DifferentialPrivacyParams + privacyParams: DifferentialPrivacyParams, + directNoiseMechanism: DirectNoiseMechanism ): Int { val reachNoiser: AbstractNoiser = getPublisherNoiser(privacyParams, directNoiseMechanism, random) @@ -1139,12 +1178,14 @@ class EdpSimulator( * @param reachValue Direct reach value. * @param frequencyMap Direct frequency. * @param privacyParams Differential privacy params for frequency map. + * @param directNoiseMechanism Selected noise mechanism for direct frequency. * @return Noised frequency map. */ private fun addFrequencyPublisherNoise( reachValue: Int, frequencyMap: Map, privacyParams: DifferentialPrivacyParams, + directNoiseMechanism: DirectNoiseMechanism ): Map { val frequencyNoiser: AbstractNoiser = getPublisherNoiser(privacyParams, directNoiseMechanism, random) @@ -1157,57 +1198,127 @@ class EdpSimulator( /** * Build [Measurement.Result] of the measurement type specified in [MeasurementSpec]. * + * @param requisition Requisition. * @param measurementSpec Measurement spec. - * @param sampledVids sampled event VIDs + * @param samples sampled events. * @return [Measurement.Result]. */ private fun buildDirectMeasurementResult( + directProtocol: DirectProtocol, measurementSpec: MeasurementSpec, - sampledVids: Iterable, + samples: Iterable, ): Measurement.Result { + val directProtocolConfig = directProtocol.directProtocolConfig + val directNoiseMechanism = directProtocol.selectedDirectNoiseMechanism + val protocolConfigNoiseMechanism = directNoiseMechanism.toProtocolConfigNoiseMechanism() + @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Protobuf enum fields cannot be null. return when (measurementSpec.measurementTypeCase) { MeasurementSpec.MeasurementTypeCase.REACH_AND_FREQUENCY -> { - val (sampledReachValue, frequencyMap) = - MeasurementResults.computeReachAndFrequency(sampledVids) + if (!directProtocolConfig.hasDeterministicCountDistinct()) { + throw RequisitionRefusalException( + Requisition.Refusal.Justification.DECLINED, + "No valid methodologies for direct reach computation." + ) + } + if (!directProtocolConfig.hasDeterministicDistribution()) { + throw RequisitionRefusalException( + Requisition.Refusal.Justification.DECLINED, + "No valid methodologies for direct frequency distribution computation." + ) + } + + val (sampledReachValue, frequencyMap) = MeasurementResults.computeReachAndFrequency(samples) logger.info("Adding $directNoiseMechanism publisher noise to direct reach and frequency...") val sampledNoisedReachValue = addReachPublisherNoise( sampledReachValue, - measurementSpec.reachAndFrequency.reachPrivacyParams + measurementSpec.reachAndFrequency.reachPrivacyParams, + directNoiseMechanism ) val noisedFrequencyMap = addFrequencyPublisherNoise( sampledReachValue, frequencyMap, measurementSpec.reachAndFrequency.frequencyPrivacyParams, + directNoiseMechanism ) val scaledNoisedReachValue = (sampledNoisedReachValue / measurementSpec.vidSamplingInterval.width).toLong() MeasurementKt.result { - reach = reach { value = scaledNoisedReachValue } + reach = reach { + value = scaledNoisedReachValue + this.noiseMechanism = protocolConfigNoiseMechanism + deterministicCountDistinct = DeterministicCountDistinct.getDefaultInstance() + } frequency = frequency { relativeFrequencyDistribution.putAll(noisedFrequencyMap.mapKeys { it.key.toLong() }) + this.noiseMechanism = protocolConfigNoiseMechanism + deterministicDistribution = DeterministicDistribution.getDefaultInstance() + } + } + } + MeasurementSpec.MeasurementTypeCase.IMPRESSION -> { + MeasurementKt.result { + impression = impression { + // Use externalDataProviderId since it's a known value the FrontendSimulator can verify. + // TODO: Calculate impression from data. + value = apiIdToExternalId(DataProviderKey.fromName(edpData.name)!!.dataProviderId) + noiseMechanism = protocolConfigNoiseMechanism + // TODO(@riemanli): specify impression computation methodology once the real impression + // calculation is done. + } + } + } + MeasurementSpec.MeasurementTypeCase.DURATION -> { + val externalDataProviderId = + apiIdToExternalId(DataProviderKey.fromName(edpData.name)!!.dataProviderId) + MeasurementKt.result { + watchDuration = watchDuration { + value = duration { + // Use a value based on the externalDataProviderId since it's a known value the + // MeasurementConsumerSimulator can verify. + seconds = log2(externalDataProviderId.toDouble()).toLong() + } + noiseMechanism = protocolConfigNoiseMechanism + // TODO(@riemanli): specify duration computation methodology once the real duration + // calculation is done. } } } - MeasurementSpec.MeasurementTypeCase.IMPRESSION, - MeasurementSpec.MeasurementTypeCase.DURATION, MeasurementSpec.MeasurementTypeCase.POPULATION -> { error("Measurement type not supported.") } MeasurementSpec.MeasurementTypeCase.REACH -> { - val sampledReachValue = MeasurementResults.computeReach(sampledVids) - logger.info("Adding $directNoiseMechanism publisher noise to direct reach...") + if (!directProtocolConfig.hasDeterministicCountDistinct()) { + throw RequisitionRefusalException( + Requisition.Refusal.Justification.DECLINED, + "No valid methodologies for direct reach computation." + ) + } + + val sampledReachValue = MeasurementResults.computeReach(samples) + + logger.info("Adding $directNoiseMechanism publisher noise to direct reach for reach-only") val sampledNoisedReachValue = - addReachPublisherNoise(sampledReachValue, measurementSpec.reach.privacyParams) + addReachPublisherNoise( + sampledReachValue, + measurementSpec.reach.privacyParams, + directNoiseMechanism + ) val scaledNoisedReachValue = (sampledNoisedReachValue / measurementSpec.vidSamplingInterval.width).toLong() - MeasurementKt.result { reach = reach { value = scaledNoisedReachValue } } + MeasurementKt.result { + reach = reach { + value = scaledNoisedReachValue + this.noiseMechanism = protocolConfigNoiseMechanism + deterministicCountDistinct = DeterministicCountDistinct.getDefaultInstance() + } + } } MeasurementSpec.MeasurementTypeCase.MEASUREMENTTYPE_NOT_SET -> { error("Measurement type not set.") @@ -1215,26 +1326,80 @@ class EdpSimulator( } } + /** + * Selects the most preferred [DirectNoiseMechanism] for reach and frequency measurements from the + * overlap of a list of preferred [DirectNoiseMechanism] and a set of [DirectNoiseMechanism] + * [options]. + */ + private fun selectReachAndFrequencyNoiseMechanism( + options: Set + ): DirectNoiseMechanism { + val preferences = + when (compositionMechanism) { + CompositionMechanism.DP_ADVANCED -> { + DIRECT_REACH_AND_FREQUENCY_NOISE_MECHANISM_PREFERENCES + } + CompositionMechanism.ACDP -> { + DIRECT_REACH_AND_FREQUENCY_ACDP_NOISE_MECHANISM_PREFERENCES + } + } + + return preferences.firstOrNull { preference -> options.contains(preference) } + ?: throw RequisitionRefusalException( + Requisition.Refusal.Justification.SPEC_INVALID, + "No valid noise mechanism option for reach or frequency measurements." + ) + } + + /** + * Selects the most preferred [DirectNoiseMechanism] for impression measurements from the overlap + * of a list of preferred [DirectNoiseMechanism] and a set of [DirectNoiseMechanism] [options]. + */ + private fun selectImpressionNoiseMechanism( + options: Set + ): DirectNoiseMechanism { + val preferences = listOf(DirectNoiseMechanism.NONE) + + return preferences.firstOrNull { preference -> options.contains(preference) } + ?: throw RequisitionRefusalException( + Requisition.Refusal.Justification.SPEC_INVALID, + "No valid noise mechanism option for impression measurements." + ) + } + + /** + * Selects the most preferred [DirectNoiseMechanism] for watch duration measurements from the + * overlap of a list of preferred [DirectNoiseMechanism] and a set of [DirectNoiseMechanism] + * [options]. + */ + private fun selectWatchDurationNoiseMechanism( + options: Set + ): DirectNoiseMechanism { + val preferences = listOf(DirectNoiseMechanism.NONE) + + return preferences.firstOrNull { preference -> options.contains(preference) } + ?: throw RequisitionRefusalException( + Requisition.Refusal.Justification.SPEC_INVALID, + "No valid noise mechanism option for watch duration measurements." + ) + } + private suspend fun fulfillImpressionMeasurement( requisition: Requisition, requisitionSpec: RequisitionSpec, measurementSpec: MeasurementSpec, - eventGroupSpecs: Iterable + eventGroupSpecs: Iterable, + directProtocol: DirectProtocol ) { chargeDirectPrivacyBudget( requisition.name, measurementSpec, eventGroupSpecs.map { it.spec }, + directProtocol.selectedDirectNoiseMechanism ) val measurementResult = - MeasurementKt.result { - impression = impression { - // Use externalDataProviderId since it's a known value the FrontendSimulator can verify. - // TODO: Calculate impression from data. - value = apiIdToExternalId(DataProviderKey.fromName(edpData.name)!!.dataProviderId) - } - } + buildDirectMeasurementResult(directProtocol, measurementSpec, listOf().asIterable()) fulfillDirectMeasurement(requisition, measurementSpec, requisitionSpec.nonce, measurementResult) } @@ -1243,26 +1408,18 @@ class EdpSimulator( requisition: Requisition, requisitionSpec: RequisitionSpec, measurementSpec: MeasurementSpec, - eventGroupSpecs: Iterable + eventGroupSpecs: Iterable, + directProtocol: DirectProtocol ) { chargeDirectPrivacyBudget( requisition.name, measurementSpec, eventGroupSpecs.map { it.spec }, + directProtocol.selectedDirectNoiseMechanism ) - val externalDataProviderId = - apiIdToExternalId(DataProviderKey.fromName(edpData.name)!!.dataProviderId) val measurementResult = - MeasurementKt.result { - watchDuration = watchDuration { - value = duration { - // Use a value based on the externalDataProviderId since it's a known value the - // MeasurementConsumerSimulator can verify. - seconds = log2(externalDataProviderId.toDouble()).toLong() - } - } - } + buildDirectMeasurementResult(directProtocol, measurementSpec, listOf().asIterable()) fulfillDirectMeasurement(requisition, measurementSpec, requisitionSpec.nonce, measurementResult) } @@ -1309,6 +1466,47 @@ class EdpSimulator( EVENT_TEMPLATE_TYPES.map { eventTemplate { type = it.fullName } } private val logger: Logger = Logger.getLogger(this::class.java.name) + + // The noise mechanisms for reach and frequency are in order of preference. + private val DIRECT_REACH_AND_FREQUENCY_NOISE_MECHANISM_PREFERENCES = + listOf( + DirectNoiseMechanism.CONTINUOUS_LAPLACE, + DirectNoiseMechanism.CONTINUOUS_GAUSSIAN, + ) + // The direct noise mechanisms for ACDP composition in PBM for reach and frequency in order + // of preference. Currently, ACDP composition only supports CONTINUOUS_GAUSSIAN noise for direct + // measurements + private val DIRECT_REACH_AND_FREQUENCY_ACDP_NOISE_MECHANISM_PREFERENCES = + listOf( + DirectNoiseMechanism.CONTINUOUS_GAUSSIAN, + ) + } +} + +private fun DirectNoiseMechanism.toProtocolConfigNoiseMechanism(): NoiseMechanism { + return when (this) { + DirectNoiseMechanism.NONE -> NoiseMechanism.NONE + DirectNoiseMechanism.CONTINUOUS_LAPLACE -> NoiseMechanism.CONTINUOUS_LAPLACE + DirectNoiseMechanism.CONTINUOUS_GAUSSIAN -> NoiseMechanism.CONTINUOUS_GAUSSIAN + } +} + +/** + * Converts a [NoiseMechanism] to a nullable [DirectNoiseMechanism]. + * + * @return [DirectNoiseMechanism] when there is a matched, otherwise null. + */ +private fun NoiseMechanism.toDirectNoiseMechanism(): DirectNoiseMechanism? { + return when (this) { + NoiseMechanism.NONE -> DirectNoiseMechanism.NONE + NoiseMechanism.CONTINUOUS_LAPLACE -> DirectNoiseMechanism.CONTINUOUS_LAPLACE + NoiseMechanism.CONTINUOUS_GAUSSIAN -> DirectNoiseMechanism.CONTINUOUS_GAUSSIAN + NoiseMechanism.NOISE_MECHANISM_UNSPECIFIED, + NoiseMechanism.GEOMETRIC, + NoiseMechanism.DISCRETE_GAUSSIAN, + NoiseMechanism.UNRECOGNIZED -> { + null + } } } diff --git a/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/EdpSimulatorFlags.kt b/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/EdpSimulatorFlags.kt index c7b6a737a8e..f3053349586 100644 --- a/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/EdpSimulatorFlags.kt +++ b/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/EdpSimulatorFlags.kt @@ -17,7 +17,6 @@ package org.wfanet.measurement.loadtest.dataprovider import java.io.File import java.time.Duration import org.wfanet.measurement.common.grpc.TlsFlags -import org.wfanet.measurement.eventdataprovider.noiser.DirectNoiseMechanism import org.wfanet.measurement.eventdataprovider.privacybudgetmanagement.CompositionMechanism import org.wfanet.measurement.loadtest.KingdomPublicApiFlags import org.wfanet.measurement.loadtest.RequisitionFulfillmentServiceFlags @@ -102,14 +101,6 @@ class EdpSimulatorFlags { var randomSeed: Long? = null private set - @CommandLine.Option( - names = ["--direct-noise-mechanism"], - description = ["Differential privacy noise mechanism for direct measurements"], - defaultValue = "LAPLACE", - ) - lateinit var directNoiseMechanism: DirectNoiseMechanism - private set - @CommandLine.Option( names = ["--composition-mechanism"], description = ["Composition mechanism in Privacy Budget Manager"], diff --git a/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/EdpSimulatorRunner.kt b/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/EdpSimulatorRunner.kt index a80c3dd989f..d0a5fb98d26 100644 --- a/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/EdpSimulatorRunner.kt +++ b/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/EdpSimulatorRunner.kt @@ -99,7 +99,6 @@ abstract class EdpSimulatorRunner : Runnable { MinimumIntervalThrottler(Clock.systemUTC(), flags.throttlerMinimumInterval), createNoOpPrivacyBudgetManager(), clientCerts.trustedCertificates, - flags.directNoiseMechanism, random = random, compositionMechanism = flags.compositionMechanism, ) diff --git a/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/MeasurementResults.kt b/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/MeasurementResults.kt index d6943299b84..e9dd51ffbbb 100644 --- a/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/MeasurementResults.kt +++ b/src/main/kotlin/org/wfanet/measurement/loadtest/dataprovider/MeasurementResults.kt @@ -20,7 +20,10 @@ package org.wfanet.measurement.loadtest.dataprovider object MeasurementResults { data class ReachAndFrequency(val reach: Int, val relativeFrequencyDistribution: Map) - /** Computes reach and frequency using the "deterministic count distinct" methodology. */ + /** + * Computes reach and frequency using the "deterministic count distinct" methodology and the + * "deterministic distribution" methodology. + */ fun computeReachAndFrequency( sampledVids: Iterable, maxFrequency: Int = Int.MAX_VALUE diff --git a/src/main/proto/wfa/measurement/api/v2alpha/BUILD.bazel b/src/main/proto/wfa/measurement/api/v2alpha/BUILD.bazel index 877f6b1364d..fd7d93b6be6 100644 --- a/src/main/proto/wfa/measurement/api/v2alpha/BUILD.bazel +++ b/src/main/proto/wfa/measurement/api/v2alpha/BUILD.bazel @@ -694,3 +694,18 @@ kt_jvm_proto_library( ], deps = [":date_interval_java_proto"], ) + +java_proto_library( + name = "direct_computation_java_proto", + deps = [ + "@wfa_measurement_proto//src/main/proto/wfa/measurement/api/v2alpha:direct_computation_proto", + ], +) + +kt_jvm_proto_library( + name = "direct_computation_kt_jvm_proto", + srcs = [ + "@wfa_measurement_proto//src/main/proto/wfa/measurement/api/v2alpha:direct_computation_proto", + ], + deps = [":direct_computation_java_proto"], +) diff --git a/src/main/proto/wfa/measurement/internal/kingdom/BUILD.bazel b/src/main/proto/wfa/measurement/internal/kingdom/BUILD.bazel index 3e80850c74c..514ed11c5b2 100644 --- a/src/main/proto/wfa/measurement/internal/kingdom/BUILD.bazel +++ b/src/main/proto/wfa/measurement/internal/kingdom/BUILD.bazel @@ -135,6 +135,10 @@ proto_and_java_proto_library( name = "differential_privacy", ) +proto_and_java_proto_library( + name = "direct_computation", +) + proto_and_java_proto_library( name = "duchy_protocol_config", deps = [ diff --git a/src/main/proto/wfa/measurement/internal/kingdom/direct_computation.proto b/src/main/proto/wfa/measurement/internal/kingdom/direct_computation.proto new file mode 100644 index 00000000000..c02b36cf198 --- /dev/null +++ b/src/main/proto/wfa/measurement/internal/kingdom/direct_computation.proto @@ -0,0 +1,61 @@ +// Copyright 2023 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package wfa.measurement.internal.kingdom; + +option java_package = "org.wfanet.measurement.internal.kingdom"; +option java_multiple_files = true; + +// Parameters used when applying the deterministic count distinct methodology. +message DeterministicCountDistinct {} + +// Parameters used when applying the deterministic distribution methodology. +message DeterministicDistribution {} + +// Parameters used when applying the deterministic count methodology. +message DeterministicCount {} + +// Parameters used when applying the deterministic sum methodology. +message DeterministicSum {} + +// Parameters used when applying the Liquid Legions count distinct methodology. +// +// May only be set when the measurement type is REACH. +// To obtain differentially private result, one should add a DP noise to the +// estimate number of sampled registers instead of the target estimate. +message LiquidLegionsCountDistinct { + // The decay rate of the Liquid Legions sketch. Required. + double decay_rate = 1; + + // The maximum size of the Liquid Legions sketch. Required. + int64 max_size = 2; +} + +// Parameters used when applying the Liquid Legions distribution methodology. +// +// May only be set when the measurement type is REACH_AND_FREQUENCY. +// `Requisition`s using this protocol can be fulfilled by calling +// RequisitionFulfillment/FulfillRequisition with an encrypted sketch. +message LiquidLegionsDistribution { + // The decay rate of the Liquid Legions sketch. Required. + double decay_rate = 1; + + // The maximum size of the Liquid Legions sketch. Required. + int64 max_size = 2; + + // The size of the distribution of the sampling indicator value. Required. + int64 sampling_indicator_size = 3; +} diff --git a/src/main/proto/wfa/measurement/internal/kingdom/protocol_config.proto b/src/main/proto/wfa/measurement/internal/kingdom/protocol_config.proto index d8f447eeb4e..a235c52975c 100644 --- a/src/main/proto/wfa/measurement/internal/kingdom/protocol_config.proto +++ b/src/main/proto/wfa/measurement/internal/kingdom/protocol_config.proto @@ -32,8 +32,73 @@ message ProtocolConfig { // The mechanism used to generate noise in computations. enum NoiseMechanism { NOISE_MECHANISM_UNSPECIFIED = 0; + NONE = 3; GEOMETRIC = 1; DISCRETE_GAUSSIAN = 2; + CONTINUOUS_LAPLACE = 4; + CONTINUOUS_GAUSSIAN = 5; + } + + // Configuration for the Direct protocol. + // + // The `DataProvider` may choose from the specified noise mechanisms and + // methodologies. + message Direct { + // Configuration parameters for the deterministic count distinct + // methodology. + message DeterministicCountDistinct {} + // Configuration parameters for the deterministic distribution methodology. + message DeterministicDistribution { + // The maximum frequency to reveal in the distribution. + int32 maximum_frequency = 1; + } + // Configuration parameters for the deterministic count methodology. + message DeterministicCount {} + // Configuration parameters for the deterministic sum methodology. + message DeterministicSum {} + // Configuration parameters for the direct Liquid Legions distribution + // methodology. + message LiquidLegionsDistribution { + // The maximum frequency to reveal in the distribution. + int32 maximum_frequency = 1; + } + // Configuration parameters for the direct Liquid Legions count distinct + // methodology. + message LiquidLegionsCountDistinct {} + + // The set of mechanisms that can be used to generate noise during + // computation. + repeated NoiseMechanism noise_mechanisms = 1; + + // Deterministic count distinct methodology. + // + // Can be used in reach computations. + DeterministicCountDistinct deterministic_count_distinct = 2; + + // Deterministic distribution methodology. + // + // Can be used in frequency computations. + DeterministicDistribution deterministic_distribution = 3; + + // Deterministic count methodology. + // + // Can be used in impression computations. + DeterministicCount deterministic_count = 4; + + // Deterministic sum methodology. + // + // Can be used in watch duration computations. + DeterministicSum deterministic_sum = 5; + + // Liquid Legions count distinct methodology. + // + // Can be used in reach computations. + LiquidLegionsCountDistinct liquid_legions_count_distinct = 6; + + // Liquid Legions distribution methodology. + // + // Can be used in frequency computations. + LiquidLegionsDistribution liquid_legions_distribution = 7; } // Configuration for Liquid Legions v2 protocols. @@ -72,6 +137,9 @@ message ProtocolConfig { // // Must only be set when the measurement type is REACH. LiquidLegionsV2 reach_only_liquid_legions_v2 = 4; + + // Direct protocol. + Direct direct = 5; } } diff --git a/src/test/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/MeasurementsServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/MeasurementsServiceTest.kt index 6cbeefdeba0..1abef7bfe7b 100644 --- a/src/test/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/MeasurementsServiceTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/MeasurementsServiceTest.kt @@ -98,7 +98,9 @@ import org.wfanet.measurement.internal.kingdom.MeasurementKt as InternalMeasurem import org.wfanet.measurement.internal.kingdom.MeasurementKt.resultInfo import org.wfanet.measurement.internal.kingdom.MeasurementsGrpcKt import org.wfanet.measurement.internal.kingdom.ProtocolConfig as InternalProtocolConfig +import org.wfanet.measurement.internal.kingdom.ProtocolConfig.NoiseMechanism as InternalNoiseMechanism import org.wfanet.measurement.internal.kingdom.ProtocolConfigKt as InternalProtocolConfigKt +import org.wfanet.measurement.internal.kingdom.ProtocolConfigKt.direct import org.wfanet.measurement.internal.kingdom.StreamMeasurementsRequest import org.wfanet.measurement.internal.kingdom.StreamMeasurementsRequestKt import org.wfanet.measurement.internal.kingdom.cancelMeasurementRequest as internalCancelMeasurementRequest @@ -171,6 +173,7 @@ class MeasurementsServiceTest { service = MeasurementsService( MeasurementsGrpcKt.MeasurementsCoroutineStub(grpcTestServerRule.channel), + NOISE_MECHANISMS ) } @@ -409,7 +412,20 @@ class MeasurementsServiceTest { runBlocking { service.createMeasurement(request) } } - assertThat(response).ignoringRepeatedFieldOrder().isEqualTo(measurement) + assertThat(response) + .ignoringRepeatedFieldOrder() + .isEqualTo( + measurement.copy { + protocolConfig = + protocolConfig.copy { + protocols.clear() + protocols += + ProtocolConfigKt.protocol { + direct = DEFAULT_DIRECT_REACH_AND_FREQUENCY_PROTOCOL_CONFIG + } + } + } + ) verifyProtoArgument( internalMeasurementsMock, MeasurementsGrpcKt.MeasurementsCoroutineImplBase::createMeasurement @@ -420,6 +436,13 @@ class MeasurementsServiceTest { internalMeasurement.copy { clearExternalMeasurementId() clearUpdateTime() + + details = + details.copy { + protocolConfig = internalProtocolConfig { + direct = DEFAULT_INTERNAL_DIRECT_REACH_AND_FREQUENCY_PROTOCOL_CONFIG + } + } } } ) @@ -486,7 +509,18 @@ class MeasurementsServiceTest { runBlocking { service.createMeasurement(request) } } - assertThat(response).ignoringRepeatedFieldOrder().isEqualTo(measurement) + assertThat(response) + .ignoringRepeatedFieldOrder() + .isEqualTo( + measurement.copy { + protocolConfig = + protocolConfig.copy { + protocols.clear() + protocols += + ProtocolConfigKt.protocol { direct = DEFAULT_DIRECT_IMPRESSION_PROTOCOL_CONFIG } + } + } + ) verifyProtoArgument( internalMeasurementsMock, MeasurementsGrpcKt.MeasurementsCoroutineImplBase::createMeasurement @@ -497,6 +531,13 @@ class MeasurementsServiceTest { internalMeasurement.copy { clearExternalMeasurementId() clearUpdateTime() + + details = + details.copy { + protocolConfig = internalProtocolConfig { + direct = DEFAULT_INTERNAL_DIRECT_IMPRESSION_PROTOCOL_CONFIG + } + } } } ) @@ -563,7 +604,18 @@ class MeasurementsServiceTest { runBlocking { service.createMeasurement(request) } } - assertThat(response).ignoringRepeatedFieldOrder().isEqualTo(measurement) + assertThat(response) + .ignoringRepeatedFieldOrder() + .isEqualTo( + measurement.copy { + protocolConfig = + protocolConfig.copy { + protocols.clear() + protocols += + ProtocolConfigKt.protocol { direct = DEFAULT_DIRECT_WATCH_DURATION_PROTOCOL_CONFIG } + } + } + ) verifyProtoArgument( internalMeasurementsMock, MeasurementsGrpcKt.MeasurementsCoroutineImplBase::createMeasurement @@ -574,6 +626,13 @@ class MeasurementsServiceTest { internalMeasurement.copy { clearExternalMeasurementId() clearUpdateTime() + + details = + details.copy { + protocolConfig = internalProtocolConfig { + direct = DEFAULT_INTERNAL_DIRECT_WATCH_DURATION_PROTOCOL_CONFIG + } + } } } ) @@ -635,7 +694,18 @@ class MeasurementsServiceTest { runBlocking { service.createMeasurement(request) } } - assertThat(response).ignoringRepeatedFieldOrder().isEqualTo(measurement) + assertThat(response) + .ignoringRepeatedFieldOrder() + .isEqualTo( + measurement.copy { + protocolConfig = + protocolConfig.copy { + protocols.clear() + protocols += + ProtocolConfigKt.protocol { direct = DEFAULT_DIRECT_POPULATION_PROTOCOL_CONFIG } + } + } + ) verifyProtoArgument( internalMeasurementsMock, MeasurementsGrpcKt.MeasurementsCoroutineImplBase::createMeasurement @@ -646,6 +716,13 @@ class MeasurementsServiceTest { internalMeasurement.copy { clearExternalMeasurementId() clearUpdateTime() + + details = + details.copy { + protocolConfig = internalProtocolConfig { + direct = DEFAULT_INTERNAL_DIRECT_POPULATION_PROTOCOL_CONFIG + } + } } } ) @@ -1771,6 +1848,20 @@ class MeasurementsServiceTest { ) } + private val NOISE_MECHANISMS = + listOf( + ProtocolConfig.NoiseMechanism.NONE, + ProtocolConfig.NoiseMechanism.CONTINUOUS_LAPLACE, + ProtocolConfig.NoiseMechanism.CONTINUOUS_GAUSSIAN, + ) + + private val INTERNAL_NOISE_MECHANISMS = + listOf( + InternalNoiseMechanism.NONE, + InternalNoiseMechanism.CONTINUOUS_LAPLACE, + InternalNoiseMechanism.CONTINUOUS_GAUSSIAN, + ) + private val DIFFERENTIAL_PRIVACY_PARAMS = differentialPrivacyParams { epsilon = 1.0 delta = 1.0 @@ -1789,7 +1880,7 @@ class MeasurementsServiceTest { epsilon = 2.1 delta = 3.3 } - noiseMechanism = InternalProtocolConfig.NoiseMechanism.GEOMETRIC + noiseMechanism = InternalNoiseMechanism.GEOMETRIC } } @@ -1829,7 +1920,7 @@ class MeasurementsServiceTest { epsilon = 2.1 delta = 3.3 } - noiseMechanism = InternalProtocolConfig.NoiseMechanism.GEOMETRIC + noiseMechanism = InternalNoiseMechanism.GEOMETRIC } } @@ -2003,5 +2094,57 @@ class MeasurementsServiceTest { } } } + private val DEFAULT_INTERNAL_DIRECT_NOISE_MECHANISMS: List = + listOf( + InternalNoiseMechanism.NONE, + InternalNoiseMechanism.CONTINUOUS_LAPLACE, + InternalNoiseMechanism.CONTINUOUS_GAUSSIAN + ) + + private val DEFAULT_INTERNAL_DIRECT_REACH_PROTOCOL_CONFIG: InternalProtocolConfig.Direct = + direct { + noiseMechanisms += DEFAULT_INTERNAL_DIRECT_NOISE_MECHANISMS + deterministicCountDistinct = + InternalProtocolConfig.Direct.DeterministicCountDistinct.getDefaultInstance() + liquidLegionsCountDistinct = + InternalProtocolConfig.Direct.LiquidLegionsCountDistinct.getDefaultInstance() + } + + private val DEFAULT_INTERNAL_DIRECT_REACH_AND_FREQUENCY_PROTOCOL_CONFIG: + InternalProtocolConfig.Direct = + direct { + noiseMechanisms += DEFAULT_INTERNAL_DIRECT_NOISE_MECHANISMS + deterministicCountDistinct = + InternalProtocolConfig.Direct.DeterministicCountDistinct.getDefaultInstance() + liquidLegionsCountDistinct = + InternalProtocolConfig.Direct.LiquidLegionsCountDistinct.getDefaultInstance() + deterministicDistribution = + InternalProtocolConfigKt.DirectKt.deterministicDistribution { + maximumFrequency = DEFAULT_MAXIMUM_FREQUENCY_DIRECT_DISTRIBUTION + } + liquidLegionsDistribution = + InternalProtocolConfigKt.DirectKt.liquidLegionsDistribution { + maximumFrequency = DEFAULT_MAXIMUM_FREQUENCY_DIRECT_DISTRIBUTION + } + } + + private val DEFAULT_INTERNAL_DIRECT_IMPRESSION_PROTOCOL_CONFIG: InternalProtocolConfig.Direct = + direct { + noiseMechanisms += DEFAULT_INTERNAL_DIRECT_NOISE_MECHANISMS + deterministicCount = InternalProtocolConfig.Direct.DeterministicCount.getDefaultInstance() + } + + private val DEFAULT_INTERNAL_DIRECT_WATCH_DURATION_PROTOCOL_CONFIG: + InternalProtocolConfig.Direct = + direct { + noiseMechanisms += DEFAULT_INTERNAL_DIRECT_NOISE_MECHANISMS + deterministicSum = InternalProtocolConfig.Direct.DeterministicSum.getDefaultInstance() + } + + private val DEFAULT_INTERNAL_DIRECT_POPULATION_PROTOCOL_CONFIG: InternalProtocolConfig.Direct = + direct { + noiseMechanisms += DEFAULT_INTERNAL_DIRECT_NOISE_MECHANISMS + deterministicCount = InternalProtocolConfig.Direct.DeterministicCount.getDefaultInstance() + } } } diff --git a/src/test/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/RequisitionsServiceTest.kt b/src/test/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/RequisitionsServiceTest.kt index cdbf309faed..856305727a4 100644 --- a/src/test/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/RequisitionsServiceTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/RequisitionsServiceTest.kt @@ -84,6 +84,7 @@ import org.wfanet.measurement.internal.kingdom.ComputationParticipantKt.liquidLe import org.wfanet.measurement.internal.kingdom.FulfillRequisitionRequestKt.directRequisitionParams import org.wfanet.measurement.internal.kingdom.Measurement as InternalMeasurement import org.wfanet.measurement.internal.kingdom.ProtocolConfig as InternalProtocolConfig +import org.wfanet.measurement.internal.kingdom.ProtocolConfigKt as InternalProtocolConfigKt import org.wfanet.measurement.internal.kingdom.Requisition as InternalRequisition import org.wfanet.measurement.internal.kingdom.Requisition.Refusal as InternalRefusal import org.wfanet.measurement.internal.kingdom.Requisition.State as InternalState @@ -209,6 +210,65 @@ class RequisitionsServiceTest { assertThat(result).ignoringRepeatedFieldOrder().isEqualTo(expected) } + @Test + fun `listRequisitions requests internal Requisitions with direct protocol`() { + val internalRequisition = + INTERNAL_REQUISITION.copy { + parentMeasurement = + parentMeasurement.copy { + protocolConfig = internalProtocolConfig { + externalProtocolConfigId = "direct" + direct = INTERNAL_DIRECT_RF_PROTOCOL_CONFIG + } + } + } + + val requisition = + REQUISITION.copy { + protocolConfig = + protocolConfig.copy { + protocols.clear() + protocols += ProtocolConfigKt.protocol { direct = DIRECT_RF_PROTOCOL_CONFIG } + } + } + + whenever(internalRequisitionMock.streamRequisitions(any())) + .thenReturn(flowOf(internalRequisition, internalRequisition)) + + val request = listRequisitionsRequest { parent = MEASUREMENT_NAME } + + val result = + withMeasurementConsumerPrincipal(MEASUREMENT_CONSUMER_NAME) { + runBlocking { service.listRequisitions(request) } + } + + val expected = listRequisitionsResponse { + requisitions += requisition + requisitions += requisition + } + + val streamRequisitionRequest = + captureFirst { + verify(internalRequisitionMock).streamRequisitions(capture()) + } + + assertThat(streamRequisitionRequest) + .ignoringRepeatedFieldOrder() + .isEqualTo( + streamRequisitionsRequest { + limit = DEFAULT_LIMIT + 1 + filter = + StreamRequisitionsRequestKt.filter { + externalMeasurementConsumerId = EXTERNAL_MEASUREMENT_CONSUMER_ID + externalMeasurementId = EXTERNAL_MEASUREMENT_ID + states += VISIBLE_REQUISITION_STATES + } + } + ) + + assertThat(result).ignoringRepeatedFieldOrder().isEqualTo(expected) + } + @Test fun `listRequisitions with page token returns next page`() { whenever(internalRequisitionMock.streamRequisitions(any())) @@ -619,43 +679,97 @@ class RequisitionsServiceTest { } @Test - fun `fulfillDirectRequisition fulfills the requisition`() = runBlocking { - whenever(internalRequisitionMock.fulfillRequisition(any())) - .thenReturn( - INTERNAL_REQUISITION.copy { - state = InternalState.FULFILLED - details = details { encryptedData = REQUISITION_ENCRYPTED_DATA } + fun `fulfillDirectRequisition fulfills the requisition when direct protocol config is not specified`() = + runBlocking { + whenever(internalRequisitionMock.fulfillRequisition(any())) + .thenReturn( + INTERNAL_REQUISITION.copy { + state = InternalState.FULFILLED + details = details { encryptedData = REQUISITION_ENCRYPTED_DATA } + } + ) + + val request = fulfillDirectRequisitionRequest { + name = REQUISITION_NAME + encryptedData = REQUISITION_ENCRYPTED_DATA + nonce = NONCE + } + + val result = + withDataProviderPrincipal(DATA_PROVIDER_NAME) { + runBlocking { service.fulfillDirectRequisition(request) } } - ) - val request = fulfillDirectRequisitionRequest { - name = REQUISITION_NAME - encryptedData = REQUISITION_ENCRYPTED_DATA - nonce = NONCE + val expected = fulfillDirectRequisitionResponse { state = State.FULFILLED } + + verifyProtoArgument( + internalRequisitionMock, + RequisitionsCoroutineImplBase::fulfillRequisition + ) + .comparingExpectedFieldsOnly() + .isEqualTo( + internalFulfillRequisitionRequest { + externalRequisitionId = EXTERNAL_REQUISITION_ID + nonce = NONCE + directParams = directRequisitionParams { + externalDataProviderId = EXTERNAL_DATA_PROVIDER_ID + encryptedData = REQUISITION_ENCRYPTED_DATA + } + } + ) + + assertThat(result).ignoringRepeatedFieldOrder().isEqualTo(expected) } - val result = - withDataProviderPrincipal(DATA_PROVIDER_NAME) { - runBlocking { service.fulfillDirectRequisition(request) } + @Test + fun `fulfillDirectRequisition fulfills the requisition when direct protocol config is specified`() = + runBlocking { + whenever(internalRequisitionMock.fulfillRequisition(any())) + .thenReturn( + INTERNAL_REQUISITION.copy { + state = InternalState.FULFILLED + details = details { encryptedData = REQUISITION_ENCRYPTED_DATA } + parentMeasurement = + parentMeasurement.copy { + protocolConfig = internalProtocolConfig { + externalProtocolConfigId = "direct" + direct = INTERNAL_DIRECT_RF_PROTOCOL_CONFIG + } + } + } + ) + + val request = fulfillDirectRequisitionRequest { + name = REQUISITION_NAME + encryptedData = REQUISITION_ENCRYPTED_DATA + nonce = NONCE } - val expected = fulfillDirectRequisitionResponse { state = State.FULFILLED } + val result = + withDataProviderPrincipal(DATA_PROVIDER_NAME) { + runBlocking { service.fulfillDirectRequisition(request) } + } + + val expected = fulfillDirectRequisitionResponse { state = State.FULFILLED } - verifyProtoArgument(internalRequisitionMock, RequisitionsCoroutineImplBase::fulfillRequisition) - .comparingExpectedFieldsOnly() - .isEqualTo( - internalFulfillRequisitionRequest { - externalRequisitionId = EXTERNAL_REQUISITION_ID - nonce = NONCE - directParams = directRequisitionParams { - externalDataProviderId = EXTERNAL_DATA_PROVIDER_ID - encryptedData = REQUISITION_ENCRYPTED_DATA + verifyProtoArgument( + internalRequisitionMock, + RequisitionsCoroutineImplBase::fulfillRequisition + ) + .comparingExpectedFieldsOnly() + .isEqualTo( + internalFulfillRequisitionRequest { + externalRequisitionId = EXTERNAL_REQUISITION_ID + nonce = NONCE + directParams = directRequisitionParams { + externalDataProviderId = EXTERNAL_DATA_PROVIDER_ID + encryptedData = REQUISITION_ENCRYPTED_DATA + } } - } - ) + ) - assertThat(result).ignoringRepeatedFieldOrder().isEqualTo(expected) - } + assertThat(result).ignoringRepeatedFieldOrder().isEqualTo(expected) + } @Test fun `fulfillDirectRequisition throw INVALID_ARGUMENT when name is unspecified`() = runBlocking { @@ -737,6 +851,7 @@ class RequisitionsServiceTest { } companion object { + private const val MAXIMUM_FREQUENCY_DIRECT_DISTRIBUTION = 10 private val MEASUREMENT_SPEC = measurementSpec { measurementPublicKey = UPDATE_TIME.toByteString() reachAndFrequency = @@ -754,6 +869,34 @@ class RequisitionsServiceTest { nonceHashes += ByteString.copyFromUtf8("foo") } + private val DIRECT_RF_PROTOCOL_CONFIG = + ProtocolConfigKt.direct { + noiseMechanisms += ProtocolConfig.NoiseMechanism.CONTINUOUS_LAPLACE + noiseMechanisms += ProtocolConfig.NoiseMechanism.CONTINUOUS_GAUSSIAN + deterministicCountDistinct = ProtocolConfigKt.DirectKt.deterministicCountDistinct {} + liquidLegionsCountDistinct = ProtocolConfigKt.DirectKt.liquidLegionsCountDistinct {} + deterministicDistribution = + ProtocolConfig.Direct.DeterministicDistribution.getDefaultInstance() + liquidLegionsDistribution = + ProtocolConfig.Direct.LiquidLegionsDistribution.getDefaultInstance() + } + + private val INTERNAL_DIRECT_RF_PROTOCOL_CONFIG = + InternalProtocolConfigKt.direct { + noiseMechanisms += InternalProtocolConfig.NoiseMechanism.CONTINUOUS_LAPLACE + noiseMechanisms += InternalProtocolConfig.NoiseMechanism.CONTINUOUS_GAUSSIAN + deterministicCountDistinct = InternalProtocolConfigKt.DirectKt.deterministicCountDistinct {} + liquidLegionsCountDistinct = InternalProtocolConfigKt.DirectKt.liquidLegionsCountDistinct {} + deterministicDistribution = + InternalProtocolConfigKt.DirectKt.deterministicDistribution { + maximumFrequency = MAXIMUM_FREQUENCY_DIRECT_DISTRIBUTION + } + liquidLegionsDistribution = + InternalProtocolConfigKt.DirectKt.liquidLegionsDistribution { + maximumFrequency = MAXIMUM_FREQUENCY_DIRECT_DISTRIBUTION + } + } + private val INTERNAL_REQUISITION: InternalRequisition = internalRequisition { externalMeasurementConsumerId = EXTERNAL_MEASUREMENT_CONSUMER_ID externalMeasurementId = EXTERNAL_MEASUREMENT_ID @@ -815,7 +958,8 @@ class RequisitionsServiceTest { } protocolConfig = protocolConfig { measurementType = ProtocolConfig.MeasurementType.REACH_AND_FREQUENCY - protocols += ProtocolConfigKt.protocol { direct = ProtocolConfigKt.direct {} } + protocols += + ProtocolConfigKt.protocol { direct = DEFAULT_DIRECT_REACH_AND_FREQUENCY_PROTOCOL_CONFIG } } dataProviderCertificate = DataProviderCertificateKey( diff --git a/src/test/kotlin/org/wfanet/measurement/loadtest/dataprovider/EdpSimulatorTest.kt b/src/test/kotlin/org/wfanet/measurement/loadtest/dataprovider/EdpSimulatorTest.kt index 14821445245..01ce1ae1165 100644 --- a/src/test/kotlin/org/wfanet/measurement/loadtest/dataprovider/EdpSimulatorTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/loadtest/dataprovider/EdpSimulatorTest.kt @@ -69,7 +69,6 @@ import org.wfanet.measurement.api.v2alpha.MeasurementSpecKt.reach import org.wfanet.measurement.api.v2alpha.MeasurementSpecKt.reachAndFrequency import org.wfanet.measurement.api.v2alpha.MeasurementSpecKt.vidSamplingInterval import org.wfanet.measurement.api.v2alpha.ProtocolConfig -import org.wfanet.measurement.api.v2alpha.ProtocolConfig.NoiseMechanism import org.wfanet.measurement.api.v2alpha.ProtocolConfigKt import org.wfanet.measurement.api.v2alpha.RefuseRequisitionRequest import org.wfanet.measurement.api.v2alpha.Requisition @@ -145,7 +144,6 @@ import org.wfanet.measurement.consent.client.measurementconsumer.encryptRequisit import org.wfanet.measurement.consent.client.measurementconsumer.signEncryptionPublicKey import org.wfanet.measurement.consent.client.measurementconsumer.signMeasurementSpec import org.wfanet.measurement.consent.client.measurementconsumer.signRequisitionSpec -import org.wfanet.measurement.eventdataprovider.noiser.DirectNoiseMechanism import org.wfanet.measurement.eventdataprovider.privacybudgetmanagement.AcdpCharge import org.wfanet.measurement.eventdataprovider.privacybudgetmanagement.AgeGroup as PrivacyLandscapeAge import org.wfanet.measurement.eventdataprovider.privacybudgetmanagement.CompositionMechanism @@ -175,7 +173,7 @@ private const val EDP_NAME = "dataProviders/someDataProvider" private const val LLV2_DECAY_RATE = 12.0 private const val LLV2_MAX_SIZE = 100_000L -private val NOISE_MECHANISM = NoiseMechanism.GEOMETRIC +private val NOISE_MECHANISM = ProtocolConfig.NoiseMechanism.GEOMETRIC private val MEASUREMENT_CONSUMER_CERTIFICATE_DER = SECRET_FILES_PATH.resolve("mc_cs_cert.der").toFile().readByteString() @@ -200,7 +198,6 @@ private val TIME_RANGE = OpenEndTimeRange.fromClosedDateRange(FIRST_EVENT_DATE.. private const val DUCHY_ID = "worker1" private const val RANDOM_SEED: Long = 0 -private val DIRECT_NOISE_MECHANISM = DirectNoiseMechanism.LAPLACE private val COMPOSITION_MECHANISM = CompositionMechanism.DP_ADVANCED @RunWith(JUnit4::class) @@ -321,7 +318,6 @@ class EdpSimulatorTest { dummyThrottler, privacyBudgetManager, TRUSTED_CERTIFICATES, - DIRECT_NOISE_MECHANISM, compositionMechanism = COMPOSITION_MECHANISM ) @@ -388,7 +384,6 @@ class EdpSimulatorTest { dummyThrottler, privacyBudgetManager, TRUSTED_CERTIFICATES, - DIRECT_NOISE_MECHANISM, compositionMechanism = COMPOSITION_MECHANISM ) @@ -439,7 +434,6 @@ class EdpSimulatorTest { dummyThrottler, privacyBudgetManager, TRUSTED_CERTIFICATES, - DIRECT_NOISE_MECHANISM, compositionMechanism = COMPOSITION_MECHANISM ) @@ -502,7 +496,6 @@ class EdpSimulatorTest { MinimumIntervalThrottler(Clock.systemUTC(), Duration.ofMillis(1000)), privacyBudgetManager, TRUSTED_CERTIFICATES, - DIRECT_NOISE_MECHANISM, compositionMechanism = COMPOSITION_MECHANISM ) @@ -582,7 +575,6 @@ class EdpSimulatorTest { dummyThrottler, privacyBudgetManager, TRUSTED_CERTIFICATES, - DIRECT_NOISE_MECHANISM, compositionMechanism = COMPOSITION_MECHANISM ) runBlocking { @@ -664,7 +656,7 @@ class EdpSimulatorTest { ProtocolConfigKt.protocol { liquidLegionsV2 = ProtocolConfigKt.liquidLegionsV2 { - noiseMechanism = NoiseMechanism.DISCRETE_GAUSSIAN + noiseMechanism = ProtocolConfig.NoiseMechanism.DISCRETE_GAUSSIAN sketchParams = liquidLegionsSketchParams { decayRate = LLV2_DECAY_RATE maxSize = LLV2_MAX_SIZE @@ -731,7 +723,6 @@ class EdpSimulatorTest { dummyThrottler, privacyBudgetManager, TRUSTED_CERTIFICATES, - DIRECT_NOISE_MECHANISM, compositionMechanism = CompositionMechanism.ACDP ) runBlocking { @@ -802,9 +793,26 @@ class EdpSimulatorTest { width = PRIVACY_BUCKET_VID_SAMPLE_WIDTH } } + val noiseMechanismOption = ProtocolConfig.NoiseMechanism.CONTINUOUS_GAUSSIAN + val requisition = - DIRECT_REQUISITION.copy { + REQUISITION.copy { this.measurementSpec = signMeasurementSpec(measurementSpec, MC_SIGNING_KEY) + protocolConfig = + protocolConfig.copy { + protocols.clear() + protocols += + ProtocolConfigKt.protocol { + direct = + ProtocolConfigKt.direct { + noiseMechanisms += noiseMechanismOption + deterministicCountDistinct = + ProtocolConfig.Direct.DeterministicCountDistinct.getDefaultInstance() + deterministicDistribution = + ProtocolConfig.Direct.DeterministicDistribution.getDefaultInstance() + } + } + } } requisitionsServiceMock.stub { @@ -861,7 +869,6 @@ class EdpSimulatorTest { dummyThrottler, privacyBudgetManager, TRUSTED_CERTIFICATES, - DirectNoiseMechanism.GAUSSIAN, compositionMechanism = CompositionMechanism.ACDP ) runBlocking { @@ -968,7 +975,6 @@ class EdpSimulatorTest { dummyThrottler, privacyBudgetManager, TRUSTED_CERTIFICATES, - DIRECT_NOISE_MECHANISM, sketchEncrypter = fakeSketchEncrypter, compositionMechanism = COMPOSITION_MECHANISM ) @@ -1026,7 +1032,6 @@ class EdpSimulatorTest { dummyThrottler, privacyBudgetManager, TRUSTED_CERTIFICATES, - DIRECT_NOISE_MECHANISM, compositionMechanism = COMPOSITION_MECHANISM ) val requisition = @@ -1086,7 +1091,6 @@ class EdpSimulatorTest { dummyThrottler, privacyBudgetManager, TRUSTED_CERTIFICATES, - DIRECT_NOISE_MECHANISM, compositionMechanism = COMPOSITION_MECHANISM ) eventGroupsServiceMock.stub { @@ -1126,7 +1130,7 @@ class EdpSimulatorTest { ProtocolConfigKt.protocol { liquidLegionsV2 = ProtocolConfigKt.liquidLegionsV2 { - noiseMechanism = NoiseMechanism.GEOMETRIC + noiseMechanism = ProtocolConfig.NoiseMechanism.GEOMETRIC sketchParams = liquidLegionsSketchParams { decayRate = LLV2_DECAY_RATE maxSize = LLV2_MAX_SIZE @@ -1158,7 +1162,6 @@ class EdpSimulatorTest { dummyThrottler, privacyBudgetManager, TRUSTED_CERTIFICATES, - DIRECT_NOISE_MECHANISM, compositionMechanism = CompositionMechanism.ACDP ) @@ -1178,15 +1181,33 @@ class EdpSimulatorTest { refusal = refusal { justification = Refusal.Justification.SPEC_INVALID } } ) - assertThat(refuseRequest.refusal.message).contains(NoiseMechanism.GEOMETRIC.toString()) + assertThat(refuseRequest.refusal.message) + .contains(ProtocolConfig.NoiseMechanism.GEOMETRIC.toString()) assertThat(fakeRequisitionFulfillmentService.fullfillRequisitionInvocations).isEmpty() verifyBlocking(requisitionsServiceMock, never()) { fulfillDirectRequisition(any()) } } @Test - fun `refuses Requisition when directNoiseMechanism is LAPLACE and compositionMechanism is ACDP`() { - val requisition = DIRECT_REQUISITION - + fun `refuses Requisition when directNoiseMechanism option provided by Kingdom is not CONTINUOUS_GAUSSIAN and compositionMechanism is ACDP`() { + val noiseMechanismOption = ProtocolConfig.NoiseMechanism.CONTINUOUS_LAPLACE + val requisition = + REQUISITION.copy { + protocolConfig = + protocolConfig.copy { + protocols.clear() + protocols += + ProtocolConfigKt.protocol { + direct = + ProtocolConfigKt.direct { + noiseMechanisms += noiseMechanismOption + deterministicCountDistinct = + ProtocolConfig.Direct.DeterministicCountDistinct.getDefaultInstance() + deterministicDistribution = + ProtocolConfig.Direct.DeterministicDistribution.getDefaultInstance() + } + } + } + } requisitionsServiceMock.stub { onBlocking { listRequisitions(any()) } .thenReturn(listRequisitionsResponse { requisitions += requisition }) @@ -1206,7 +1227,7 @@ class EdpSimulatorTest { dummyThrottler, privacyBudgetManager, TRUSTED_CERTIFICATES, - DirectNoiseMechanism.LAPLACE, + random = Random(RANDOM_SEED), compositionMechanism = CompositionMechanism.ACDP ) @@ -1226,14 +1247,32 @@ class EdpSimulatorTest { refusal = refusal { justification = Refusal.Justification.SPEC_INVALID } } ) - assertThat(refuseRequest.refusal.message).contains(DirectNoiseMechanism.LAPLACE.toString()) + assertThat(refuseRequest.refusal.message).contains("No valid noise mechanism option") assertThat(fakeRequisitionFulfillmentService.fullfillRequisitionInvocations).isEmpty() verifyBlocking(requisitionsServiceMock, never()) { fulfillDirectRequisition(any()) } } @Test fun `fulfills direct reach and frequency Requisition`() { - val requisition = DIRECT_REQUISITION + val noiseMechanismOption = ProtocolConfig.NoiseMechanism.CONTINUOUS_LAPLACE + val requisition = + REQUISITION.copy { + protocolConfig = + protocolConfig.copy { + protocols.clear() + protocols += + ProtocolConfigKt.protocol { + direct = + ProtocolConfigKt.direct { + noiseMechanisms += noiseMechanismOption + deterministicCountDistinct = + ProtocolConfig.Direct.DeterministicCountDistinct.getDefaultInstance() + deterministicDistribution = + ProtocolConfig.Direct.DeterministicDistribution.getDefaultInstance() + } + } + } + } requisitionsServiceMock.stub { onBlocking { listRequisitions(any()) } .thenReturn(listRequisitionsResponse { requisitions += requisition }) @@ -1252,7 +1291,6 @@ class EdpSimulatorTest { dummyThrottler, privacyBudgetManager, TRUSTED_CERTIFICATES, - DIRECT_NOISE_MECHANISM, random = Random(RANDOM_SEED), compositionMechanism = COMPOSITION_MECHANISM ) @@ -1266,6 +1304,10 @@ class EdpSimulatorTest { ) val result = Measurement.Result.parseFrom(decryptResult(request.encryptedData, MC_PRIVATE_KEY).data) + assertThat(result.reach.noiseMechanism == noiseMechanismOption) + assertThat(result.reach.hasDeterministicCountDistinct()) + assertThat(result.frequency.noiseMechanism == noiseMechanismOption) + assertThat(result.frequency.hasDeterministicDistribution()) assertThat(result).reachValue().isEqualTo(2000L) assertThat(result).frequencyDistribution().isWithin(0.001).of(mapOf(2L to 0.5, 4L to 0.5)) } @@ -1274,9 +1316,25 @@ class EdpSimulatorTest { fun `fulfills direct reach and frequency Requisition with sampling rate less than 1`() { val measurementSpec = MEASUREMENT_SPEC.copy { vidSamplingInterval = vidSamplingInterval.copy { width = 0.1f } } + val noiseMechanismOption = ProtocolConfig.NoiseMechanism.CONTINUOUS_LAPLACE val requisition = - DIRECT_REQUISITION.copy { + REQUISITION.copy { this.measurementSpec = signMeasurementSpec(measurementSpec, MC_SIGNING_KEY) + protocolConfig = + protocolConfig.copy { + protocols.clear() + protocols += + ProtocolConfigKt.protocol { + direct = + ProtocolConfigKt.direct { + noiseMechanisms += noiseMechanismOption + deterministicCountDistinct = + ProtocolConfig.Direct.DeterministicCountDistinct.getDefaultInstance() + deterministicDistribution = + ProtocolConfig.Direct.DeterministicDistribution.getDefaultInstance() + } + } + } } requisitionsServiceMock.stub { onBlocking { listRequisitions(any()) } @@ -1296,7 +1354,6 @@ class EdpSimulatorTest { dummyThrottler, privacyBudgetManager, TRUSTED_CERTIFICATES, - DIRECT_NOISE_MECHANISM, random = Random(RANDOM_SEED), compositionMechanism = COMPOSITION_MECHANISM ) @@ -1310,6 +1367,11 @@ class EdpSimulatorTest { ) val result = Measurement.Result.parseFrom(decryptResult(request.encryptedData, MC_PRIVATE_KEY).data) + + assertThat(result.reach.noiseMechanism == noiseMechanismOption) + assertThat(result.reach.hasDeterministicCountDistinct()) + assertThat(result.frequency.noiseMechanism == noiseMechanismOption) + assertThat(result.frequency.hasDeterministicDistribution()) assertThat(result).reachValue().isEqualTo(1920) assertThat(result) .frequencyDistribution() @@ -1317,12 +1379,147 @@ class EdpSimulatorTest { .of(mapOf(2L to 0.5010227687681921, 4L to 0.5072032690534161)) } + @Test + fun `fails to fulfill direct reach and frequency Requisition when no direct noise mechanism is picked by EDP`() { + val noiseMechanismOption = ProtocolConfig.NoiseMechanism.NONE + val requisition = + REQUISITION.copy { + protocolConfig = + protocolConfig.copy { + protocols.clear() + protocols += + ProtocolConfigKt.protocol { + direct = + ProtocolConfigKt.direct { + noiseMechanisms += noiseMechanismOption + deterministicCountDistinct = + ProtocolConfig.Direct.DeterministicCountDistinct.getDefaultInstance() + deterministicDistribution = + ProtocolConfig.Direct.DeterministicDistribution.getDefaultInstance() + } + } + } + } + requisitionsServiceMock.stub { + onBlocking { listRequisitions(any()) } + .thenReturn(listRequisitionsResponse { requisitions += requisition }) + } + val simulator = + EdpSimulator( + EDP_DATA, + MC_NAME, + measurementConsumersStub, + certificatesStub, + eventGroupsStub, + eventGroupMetadataDescriptorsStub, + requisitionsStub, + requisitionFulfillmentStub, + syntheticGeneratorEventQuery, + dummyThrottler, + privacyBudgetManager, + TRUSTED_CERTIFICATES, + random = Random(RANDOM_SEED), + compositionMechanism = COMPOSITION_MECHANISM + ) + + runBlocking { simulator.executeRequisitionFulfillingWorkflow() } + + val refuseRequest: RefuseRequisitionRequest = + verifyAndCapture(requisitionsServiceMock, RequisitionsCoroutineImplBase::refuseRequisition) + assertThat(refuseRequest) + .ignoringFieldScope( + FieldScopes.allowingFieldDescriptors( + Refusal.getDescriptor().findFieldByNumber(Refusal.MESSAGE_FIELD_NUMBER) + ) + ) + .isEqualTo( + refuseRequisitionRequest { + name = REQUISITION.name + refusal = refusal { justification = Refusal.Justification.SPEC_INVALID } + } + ) + assertThat(refuseRequest.refusal.message).contains("No valid noise mechanism option") + assertThat(fakeRequisitionFulfillmentService.fullfillRequisitionInvocations).isEmpty() + verifyBlocking(requisitionsServiceMock, never()) { fulfillDirectRequisition(any()) } + } + + @Test + fun `fails to fulfill direct reach and frequency Requisition when no direct methodology is picked by EDP`() { + val noiseMechanismOption = ProtocolConfig.NoiseMechanism.CONTINUOUS_LAPLACE + val requisition = + REQUISITION.copy { + protocolConfig = + protocolConfig.copy { + protocols.clear() + protocols += + ProtocolConfigKt.protocol { + direct = ProtocolConfigKt.direct { noiseMechanisms += noiseMechanismOption } + } + } + } + requisitionsServiceMock.stub { + onBlocking { listRequisitions(any()) } + .thenReturn(listRequisitionsResponse { requisitions += requisition }) + } + val simulator = + EdpSimulator( + EDP_DATA, + MC_NAME, + measurementConsumersStub, + certificatesStub, + eventGroupsStub, + eventGroupMetadataDescriptorsStub, + requisitionsStub, + requisitionFulfillmentStub, + syntheticGeneratorEventQuery, + dummyThrottler, + privacyBudgetManager, + TRUSTED_CERTIFICATES, + random = Random(RANDOM_SEED), + compositionMechanism = COMPOSITION_MECHANISM + ) + + runBlocking { simulator.executeRequisitionFulfillingWorkflow() } + + val refuseRequest: RefuseRequisitionRequest = + verifyAndCapture(requisitionsServiceMock, RequisitionsCoroutineImplBase::refuseRequisition) + assertThat(refuseRequest) + .ignoringFieldScope( + FieldScopes.allowingFieldDescriptors( + Refusal.getDescriptor().findFieldByNumber(Refusal.MESSAGE_FIELD_NUMBER) + ) + ) + .isEqualTo( + refuseRequisitionRequest { + name = REQUISITION.name + refusal = refusal { justification = Refusal.Justification.DECLINED } + } + ) + assertThat(refuseRequest.refusal.message).contains("No valid methodologies") + assertThat(fakeRequisitionFulfillmentService.fullfillRequisitionInvocations).isEmpty() + verifyBlocking(requisitionsServiceMock, never()) { fulfillDirectRequisition(any()) } + } + @Test fun `fulfills direct reach-only Requisition`() { val measurementSpec = REACH_ONLY_MEASUREMENT_SPEC + val noiseMechanismOption = ProtocolConfig.NoiseMechanism.CONTINUOUS_LAPLACE val requisition = - DIRECT_REQUISITION.copy { + REQUISITION.copy { this.measurementSpec = signMeasurementSpec(measurementSpec, MC_SIGNING_KEY) + protocolConfig = + protocolConfig.copy { + protocols.clear() + protocols += + ProtocolConfigKt.protocol { + direct = + ProtocolConfigKt.direct { + noiseMechanisms += noiseMechanismOption + deterministicCountDistinct = + ProtocolConfig.Direct.DeterministicCountDistinct.getDefaultInstance() + } + } + } } requisitionsServiceMock.stub { onBlocking { listRequisitions(any()) } @@ -1342,7 +1539,6 @@ class EdpSimulatorTest { dummyThrottler, privacyBudgetManager, TRUSTED_CERTIFICATES, - DIRECT_NOISE_MECHANISM, random = Random(RANDOM_SEED), compositionMechanism = COMPOSITION_MECHANISM ) @@ -1356,6 +1552,9 @@ class EdpSimulatorTest { ) val result = Measurement.Result.parseFrom(decryptResult(request.encryptedData, MC_PRIVATE_KEY).data) + + assertThat(result.reach.noiseMechanism == noiseMechanismOption) + assertThat(result.reach.hasDeterministicCountDistinct()) assertThat(result).reachValue().isEqualTo(2000L) assertThat(result.hasFrequency()).isFalse() } @@ -1366,9 +1565,23 @@ class EdpSimulatorTest { REACH_ONLY_MEASUREMENT_SPEC.copy { vidSamplingInterval = vidSamplingInterval.copy { width = 0.1f } } + val noiseMechanismOption = ProtocolConfig.NoiseMechanism.CONTINUOUS_LAPLACE val requisition = - DIRECT_REQUISITION.copy { + REQUISITION.copy { this.measurementSpec = signMeasurementSpec(measurementSpec, MC_SIGNING_KEY) + protocolConfig = + protocolConfig.copy { + protocols.clear() + protocols += + ProtocolConfigKt.protocol { + direct = + ProtocolConfigKt.direct { + noiseMechanisms += noiseMechanismOption + deterministicCountDistinct = + ProtocolConfig.Direct.DeterministicCountDistinct.getDefaultInstance() + } + } + } } requisitionsServiceMock.stub { onBlocking { listRequisitions(any()) } @@ -1388,7 +1601,6 @@ class EdpSimulatorTest { dummyThrottler, privacyBudgetManager, TRUSTED_CERTIFICATES, - DIRECT_NOISE_MECHANISM, random = Random(RANDOM_SEED), compositionMechanism = COMPOSITION_MECHANISM ) @@ -1402,10 +1614,142 @@ class EdpSimulatorTest { ) val result = Measurement.Result.parseFrom(decryptResult(request.encryptedData, MC_PRIVATE_KEY).data) + + assertThat(result.reach.noiseMechanism == noiseMechanismOption) + assertThat(result.reach.hasDeterministicCountDistinct()) assertThat(result).reachValue().isEqualTo(1920) assertThat(result.hasFrequency()).isFalse() } + @Test + fun `fails to fulfill direct reach-only Requisition when no direct noise mechanism is picked by EDP`() { + val measurementSpec = + REACH_ONLY_MEASUREMENT_SPEC.copy { + vidSamplingInterval = vidSamplingInterval.copy { width = 0.1f } + } + val noiseMechanismOption = ProtocolConfig.NoiseMechanism.NONE + val requisition = + REQUISITION.copy { + this.measurementSpec = signMeasurementSpec(measurementSpec, MC_SIGNING_KEY) + protocolConfig = + protocolConfig.copy { + protocols.clear() + protocols += + ProtocolConfigKt.protocol { + direct = + ProtocolConfigKt.direct { + noiseMechanisms += noiseMechanismOption + deterministicCountDistinct = + ProtocolConfig.Direct.DeterministicCountDistinct.getDefaultInstance() + } + } + } + } + requisitionsServiceMock.stub { + onBlocking { listRequisitions(any()) } + .thenReturn(listRequisitionsResponse { requisitions += requisition }) + } + val simulator = + EdpSimulator( + EDP_DATA, + MC_NAME, + measurementConsumersStub, + certificatesStub, + eventGroupsStub, + eventGroupMetadataDescriptorsStub, + requisitionsStub, + requisitionFulfillmentStub, + syntheticGeneratorEventQuery, + dummyThrottler, + privacyBudgetManager, + TRUSTED_CERTIFICATES, + random = Random(RANDOM_SEED), + compositionMechanism = COMPOSITION_MECHANISM + ) + + runBlocking { simulator.executeRequisitionFulfillingWorkflow() } + + val refuseRequest: RefuseRequisitionRequest = + verifyAndCapture(requisitionsServiceMock, RequisitionsCoroutineImplBase::refuseRequisition) + assertThat(refuseRequest) + .ignoringFieldScope( + FieldScopes.allowingFieldDescriptors( + Refusal.getDescriptor().findFieldByNumber(Refusal.MESSAGE_FIELD_NUMBER) + ) + ) + .isEqualTo( + refuseRequisitionRequest { + name = REQUISITION.name + refusal = refusal { justification = Refusal.Justification.SPEC_INVALID } + } + ) + assertThat(refuseRequest.refusal.message).contains("No valid noise mechanism option") + assertThat(fakeRequisitionFulfillmentService.fullfillRequisitionInvocations).isEmpty() + verifyBlocking(requisitionsServiceMock, never()) { fulfillDirectRequisition(any()) } + } + + @Test + fun `fails to fulfill direct reach-only Requisition when no direct methodology is picked by EDP`() { + val measurementSpec = + REACH_ONLY_MEASUREMENT_SPEC.copy { + vidSamplingInterval = vidSamplingInterval.copy { width = 0.1f } + } + val noiseMechanismOption = ProtocolConfig.NoiseMechanism.CONTINUOUS_LAPLACE + val requisition = + REQUISITION.copy { + this.measurementSpec = signMeasurementSpec(measurementSpec, MC_SIGNING_KEY) + protocolConfig = + protocolConfig.copy { + protocols.clear() + protocols += + ProtocolConfigKt.protocol { + direct = ProtocolConfigKt.direct { noiseMechanisms += noiseMechanismOption } + } + } + } + requisitionsServiceMock.stub { + onBlocking { listRequisitions(any()) } + .thenReturn(listRequisitionsResponse { requisitions += requisition }) + } + val simulator = + EdpSimulator( + EDP_DATA, + MC_NAME, + measurementConsumersStub, + certificatesStub, + eventGroupsStub, + eventGroupMetadataDescriptorsStub, + requisitionsStub, + requisitionFulfillmentStub, + syntheticGeneratorEventQuery, + dummyThrottler, + privacyBudgetManager, + TRUSTED_CERTIFICATES, + random = Random(RANDOM_SEED), + compositionMechanism = COMPOSITION_MECHANISM + ) + + runBlocking { simulator.executeRequisitionFulfillingWorkflow() } + + val refuseRequest: RefuseRequisitionRequest = + verifyAndCapture(requisitionsServiceMock, RequisitionsCoroutineImplBase::refuseRequisition) + assertThat(refuseRequest) + .ignoringFieldScope( + FieldScopes.allowingFieldDescriptors( + Refusal.getDescriptor().findFieldByNumber(Refusal.MESSAGE_FIELD_NUMBER) + ) + ) + .isEqualTo( + refuseRequisitionRequest { + name = REQUISITION.name + refusal = refusal { justification = Refusal.Justification.DECLINED } + } + ) + assertThat(refuseRequest.refusal.message).contains("No valid methodologies") + assertThat(fakeRequisitionFulfillmentService.fullfillRequisitionInvocations).isEmpty() + verifyBlocking(requisitionsServiceMock, never()) { fulfillDirectRequisition(any()) } + } + private class FakeRequisitionFulfillmentService : RequisitionFulfillmentCoroutineImplBase() { data class FulfillRequisitionInvocation(val requests: List) @@ -1540,15 +1884,6 @@ class EdpSimulatorTest { } } } - private val DIRECT_REQUISITION = - REQUISITION.copy { - protocolConfig = - protocolConfig.copy { - protocols.clear() - protocols += - ProtocolConfigKt.protocol { direct = ProtocolConfig.Direct.getDefaultInstance() } - } - } private val TRUSTED_CERTIFICATES: Map = readCertificateCollection(SECRET_FILES_PATH.resolve("edp_trusted_certs.pem").toFile())