Skip to content

Commit

Permalink
Circulate direct computation methodology from protocol config to meas…
Browse files Browse the repository at this point in the history
…urement result.
  • Loading branch information
riemanli committed Aug 5, 2023
1 parent 7298415 commit 4e84993
Show file tree
Hide file tree
Showing 18 changed files with 566 additions and 44 deletions.
4 changes: 2 additions & 2 deletions build/repositories.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def wfa_measurement_system_repositories():
wfa_repo_archive(
name = "wfa_measurement_proto",
repo = "cross-media-measurement-api",
sha256 = "3ccf5e4e81f2b0cd9abfc0fe9945096e6ff1c18577a9d9f67ea60470c64c3ec3",
version = "0.39.1",
sha256 = "642106dd7c10b4c8820c31c3c18f54e7a5b9480adc85b9f6e58b267fd8f7a62e",
commit = "cf1af8937a764c491e1c99a79193ac677381c03a",
)

wfa_repo_archive(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ kt_jvm_library(
"//src/main/kotlin/org/wfanet/measurement/kingdom/deploy/common/service:data_services",
"//src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha:api_key_authentication_server_interceptor",
"//src/main/kotlin/org/wfanet/measurement/loadtest/panelmatchresourcesetup",
"//src/main/proto/wfa/measurement/api/v2alpha:protocol_config_kt_jvm_proto",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc/testing",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/testing",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ import org.wfanet.measurement.kingdom.service.system.v1alpha.ComputationLogEntri
import org.wfanet.measurement.kingdom.service.system.v1alpha.ComputationParticipantsService as SystemComputationParticipantsService
import org.wfanet.measurement.kingdom.service.system.v1alpha.ComputationsService as SystemComputationsService
import org.wfanet.measurement.kingdom.service.system.v1alpha.RequisitionsService as SystemRequisitionsService
import org.wfanet.measurement.api.v2alpha.ProtocolConfig
import org.wfanet.measurement.loadtest.panelmatchresourcesetup.PanelMatchResourceSetup

/** TestRule that starts and stops all Kingdom gRPC services. */
Expand Down Expand Up @@ -150,7 +151,7 @@ class InProcessKingdom(
EventGroupMetadataDescriptorsService(internalEventGroupMetadataDescriptorsClient)
.withMetadataPrincipalIdentities()
.withApiKeyAuthenticationServerInterceptor(internalApiKeysClient),
MeasurementsService(internalMeasurementsClient)
MeasurementsService(internalMeasurementsClient, measurementNoiseMechanisms)
.withMetadataPrincipalIdentities()
.withApiKeyAuthenticationServerInterceptor(internalApiKeysClient),
PublicKeysService(internalPublicKeysClient)
Expand Down Expand Up @@ -206,5 +207,11 @@ class InProcessKingdom(

/** Default deadline for RPCs to internal server in milliseconds. */
private const val DEFAULT_INTERNAL_DEADLINE_MILLIS = 30_000L
private val measurementNoiseMechanisms: List<ProtocolConfig.NoiseMechanism> =
listOf(
ProtocolConfig.NoiseMechanism.NONE,
ProtocolConfig.NoiseMechanism.GEOMETRIC,
ProtocolConfig.NoiseMechanism.DISCRETE_GAUSSIAN,
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package org.wfanet.measurement.kingdom.deploy.common.server
import io.grpc.ServerServiceDefinition
import java.io.File
import org.wfanet.measurement.api.v2alpha.AkidPrincipalLookup
import org.wfanet.measurement.api.v2alpha.ProtocolConfig
import org.wfanet.measurement.api.v2alpha.withPrincipalsFromX509AuthorityKeyIdentifiers
import org.wfanet.measurement.common.commandLineMain
import org.wfanet.measurement.common.crypto.SigningCerts
Expand Down Expand Up @@ -109,6 +110,19 @@ private fun run(
.withDefaultDeadline(kingdomApiServerFlags.internalApiFlags.defaultDeadlineDuration)

val principalLookup = AkidPrincipalLookup(v2alphaFlags.authorityKeyIdentifierToPrincipalMapFile)
val noiseMechanisms = mutableListOf<ProtocolConfig.NoiseMechanism>()
if (v2alphaFlags.directNoiseMechanismInput.noNoise) {
noiseMechanisms += ProtocolConfig.NoiseMechanism.NONE
}
if (v2alphaFlags.directNoiseMechanismInput.geometryNoise) {
noiseMechanisms += ProtocolConfig.NoiseMechanism.GEOMETRIC
}
if (v2alphaFlags.directNoiseMechanismInput.discreteGaussianNoise) {
noiseMechanisms += ProtocolConfig.NoiseMechanism.DISCRETE_GAUSSIAN
}
if (noiseMechanisms.size == 0) {
error("No noise mechanism is selected.")
}

val internalAccountsCoroutineStub = InternalAccountsCoroutineStub(channel)
val internalApiKeysCoroutineStub = InternalApiKeysCoroutineStub(channel)
Expand Down Expand Up @@ -141,9 +155,7 @@ private fun run(
)
.withPrincipalsFromX509AuthorityKeyIdentifiers(principalLookup)
.withApiKeyAuthenticationServerInterceptor(internalApiKeysCoroutineStub),
MeasurementsService(
InternalMeasurementsCoroutineStub(channel),
)
MeasurementsService(InternalMeasurementsCoroutineStub(channel), noiseMechanisms)
.withPrincipalsFromX509AuthorityKeyIdentifiers(principalLookup)
.withApiKeyAuthenticationServerInterceptor(internalApiKeysCoroutineStub),
MeasurementConsumersService(InternalMeasurementConsumersCoroutineStub(channel))
Expand Down Expand Up @@ -192,6 +204,32 @@ fun main(args: Array<String>) = commandLineMain(::run, args)

/** Flags specific to the V2alpha API version. */
private class V2alphaFlags {
class DirectNoiseMechanismInput {
@CommandLine.Option(
names = ["--none"],
description = ["Allow no noise added to the result of direct computation."],
required = false
)
var noNoise = false
private set

@CommandLine.Option(
names = ["--geometry"],
description = ["Allow geometry (Laplace) noise added to the result of direct computation."],
required = false
)
var geometryNoise = false
private set

@CommandLine.Option(
names = ["--discrete-gaussian"],
description = ["Allow discrete Gaussian noise added to the result of direct computation."],
required = false
)
var discreteGaussianNoise = false
private set
}

@CommandLine.Option(
names = ["--authority-key-identifier-to-principal-map-file"],
description = ["File path to a AuthorityKeyToPrincipalMap textproto"],
Expand All @@ -207,4 +245,8 @@ private class V2alphaFlags {
)
lateinit var redirectUri: String
private set

@CommandLine.ArgGroup(exclusive = true, multiplicity = "1", heading = "Direct noise mechanisms\n")
lateinit var directNoiseMechanismInput: DirectNoiseMechanismInput
private set
}
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,9 @@ class CreateMeasurement(private val request: CreateMeasurementRequest) :
ProtocolConfig.ProtocolCase.REACH_ONLY_LIQUID_LEGIONS_V2 -> {
createComputedMeasurement(request.measurement, measurementConsumerId)
}
ProtocolConfig.ProtocolCase.PROTOCOL_NOT_SET ->
ProtocolConfig.ProtocolCase.DIRECT ->
createDirectMeasurement(request.measurement, measurementConsumerId)
ProtocolConfig.ProtocolCase.PROTOCOL_NOT_SET -> error("Protocol is not set.")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ class ExchangeStepsService(private val internalExchangeSteps: InternalExchangeSt
it.toV2Alpha()
} catch (e: Throwable) {
failGrpc(Status.INVALID_ARGUMENT) {
e.message ?: "Failed to convert ProtocolConfig ExchangeStep"
e.message ?: "Failed to convert ExchangeStep"
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import org.wfanet.measurement.api.v2alpha.MeasurementKey
import org.wfanet.measurement.api.v2alpha.MeasurementPrincipal
import org.wfanet.measurement.api.v2alpha.MeasurementSpec
import org.wfanet.measurement.api.v2alpha.MeasurementsGrpcKt.MeasurementsCoroutineImplBase
import org.wfanet.measurement.api.v2alpha.ProtocolConfig.NoiseMechanism
import org.wfanet.measurement.api.v2alpha.copy
import org.wfanet.measurement.api.v2alpha.listMeasurementsPageToken
import org.wfanet.measurement.api.v2alpha.listMeasurementsResponse
Expand Down Expand Up @@ -72,6 +73,7 @@ private const val MISSING_RESOURCE_NAME_ERROR = "Resource name is either unspeci

class MeasurementsService(
private val internalMeasurementsStub: MeasurementsCoroutineStub,
private val noiseMechanisms: List<NoiseMechanism>
) : MeasurementsCoroutineImplBase() {

override suspend fun getMeasurement(request: GetMeasurementRequest): Measurement {
Expand Down Expand Up @@ -167,7 +169,8 @@ class MeasurementsService(
request.measurement.toInternal(
measurementConsumerCertificateKey,
dataProvidersMap,
parsedMeasurementSpec
parsedMeasurementSpec,
noiseMechanisms.map { it.toInternal() }
)
requestId = request.requestId
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
package org.wfanet.measurement.kingdom.service.api.v2alpha

import com.google.protobuf.util.Timestamps
import com.google.type.date
import com.google.type.interval
import java.time.ZoneOffset
import org.wfanet.measurement.api.Version
Expand Down Expand Up @@ -61,6 +60,7 @@ import org.wfanet.measurement.api.v2alpha.ModelSuite
import org.wfanet.measurement.api.v2alpha.ModelSuiteKey
import org.wfanet.measurement.api.v2alpha.ProtocolConfig
import org.wfanet.measurement.api.v2alpha.ProtocolConfig.NoiseMechanism
import org.wfanet.measurement.api.v2alpha.ProtocolConfigKt
import org.wfanet.measurement.api.v2alpha.ProtocolConfigKt.direct
import org.wfanet.measurement.api.v2alpha.ProtocolConfigKt.liquidLegionsV2
import org.wfanet.measurement.api.v2alpha.ProtocolConfigKt.protocol
Expand Down Expand Up @@ -107,6 +107,7 @@ import org.wfanet.measurement.internal.kingdom.ModelShard as InternalModelShard
import org.wfanet.measurement.internal.kingdom.ModelSuite as InternalModelSuite
import org.wfanet.measurement.internal.kingdom.ProtocolConfig as InternalProtocolConfig
import org.wfanet.measurement.internal.kingdom.ProtocolConfig.NoiseMechanism as InternalNoiseMechanism
import org.wfanet.measurement.internal.kingdom.ProtocolConfigKt as InternalProtocolConfigKt
import org.wfanet.measurement.internal.kingdom.duchyProtocolConfig
import org.wfanet.measurement.internal.kingdom.exchangeWorkflow
import org.wfanet.measurement.internal.kingdom.measurement as internalMeasurement
Expand All @@ -120,6 +121,10 @@ import org.wfanet.measurement.internal.kingdom.protocolConfig as internalProtoco
import org.wfanet.measurement.kingdom.deploy.common.Llv2ProtocolConfig
import org.wfanet.measurement.kingdom.deploy.common.RoLlv2ProtocolConfig

// (-- TODO(world-federation-of-advertisers/cross-media-measurement-api/issues/160): this value
// won't be needed once the maximum frequency field is moved to measurement spec. --)
const val DEFAULT_MAXIMUM_FREQUENCY_DIRECT_DISTRIBUTION = 20

/** Converts an internal [InternalMeasurement.State] to a public [State]. */
fun InternalMeasurement.State.toState(): State =
when (this) {
Expand Down Expand Up @@ -182,11 +187,23 @@ fun InternalNoiseMechanism.toNoiseMechanism(): NoiseMechanism {
return when (this) {
InternalNoiseMechanism.GEOMETRIC -> NoiseMechanism.GEOMETRIC
InternalNoiseMechanism.DISCRETE_GAUSSIAN -> NoiseMechanism.DISCRETE_GAUSSIAN
InternalNoiseMechanism.NONE -> NoiseMechanism.NONE
InternalNoiseMechanism.NOISE_MECHANISM_UNSPECIFIED,
InternalNoiseMechanism.UNRECOGNIZED -> error("invalid internal noise mechanism.")
}
}

/** Converts a public [NoiseMechanism] to an internal [InternalNoiseMechanism]. */
fun NoiseMechanism.toInternal(): InternalNoiseMechanism {
return when (this) {
NoiseMechanism.GEOMETRIC -> InternalNoiseMechanism.GEOMETRIC
NoiseMechanism.DISCRETE_GAUSSIAN -> InternalNoiseMechanism.DISCRETE_GAUSSIAN
NoiseMechanism.NONE -> InternalNoiseMechanism.NONE
NoiseMechanism.NOISE_MECHANISM_UNSPECIFIED,
NoiseMechanism.UNRECOGNIZED -> error("invalid internal noise mechanism.")
}
}

/** Converts an internal [InternalProtocolConfig] to a public [ProtocolConfig]. */
fun InternalProtocolConfig.toProtocolConfig(
measurementTypeCase: MeasurementSpec.MeasurementTypeCase,
Expand All @@ -211,10 +228,23 @@ fun InternalProtocolConfig.toProtocolConfig(
ProtocolConfig.MeasurementType.REACH,
ProtocolConfig.MeasurementType.REACH_AND_FREQUENCY -> {
if (dataProviderCount == 1) {
protocols += protocol { direct = direct {} }
protocols += protocol {
if (source.hasDirect()) {
direct = source.direct.toDirect()
} else {
// For backward compatibility
direct = direct {}
}
}
} else {
@Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Protobuf enum fields are never null.
when (source.protocolCase) {
InternalProtocolConfig.ProtocolCase.DIRECT -> {
error(
"Direct protocol of reach computation shouldn't be used when number of data " +
"providers is greater than 1."
)
}
InternalProtocolConfig.ProtocolCase.LIQUID_LEGIONS_V2 -> {
protocols += protocol {
liquidLegionsV2 = liquidLegionsV2 {
Expand Down Expand Up @@ -288,6 +318,45 @@ fun InternalProtocolConfig.toProtocolConfig(
}
}

/**
* Converts an internal [InternalProtocolConfig.Direct] to a public [InternalProtocolConfig.Direct].
*/
private fun InternalProtocolConfig.Direct.toDirect(): ProtocolConfig.Direct {
val source = this

return direct {
noiseMechanisms +=
source.noiseMechanismsList.map { internalNoiseMechanism ->
internalNoiseMechanism.toNoiseMechanism()
}

if (source.hasDeterministicCountDistinct()) {
deterministicCountDistinct = ProtocolConfigKt.DirectKt.deterministicCountDistinct {}
}
if (source.hasDeterministicDistribution()) {
deterministicDistribution =
ProtocolConfigKt.DirectKt.deterministicDistribution {
maximumFrequency = source.deterministicDistribution.maximumFrequency
}
}
if (source.hasDeterministicCount()) {
deterministicCount = ProtocolConfigKt.DirectKt.deterministicCount {}
}
if (source.hasDeterministicSum()) {
deterministicSum = ProtocolConfigKt.DirectKt.deterministicSum {}
}
if (source.hasLiquidLegionsCountDistinct()) {
liquidLegionsCountDistinct = ProtocolConfigKt.DirectKt.liquidLegionsCountDistinct {}
}
if (source.hasLiquidLegionsDistribution()) {
liquidLegionsDistribution =
ProtocolConfigKt.DirectKt.liquidLegionsDistribution {
maximumFrequency = source.liquidLegionsDistribution.maximumFrequency
}
}
}
}

/** Converts an internal [InternalModelSuite] to a public [ModelSuite]. */
fun InternalModelSuite.toModelSuite(): ModelSuite {
val source = this
Expand Down Expand Up @@ -709,7 +778,8 @@ fun Map.Entry<Long, DataProviderValue>.toDataProviderEntry(): DataProviderEntry
fun Measurement.toInternal(
measurementConsumerCertificateKey: MeasurementConsumerCertificateKey,
dataProvidersMap: Map<Long, DataProviderValue>,
measurementSpecProto: MeasurementSpec
measurementSpecProto: MeasurementSpec,
internalNoiseMechanisms: List<InternalProtocolConfig.NoiseMechanism>
): InternalMeasurement {
val publicMeasurement = this

Expand Down Expand Up @@ -746,6 +816,17 @@ fun Measurement.toInternal(
liquidLegionsV2 = Llv2ProtocolConfig.duchyProtocolConfig
}
}
} else if (dataProvidersCount == 1) {
protocolConfig = internalProtocolConfig {
direct =
InternalProtocolConfigKt.direct {
this.noiseMechanisms += internalNoiseMechanisms
deterministicCountDistinct =
InternalProtocolConfigKt.DirectKt.deterministicCountDistinct {}
liquidLegionsCountDistinct =
InternalProtocolConfigKt.DirectKt.liquidLegionsCountDistinct {}
}
}
}
}
MeasurementSpec.MeasurementTypeCase.REACH_AND_FREQUENCY -> {
Expand All @@ -757,10 +838,49 @@ fun Measurement.toInternal(
duchyProtocolConfig = duchyProtocolConfig {
liquidLegionsV2 = Llv2ProtocolConfig.duchyProtocolConfig
}
} else if (dataProvidersCount == 1) {
protocolConfig = internalProtocolConfig {
direct =
InternalProtocolConfigKt.direct {
this.noiseMechanisms += internalNoiseMechanisms
deterministicCountDistinct =
InternalProtocolConfigKt.DirectKt.deterministicCountDistinct {}
liquidLegionsCountDistinct =
InternalProtocolConfigKt.DirectKt.liquidLegionsCountDistinct {}
deterministicDistribution =
InternalProtocolConfigKt.DirectKt.deterministicDistribution {
maximumFrequency = DEFAULT_MAXIMUM_FREQUENCY_DIRECT_DISTRIBUTION
}
liquidLegionsDistribution =
InternalProtocolConfigKt.DirectKt.liquidLegionsDistribution {
maximumFrequency = DEFAULT_MAXIMUM_FREQUENCY_DIRECT_DISTRIBUTION
}
}
}
}
}
MeasurementSpec.MeasurementTypeCase.IMPRESSION -> {
if (dataProvidersCount == 1) {
protocolConfig = internalProtocolConfig {
direct =
InternalProtocolConfigKt.direct {
this.noiseMechanisms += internalNoiseMechanisms
deterministicCount = InternalProtocolConfigKt.DirectKt.deterministicCount {}
}
}
}
}
MeasurementSpec.MeasurementTypeCase.DURATION -> {
if (dataProvidersCount == 1) {
protocolConfig = internalProtocolConfig {
direct =
InternalProtocolConfigKt.direct {
this.noiseMechanisms += internalNoiseMechanisms
deterministicSum = InternalProtocolConfigKt.DirectKt.deterministicSum {}
}
}
}
}
MeasurementSpec.MeasurementTypeCase.IMPRESSION,
MeasurementSpec.MeasurementTypeCase.DURATION, -> {}
MeasurementSpec.MeasurementTypeCase.MEASUREMENTTYPE_NOT_SET ->
error("MeasurementType not set.")
}
Expand Down
Loading

0 comments on commit 4e84993

Please sign in to comment.