Skip to content

Commit

Permalink
Circulate direct computation methodology from protocol config to meas…
Browse files Browse the repository at this point in the history
…urement result.
  • Loading branch information
riemanli committed Aug 9, 2023
1 parent 764a02f commit 4e9e76d
Show file tree
Hide file tree
Showing 22 changed files with 634 additions and 48 deletions.
5 changes: 3 additions & 2 deletions build/repositories.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,12 @@ def wfa_measurement_system_repositories():
version = "0.10.0",
)

# DO_NOT_SUBMIT: switch to release version before submitting the PR.
wfa_repo_archive(
name = "wfa_measurement_proto",
repo = "cross-media-measurement-api",
sha256 = "3ccf5e4e81f2b0cd9abfc0fe9945096e6ff1c18577a9d9f67ea60470c64c3ec3",
version = "0.39.1",
sha256 = "808a4ee3056ccd2804c2a310f88d882855aa6f431a51b546d83bb7c3cc8e1f0e",
commit = "fde2e7bc5bfc855e0f27a259b1540e3150b774a0",
)

wfa_repo_archive(
Expand Down
3 changes: 3 additions & 0 deletions src/main/k8s/kingdom.cue
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ import ("strings")

_open_id_redirect_uri_flag: "--open-id-redirect-uri=https://localhost:2048"

_directNoiseMechanismFlag: "--direct-noise-mechanism=NONE,CONTINUOUS_LAPLACE,CONTINUOUS_GAUSSIAN"

_kingdomCompletedMeasurementsTimeToLiveFlag: "--time-to-live=\(_completedMeasurementsTimeToLive)"
_kingdomCompletedMeasurementsDryRunRetentionPolicyFlag: "--dry-run=\(_completedMeasurementsDryRun)"
_kingdomPendingMeasurementsTimeToLiveFlag: "--time-to-live=\(_pendingMeasurementsTimeToLive)"
Expand Down Expand Up @@ -168,6 +170,7 @@ import ("strings")
_akid_to_principal_map_file_flag,
_open_id_redirect_uri_flag,
_duchy_info_config_flag,
_directNoiseMechanismFlag,
] + Container._commonServerFlags
}
spec: template: spec: {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ kt_jvm_library(
"//src/main/kotlin/org/wfanet/measurement/kingdom/deploy/common/service:data_services",
"//src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha:api_key_authentication_server_interceptor",
"//src/main/kotlin/org/wfanet/measurement/loadtest/panelmatchresourcesetup",
"//src/main/proto/wfa/measurement/api/v2alpha:protocol_config_kt_jvm_proto",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc/testing",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/testing",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import java.util.logging.Logger
import org.junit.rules.TestRule
import org.junit.runner.Description
import org.junit.runners.model.Statement
import org.wfanet.measurement.api.v2alpha.ProtocolConfig
import org.wfanet.measurement.api.v2alpha.testing.withMetadataPrincipalIdentities
import org.wfanet.measurement.common.grpc.testing.GrpcTestServerRule
import org.wfanet.measurement.common.grpc.withDefaultDeadline
Expand Down Expand Up @@ -150,7 +151,7 @@ class InProcessKingdom(
EventGroupMetadataDescriptorsService(internalEventGroupMetadataDescriptorsClient)
.withMetadataPrincipalIdentities()
.withApiKeyAuthenticationServerInterceptor(internalApiKeysClient),
MeasurementsService(internalMeasurementsClient)
MeasurementsService(internalMeasurementsClient, MEASUREMENT_NOISE_MECHANISMS)
.withMetadataPrincipalIdentities()
.withApiKeyAuthenticationServerInterceptor(internalApiKeysClient),
PublicKeysService(internalPublicKeysClient)
Expand Down Expand Up @@ -206,5 +207,11 @@ class InProcessKingdom(

/** Default deadline for RPCs to internal server in milliseconds. */
private const val DEFAULT_INTERNAL_DEADLINE_MILLIS = 30_000L
private val MEASUREMENT_NOISE_MECHANISMS: List<ProtocolConfig.NoiseMechanism> =
listOf(
ProtocolConfig.NoiseMechanism.NONE,
ProtocolConfig.NoiseMechanism.CONTINUOUS_LAPLACE,
ProtocolConfig.NoiseMechanism.CONTINUOUS_GAUSSIAN,
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package org.wfanet.measurement.kingdom.deploy.common.server
import io.grpc.ServerServiceDefinition
import java.io.File
import org.wfanet.measurement.api.v2alpha.AkidPrincipalLookup
import org.wfanet.measurement.api.v2alpha.ProtocolConfig.NoiseMechanism
import org.wfanet.measurement.api.v2alpha.withPrincipalsFromX509AuthorityKeyIdentifiers
import org.wfanet.measurement.common.commandLineMain
import org.wfanet.measurement.common.crypto.SigningCerts
Expand Down Expand Up @@ -143,6 +144,7 @@ private fun run(
.withApiKeyAuthenticationServerInterceptor(internalApiKeysCoroutineStub),
MeasurementsService(
InternalMeasurementsCoroutineStub(channel),
v2alphaFlags.directNoiseMechanisms
)
.withPrincipalsFromX509AuthorityKeyIdentifiers(principalLookup)
.withApiKeyAuthenticationServerInterceptor(internalApiKeysCoroutineStub),
Expand Down Expand Up @@ -192,6 +194,7 @@ fun main(args: Array<String>) = commandLineMain(::run, args)

/** Flags specific to the V2alpha API version. */
private class V2alphaFlags {

@CommandLine.Option(
names = ["--authority-key-identifier-to-principal-map-file"],
description = ["File path to a AuthorityKeyToPrincipalMap textproto"],
Expand All @@ -207,4 +210,45 @@ private class V2alphaFlags {
)
lateinit var redirectUri: String
private set

val directNoiseMechanisms: MutableList<NoiseMechanism> = mutableListOf()

@CommandLine.Spec lateinit var spec: CommandLine.Model.CommandSpec // injected by picocli

@CommandLine.Option(
names = ["--direct-noise-mechanisms"],
split = ",",
description =
["Noise mechanisms that can be used in direct computation. Options are separated by `,`."],
required = true
)
fun setDirectNoiseMechanisms(noiseMechanisms: List<NoiseMechanism>) {
for (noiseMechanism in noiseMechanisms) {
when (noiseMechanism) {
NoiseMechanism.NONE,
NoiseMechanism.CONTINUOUS_LAPLACE,
NoiseMechanism.CONTINUOUS_GAUSSIAN -> {}
NoiseMechanism.GEOMETRIC,
NoiseMechanism.DISCRETE_GAUSSIAN -> {
throw CommandLine.ParameterException(
spec.commandLine(),
String.format(
"Invalid noise mechanism $noiseMechanism for option '--direct-noise-mechanism'. " +
"Discrete mechanisms are not supported for direct computations."
)
)
}
NoiseMechanism.NOISE_MECHANISM_UNSPECIFIED,
NoiseMechanism.UNRECOGNIZED -> {
throw CommandLine.ParameterException(
spec.commandLine(),
String.format(
"Invalid noise mechanism $noiseMechanism for option '--direct-noise-mechanism'."
)
)
}
}
}
directNoiseMechanisms += noiseMechanisms
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,9 @@ class CreateMeasurement(private val request: CreateMeasurementRequest) :
ProtocolConfig.ProtocolCase.REACH_ONLY_LIQUID_LEGIONS_V2 -> {
createComputedMeasurement(request.measurement, measurementConsumerId)
}
ProtocolConfig.ProtocolCase.PROTOCOL_NOT_SET ->
ProtocolConfig.ProtocolCase.DIRECT ->
createDirectMeasurement(request.measurement, measurementConsumerId)
ProtocolConfig.ProtocolCase.PROTOCOL_NOT_SET -> error("Protocol is not set.")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,7 @@ class ExchangeStepsService(private val internalExchangeSteps: InternalExchangeSt
try {
it.toV2Alpha()
} catch (e: Throwable) {
failGrpc(Status.INVALID_ARGUMENT) {
e.message ?: "Failed to convert ProtocolConfig ExchangeStep"
}
failGrpc(Status.INVALID_ARGUMENT) { e.message ?: "Failed to convert ExchangeStep" }
}
}
nextPageToken = results.last().updateTime.toByteArray().base64UrlEncode()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import org.wfanet.measurement.api.v2alpha.MeasurementKey
import org.wfanet.measurement.api.v2alpha.MeasurementPrincipal
import org.wfanet.measurement.api.v2alpha.MeasurementSpec
import org.wfanet.measurement.api.v2alpha.MeasurementsGrpcKt.MeasurementsCoroutineImplBase
import org.wfanet.measurement.api.v2alpha.ProtocolConfig.NoiseMechanism
import org.wfanet.measurement.api.v2alpha.copy
import org.wfanet.measurement.api.v2alpha.listMeasurementsPageToken
import org.wfanet.measurement.api.v2alpha.listMeasurementsResponse
Expand Down Expand Up @@ -72,6 +73,7 @@ private const val MISSING_RESOURCE_NAME_ERROR = "Resource name is either unspeci

class MeasurementsService(
private val internalMeasurementsStub: MeasurementsCoroutineStub,
private val noiseMechanisms: List<NoiseMechanism>
) : MeasurementsCoroutineImplBase() {

override suspend fun getMeasurement(request: GetMeasurementRequest): Measurement {
Expand Down Expand Up @@ -167,7 +169,8 @@ class MeasurementsService(
request.measurement.toInternal(
measurementConsumerCertificateKey,
dataProvidersMap,
parsedMeasurementSpec
parsedMeasurementSpec,
noiseMechanisms.map { it.toInternal() }
)
requestId = request.requestId
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
package org.wfanet.measurement.kingdom.service.api.v2alpha

import com.google.protobuf.util.Timestamps
import com.google.type.date
import com.google.type.interval
import java.time.ZoneOffset
import org.wfanet.measurement.api.Version
Expand Down Expand Up @@ -61,6 +60,7 @@ import org.wfanet.measurement.api.v2alpha.ModelSuite
import org.wfanet.measurement.api.v2alpha.ModelSuiteKey
import org.wfanet.measurement.api.v2alpha.ProtocolConfig
import org.wfanet.measurement.api.v2alpha.ProtocolConfig.NoiseMechanism
import org.wfanet.measurement.api.v2alpha.ProtocolConfigKt
import org.wfanet.measurement.api.v2alpha.ProtocolConfigKt.direct
import org.wfanet.measurement.api.v2alpha.ProtocolConfigKt.liquidLegionsV2
import org.wfanet.measurement.api.v2alpha.ProtocolConfigKt.protocol
Expand Down Expand Up @@ -107,6 +107,7 @@ import org.wfanet.measurement.internal.kingdom.ModelShard as InternalModelShard
import org.wfanet.measurement.internal.kingdom.ModelSuite as InternalModelSuite
import org.wfanet.measurement.internal.kingdom.ProtocolConfig as InternalProtocolConfig
import org.wfanet.measurement.internal.kingdom.ProtocolConfig.NoiseMechanism as InternalNoiseMechanism
import org.wfanet.measurement.internal.kingdom.ProtocolConfigKt as InternalProtocolConfigKt
import org.wfanet.measurement.internal.kingdom.duchyProtocolConfig
import org.wfanet.measurement.internal.kingdom.exchangeWorkflow
import org.wfanet.measurement.internal.kingdom.measurement as internalMeasurement
Expand All @@ -120,6 +121,10 @@ import org.wfanet.measurement.internal.kingdom.protocolConfig as internalProtoco
import org.wfanet.measurement.kingdom.deploy.common.Llv2ProtocolConfig
import org.wfanet.measurement.kingdom.deploy.common.RoLlv2ProtocolConfig

// (-- TODO(world-federation-of-advertisers/cross-media-measurement-api/issues/160): this value
// won't be needed once the maximum frequency field is moved to measurement spec. --)
const val DEFAULT_MAXIMUM_FREQUENCY_DIRECT_DISTRIBUTION = 20

/** Converts an internal [InternalMeasurement.State] to a public [State]. */
fun InternalMeasurement.State.toState(): State =
when (this) {
Expand Down Expand Up @@ -180,13 +185,29 @@ fun InternalDifferentialPrivacyParams.toDifferentialPrivacyParams(): Differentia
/** Converts an internal [InternalNoiseMechanism] to a public [NoiseMechanism]. */
fun InternalNoiseMechanism.toNoiseMechanism(): NoiseMechanism {
return when (this) {
InternalNoiseMechanism.NONE -> NoiseMechanism.NONE
InternalNoiseMechanism.GEOMETRIC -> NoiseMechanism.GEOMETRIC
InternalNoiseMechanism.DISCRETE_GAUSSIAN -> NoiseMechanism.DISCRETE_GAUSSIAN
InternalNoiseMechanism.CONTINUOUS_LAPLACE -> NoiseMechanism.CONTINUOUS_LAPLACE
InternalNoiseMechanism.CONTINUOUS_GAUSSIAN -> NoiseMechanism.CONTINUOUS_GAUSSIAN
InternalNoiseMechanism.NOISE_MECHANISM_UNSPECIFIED,
InternalNoiseMechanism.UNRECOGNIZED -> error("invalid internal noise mechanism.")
}
}

/** Converts a public [NoiseMechanism] to an internal [InternalNoiseMechanism]. */
fun NoiseMechanism.toInternal(): InternalNoiseMechanism {
return when (this) {
NoiseMechanism.GEOMETRIC -> InternalNoiseMechanism.GEOMETRIC
NoiseMechanism.DISCRETE_GAUSSIAN -> InternalNoiseMechanism.DISCRETE_GAUSSIAN
NoiseMechanism.NONE -> InternalNoiseMechanism.NONE
NoiseMechanism.CONTINUOUS_LAPLACE -> InternalNoiseMechanism.CONTINUOUS_LAPLACE
NoiseMechanism.CONTINUOUS_GAUSSIAN -> InternalNoiseMechanism.CONTINUOUS_GAUSSIAN
NoiseMechanism.NOISE_MECHANISM_UNSPECIFIED,
NoiseMechanism.UNRECOGNIZED -> error("invalid internal noise mechanism.")
}
}

/** Converts an internal [InternalProtocolConfig] to a public [ProtocolConfig]. */
fun InternalProtocolConfig.toProtocolConfig(
measurementTypeCase: MeasurementSpec.MeasurementTypeCase,
Expand All @@ -211,10 +232,23 @@ fun InternalProtocolConfig.toProtocolConfig(
ProtocolConfig.MeasurementType.REACH,
ProtocolConfig.MeasurementType.REACH_AND_FREQUENCY -> {
if (dataProviderCount == 1) {
protocols += protocol { direct = direct {} }
protocols += protocol {
if (source.hasDirect()) {
direct = source.direct.toDirect()
} else {
// For backward compatibility
direct = direct {}
}
}
} else {
@Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Protobuf enum fields are never null.
when (source.protocolCase) {
InternalProtocolConfig.ProtocolCase.DIRECT -> {
error(
"Direct protocol of reach computation shouldn't be used when number of data " +
"providers is greater than 1."
)
}
InternalProtocolConfig.ProtocolCase.LIQUID_LEGIONS_V2 -> {
protocols += protocol {
liquidLegionsV2 = liquidLegionsV2 {
Expand Down Expand Up @@ -288,6 +322,45 @@ fun InternalProtocolConfig.toProtocolConfig(
}
}

/**
* Converts an internal [InternalProtocolConfig.Direct] to a public [InternalProtocolConfig.Direct].
*/
private fun InternalProtocolConfig.Direct.toDirect(): ProtocolConfig.Direct {
val source = this

return direct {
noiseMechanisms +=
source.noiseMechanismsList.map { internalNoiseMechanism ->
internalNoiseMechanism.toNoiseMechanism()
}

if (source.hasDeterministicCountDistinct()) {
deterministicCountDistinct = ProtocolConfigKt.DirectKt.deterministicCountDistinct {}
}
if (source.hasDeterministicDistribution()) {
deterministicDistribution =
ProtocolConfigKt.DirectKt.deterministicDistribution {
maximumFrequency = source.deterministicDistribution.maximumFrequency
}
}
if (source.hasDeterministicCount()) {
deterministicCount = ProtocolConfigKt.DirectKt.deterministicCount {}
}
if (source.hasDeterministicSum()) {
deterministicSum = ProtocolConfigKt.DirectKt.deterministicSum {}
}
if (source.hasLiquidLegionsCountDistinct()) {
liquidLegionsCountDistinct = ProtocolConfigKt.DirectKt.liquidLegionsCountDistinct {}
}
if (source.hasLiquidLegionsDistribution()) {
liquidLegionsDistribution =
ProtocolConfigKt.DirectKt.liquidLegionsDistribution {
maximumFrequency = source.liquidLegionsDistribution.maximumFrequency
}
}
}
}

/** Converts an internal [InternalModelSuite] to a public [ModelSuite]. */
fun InternalModelSuite.toModelSuite(): ModelSuite {
val source = this
Expand Down Expand Up @@ -709,7 +782,8 @@ fun Map.Entry<Long, DataProviderValue>.toDataProviderEntry(): DataProviderEntry
fun Measurement.toInternal(
measurementConsumerCertificateKey: MeasurementConsumerCertificateKey,
dataProvidersMap: Map<Long, DataProviderValue>,
measurementSpecProto: MeasurementSpec
measurementSpecProto: MeasurementSpec,
internalNoiseMechanisms: List<InternalProtocolConfig.NoiseMechanism>
): InternalMeasurement {
val publicMeasurement = this

Expand Down Expand Up @@ -746,6 +820,17 @@ fun Measurement.toInternal(
liquidLegionsV2 = Llv2ProtocolConfig.duchyProtocolConfig
}
}
} else if (dataProvidersCount == 1) {
protocolConfig = internalProtocolConfig {
direct =
InternalProtocolConfigKt.direct {
this.noiseMechanisms += internalNoiseMechanisms
deterministicCountDistinct =
InternalProtocolConfigKt.DirectKt.deterministicCountDistinct {}
liquidLegionsCountDistinct =
InternalProtocolConfigKt.DirectKt.liquidLegionsCountDistinct {}
}
}
}
}
MeasurementSpec.MeasurementTypeCase.REACH_AND_FREQUENCY -> {
Expand All @@ -757,10 +842,45 @@ fun Measurement.toInternal(
duchyProtocolConfig = duchyProtocolConfig {
liquidLegionsV2 = Llv2ProtocolConfig.duchyProtocolConfig
}
} else if (dataProvidersCount == 1) {
protocolConfig = internalProtocolConfig {
direct =
InternalProtocolConfigKt.direct {
this.noiseMechanisms += internalNoiseMechanisms
deterministicCountDistinct =
InternalProtocolConfigKt.DirectKt.deterministicCountDistinct {}
liquidLegionsCountDistinct =
InternalProtocolConfigKt.DirectKt.liquidLegionsCountDistinct {}
deterministicDistribution =
InternalProtocolConfigKt.DirectKt.deterministicDistribution {
maximumFrequency = DEFAULT_MAXIMUM_FREQUENCY_DIRECT_DISTRIBUTION
}
liquidLegionsDistribution =
InternalProtocolConfigKt.DirectKt.liquidLegionsDistribution {
maximumFrequency = DEFAULT_MAXIMUM_FREQUENCY_DIRECT_DISTRIBUTION
}
}
}
}
}
MeasurementSpec.MeasurementTypeCase.IMPRESSION -> {
protocolConfig = internalProtocolConfig {
direct =
InternalProtocolConfigKt.direct {
this.noiseMechanisms += internalNoiseMechanisms
deterministicCount = InternalProtocolConfigKt.DirectKt.deterministicCount {}
}
}
}
MeasurementSpec.MeasurementTypeCase.DURATION -> {
protocolConfig = internalProtocolConfig {
direct =
InternalProtocolConfigKt.direct {
this.noiseMechanisms += internalNoiseMechanisms
deterministicSum = InternalProtocolConfigKt.DirectKt.deterministicSum {}
}
}
}
MeasurementSpec.MeasurementTypeCase.IMPRESSION,
MeasurementSpec.MeasurementTypeCase.DURATION, -> {}
MeasurementSpec.MeasurementTypeCase.MEASUREMENTTYPE_NOT_SET ->
error("MeasurementType not set.")
}
Expand Down
Loading

0 comments on commit 4e9e76d

Please sign in to comment.