Skip to content

Commit

Permalink
Use participant count instead of gRPC stubs count in LLv2 mills. (#1197)
Browse files Browse the repository at this point in the history
Fixes #1194
  • Loading branch information
SanjayVas authored and ple13 committed Aug 16, 2024
1 parent 6e6c2cf commit 92cd0e8
Show file tree
Hide file tree
Showing 7 changed files with 347 additions and 98 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -473,16 +473,18 @@ class LiquidLegionsV2Mill(
private suspend fun completeSetupPhaseAtAggregator(token: ComputationToken): ComputationToken {
val llv2Details = token.computationDetails.liquidLegionsV2
require(AGGREGATOR == llv2Details.role) { "invalid role for this function." }
// TODO(world-federation-of-advertisers/cross-media-measurement#1194): Fix this to
// handle the case where the set of computation participants is a subset of all Duchies.
val inputBlobCount = workerStubs.size
val inputBlobCount = token.participantCount - 1
val (bytes, nextToken) =
existingOutputOr(token) {
val request =
dataClients
.readAllRequisitionBlobs(token, duchyId)
.concat(readAndCombineAllInputBlobs(token, inputBlobCount))
.toCompleteSetupPhaseRequest(llv2Details, token.requisitionsCount)
.toCompleteSetupPhaseRequest(
llv2Details,
token.requisitionsCount,
token.participantCount
)
val cryptoResult: CompleteSetupPhaseResponse = cryptoWorker.completeSetupPhase(request)
logStageDurationMetric(
token,
Expand Down Expand Up @@ -518,7 +520,11 @@ class LiquidLegionsV2Mill(
val request =
dataClients
.readAllRequisitionBlobs(token, duchyId)
.toCompleteSetupPhaseRequest(llv2Details, token.requisitionsCount)
.toCompleteSetupPhaseRequest(
llv2Details,
token.requisitionsCount,
token.participantCount
)
val cryptoResult: CompleteSetupPhaseResponse = cryptoWorker.completeSetupPhase(request)
logStageDurationMetric(
token,
Expand Down Expand Up @@ -564,7 +570,7 @@ class LiquidLegionsV2Mill(
totalSketchesCount = token.requisitionsCount
noiseMechanism = llv2Details.parameters.noise.noiseMechanism
if (llv2Parameters.noise.hasFrequencyNoiseConfig() && maximumRequestedFrequency > 1) {
noiseParameters = getFrequencyNoiseParams(llv2Parameters)
noiseParameters = getFrequencyNoiseParams(llv2Parameters, token.participantCount)
}
}
val cryptoResult: CompleteExecutionPhaseOneAtAggregatorResponse =
Expand Down Expand Up @@ -674,17 +680,13 @@ class LiquidLegionsV2Mill(
vidSamplingIntervalWidth = measurementSpec.vidSamplingInterval.width
if (llv2Parameters.noise.hasReachNoiseConfig()) {
reachDpNoiseBaseline = globalReachDpNoiseBaseline {
// TODO(world-federation-of-advertisers/cross-media-measurement#1194): Fix this to
// handle the case where the computation participants is a subset of all Duchies.
contributorsCount = workerStubs.size + 1
contributorsCount = token.participantCount
globalReachDpNoise = llv2Parameters.noise.reachNoiseConfig.globalReachDpNoise
}
}
if (llv2Parameters.noise.hasFrequencyNoiseConfig() && (maximumRequestedFrequency > 1)) {
frequencyNoiseParameters = flagCountTupleNoiseGenerationParameters {
// TODO(world-federation-of-advertisers/cross-media-measurement#1194): Fix this to
// handle the case where the computation participants is a subset of all Duchies.
contributorsCount = workerStubs.size + 1
contributorsCount = token.participantCount
maximumFrequency = maximumRequestedFrequency
dpParams = llv2Parameters.noise.frequencyNoiseConfig
}
Expand Down Expand Up @@ -768,7 +770,7 @@ class LiquidLegionsV2Mill(
if (llv2Parameters.noise.hasFrequencyNoiseConfig()) {
partialCompositeElGamalPublicKey = llv2Details.partiallyCombinedPublicKey
if (maximumRequestedFrequency > 1) {
noiseParameters = getFrequencyNoiseParams(llv2Parameters)
noiseParameters = getFrequencyNoiseParams(llv2Parameters, token.participantCount)
}
}
noiseMechanism = llv2Details.parameters.noise.noiseMechanism
Expand Down Expand Up @@ -824,9 +826,7 @@ class LiquidLegionsV2Mill(
maximumFrequency = llv2Parameters.maximumFrequency.coerceAtLeast(1)
if (llv2Parameters.noise.hasFrequencyNoiseConfig()) {
globalFrequencyDpNoisePerBucket = perBucketFrequencyDpNoiseBaseline {
// TODO(world-federation-of-advertisers/cross-media-measurement#1194): Fix this to
// handle the case where the computation participants is a subset of all Duchies.
contributorsCount = workerStubs.size + 1
contributorsCount = token.participantCount
dpParams = llv2Parameters.noise.frequencyNoiseConfig
}
}
Expand Down Expand Up @@ -945,7 +945,8 @@ class LiquidLegionsV2Mill(

private fun ByteString.toCompleteSetupPhaseRequest(
llv2Details: LiquidLegionsSketchAggregationV2.ComputationDetails,
totalRequisitionsCount: Int
totalRequisitionsCount: Int,
participantCount: Int,
): CompleteSetupPhaseRequest {
val noiseConfig = llv2Details.parameters.noise
return completeSetupPhaseRequest {
Expand All @@ -955,9 +956,7 @@ class LiquidLegionsV2Mill(
noiseParameters = registerNoiseGenerationParameters {
compositeElGamalPublicKey = llv2Details.combinedPublicKey
curveId = llv2Details.parameters.ellipticCurveId.toLong()
// TODO(world-federation-of-advertisers/cross-media-measurement#1194): Fix this to handle
// the case where the computation participants is a subset of all Duchies.
contributorsCount = workerStubs.size + 1
contributorsCount = participantCount
totalSketchesCount = totalRequisitionsCount
dpParams = reachNoiseDifferentialPrivacyParams {
blindHistogram = noiseConfig.reachNoiseConfig.blindHistogramNoise
Expand All @@ -972,17 +971,25 @@ class LiquidLegionsV2Mill(
}

private fun getFrequencyNoiseParams(
llv2Parameters: Parameters
llv2Parameters: Parameters,
participantCount: Int,
): FlagCountTupleNoiseGenerationParameters {
return flagCountTupleNoiseGenerationParameters {
maximumFrequency = llv2Parameters.maximumFrequency
// TODO(world-federation-of-advertisers/cross-media-measurement#1194): Fix this to handle the
// case where the computation participants is a subset of all Duchies.
contributorsCount = workerStubs.size + 1
contributorsCount = participantCount
dpParams = llv2Parameters.noise.frequencyNoiseConfig
}
}

private val ComputationToken.participantCount: Int
get() =
if (computationDetails.kingdomComputation.participantCount != 0) {
computationDetails.kingdomComputation.participantCount
} else {
// For legacy Computations. See world-federation-of-advertisers/cross-media-measurement#1194
workerStubs.size + 1
}

companion object {
init {
loadLibrary(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,7 @@ class ReachOnlyLiquidLegionsV2Mill(
private val executionPhaseCryptoCpuTimeDurationHistogram: LongHistogram =
meter.histogramBuilder("execution_phase_crypto_cpu_time_duration_millis").ofLongs().build()

override val endingStage =
ReachOnlyLiquidLegionsSketchAggregationV2.Stage.COMPLETE.toProtocolStage()
override val endingStage = Stage.COMPLETE.toProtocolStage()

private val actions =
mapOf(
Expand All @@ -177,8 +176,6 @@ class ReachOnlyLiquidLegionsV2Mill(
Pair(Stage.EXECUTION_PHASE, NON_AGGREGATOR) to ::completeExecutionPhaseAtNonAggregator,
)

private val kBytesPerCipherText = 66

override suspend fun processComputationImpl(token: ComputationToken) {
require(token.computationDetails.hasReachOnlyLiquidLegionsV2()) {
"Only Reach Only Liquid Legions V2 computation is supported in this mill."
Expand Down Expand Up @@ -454,16 +451,18 @@ class ReachOnlyLiquidLegionsV2Mill(
private suspend fun completeSetupPhaseAtAggregator(token: ComputationToken): ComputationToken {
val rollv2Details = token.computationDetails.reachOnlyLiquidLegionsV2
require(AGGREGATOR == rollv2Details.role) { "invalid role for this function." }
// TODO(world-federation-of-advertisers/cross-media-measurement#1194): Fix this to
// handle the case where the set of computation participants is a subset of all Duchies.
val inputBlobCount = workerStubs.size
val inputBlobCount = token.participantCount - 1
val (bytes, nextToken) =
existingOutputOr(token) {
val request =
dataClients
.readAllRequisitionBlobs(token, duchyId)
.concat(readAndCombineAllInputBlobsSetupPhaseAtAggregator(token, inputBlobCount))
.toCompleteSetupPhaseAtAggregatorRequest(rollv2Details, token.requisitionsCount)
.toCompleteSetupPhaseAtAggregatorRequest(
rollv2Details,
token.requisitionsCount,
token.participantCount
)
val cryptoResult: CompleteReachOnlySetupPhaseResponse =
cryptoWorker.completeReachOnlySetupPhaseAtAggregator(request)
logStageDurationMetric(
Expand Down Expand Up @@ -501,7 +500,11 @@ class ReachOnlyLiquidLegionsV2Mill(
val request =
dataClients
.readAllRequisitionBlobs(token, duchyId)
.toCompleteReachOnlySetupPhaseRequest(rollv2Details, token.requisitionsCount)
.toCompleteReachOnlySetupPhaseRequest(
rollv2Details,
token.requisitionsCount,
token.participantCount
)
val cryptoResult: CompleteReachOnlySetupPhaseResponse =
cryptoWorker.completeReachOnlySetupPhase(request)
logStageDurationMetric(
Expand Down Expand Up @@ -541,24 +544,22 @@ class ReachOnlyLiquidLegionsV2Mill(
val measurementSpec =
MeasurementSpec.parseFrom(token.computationDetails.kingdomComputation.measurementSpec)
val inputBlob = readAndCombineAllInputBlobs(token, 1)
require(inputBlob.size() >= kBytesPerCipherText) {
require(inputBlob.size() >= BYTES_PER_CIPHERTEXT) {
("Invalid input blob size. Input blob ${inputBlob.toStringUtf8()} has size " +
"${inputBlob.size()} which is less than ($kBytesPerCipherText).")
"${inputBlob.size()} which is less than ($BYTES_PER_CIPHERTEXT).")
}
var reach = 0L
val (bytes, nextToken) =
val (_, nextToken) =
existingOutputOr(token) {
val request = completeReachOnlyExecutionPhaseAtAggregatorRequest {
combinedRegisterVector = inputBlob.substring(0, inputBlob.size() - kBytesPerCipherText)
combinedRegisterVector = inputBlob.substring(0, inputBlob.size() - BYTES_PER_CIPHERTEXT)
localElGamalKeyPair = rollv2Details.localElgamalKey
curveId = rollv2Details.parameters.ellipticCurveId.toLong()
serializedExcessiveNoiseCiphertext =
inputBlob.substring(inputBlob.size() - kBytesPerCipherText, inputBlob.size())
inputBlob.substring(inputBlob.size() - BYTES_PER_CIPHERTEXT, inputBlob.size())
if (rollv2Parameters.noise.hasReachNoiseConfig()) {
reachDpNoiseBaseline = globalReachDpNoiseBaseline {
// TODO(world-federation-of-advertisers/cross-media-measurement#1194): Fix this to
// handle the case where the computation participants is a subset of all Duchies.
contributorsCount = workerStubs.size + 1
contributorsCount = token.participantCount
globalReachDpNoise = rollv2Parameters.noise.reachNoiseConfig.globalReachDpNoise
}
}
Expand All @@ -571,9 +572,7 @@ class ReachOnlyLiquidLegionsV2Mill(
noiseParameters = registerNoiseGenerationParameters {
compositeElGamalPublicKey = rollv2Details.combinedPublicKey
curveId = rollv2Details.parameters.ellipticCurveId.toLong()
// TODO(world-federation-of-advertisers/cross-media-measurement#1194): Fix this to
// handle the case where the computation participants is a subset of all Duchies.
contributorsCount = workerStubs.size + 1
contributorsCount = token.participantCount
totalSketchesCount = token.requisitionsCount
dpParams = reachNoiseDifferentialPrivacyParams {
blindHistogram = noiseConfig.reachNoiseConfig.blindHistogramNoise
Expand Down Expand Up @@ -607,21 +606,21 @@ class ReachOnlyLiquidLegionsV2Mill(
val rollv2Details = token.computationDetails.reachOnlyLiquidLegionsV2
require(NON_AGGREGATOR == rollv2Details.role) { "invalid role for this function." }
val inputBlob = readAndCombineAllInputBlobs(token, 1)
require(inputBlob.size() >= kBytesPerCipherText) {
require(inputBlob.size() >= BYTES_PER_CIPHERTEXT) {
("Invalid input blob size. Input blob ${inputBlob.toStringUtf8()} has size " +
"${inputBlob.size()} which is less than ($kBytesPerCipherText).")
"${inputBlob.size()} which is less than ($BYTES_PER_CIPHERTEXT).")
}
val (bytes, nextToken) =
existingOutputOr(token) {
val cryptoResult: CompleteReachOnlyExecutionPhaseResponse =
cryptoWorker.completeReachOnlyExecutionPhase(
completeReachOnlyExecutionPhaseRequest {
combinedRegisterVector =
inputBlob.substring(0, inputBlob.size() - kBytesPerCipherText)
inputBlob.substring(0, inputBlob.size() - BYTES_PER_CIPHERTEXT)
localElGamalKeyPair = rollv2Details.localElgamalKey
curveId = rollv2Details.parameters.ellipticCurveId.toLong()
serializedExcessiveNoiseCiphertext =
inputBlob.substring(inputBlob.size() - kBytesPerCipherText, inputBlob.size())
inputBlob.substring(inputBlob.size() - BYTES_PER_CIPHERTEXT, inputBlob.size())
parallelism = this@ReachOnlyLiquidLegionsV2Mill.parallelism
}
)
Expand Down Expand Up @@ -695,7 +694,8 @@ class ReachOnlyLiquidLegionsV2Mill(

private fun ByteString.toCompleteReachOnlySetupPhaseRequest(
rollv2Details: ReachOnlyLiquidLegionsSketchAggregationV2.ComputationDetails,
totalRequisitionsCount: Int
totalRequisitionsCount: Int,
participantCount: Int,
): CompleteReachOnlySetupPhaseRequest {
val noiseConfig = rollv2Details.parameters.noise
return completeReachOnlySetupPhaseRequest {
Expand All @@ -705,9 +705,7 @@ class ReachOnlyLiquidLegionsV2Mill(
noiseParameters = registerNoiseGenerationParameters {
compositeElGamalPublicKey = rollv2Details.combinedPublicKey
curveId = rollv2Details.parameters.ellipticCurveId.toLong()
// TODO(world-federation-of-advertisers/cross-media-measurement#1194): Fix this to handle
// the case where the set of computation participants is a subset of all Duchies.
contributorsCount = workerStubs.size + 1
contributorsCount = participantCount
totalSketchesCount = totalRequisitionsCount
dpParams = reachNoiseDifferentialPrivacyParams {
blindHistogram = noiseConfig.reachNoiseConfig.blindHistogramNoise
Expand All @@ -725,26 +723,21 @@ class ReachOnlyLiquidLegionsV2Mill(

private fun ByteString.toCompleteSetupPhaseAtAggregatorRequest(
rollv2Details: ReachOnlyLiquidLegionsSketchAggregationV2.ComputationDetails,
totalRequisitionsCount: Int
totalRequisitionsCount: Int,
participantCount: Int,
): CompleteReachOnlySetupPhaseRequest {
val noiseConfig = rollv2Details.parameters.noise
val combinedInputBlobs = this@toCompleteSetupPhaseAtAggregatorRequest
val combinedRegisterVectorSizeBytes =
combinedInputBlobs.size() - (participantCount - 1) * BYTES_PER_CIPHERTEXT
return completeReachOnlySetupPhaseRequest {
combinedRegisterVector =
combinedInputBlobs.substring(
0,
// TODO(world-federation-of-advertisers/cross-media-measurement#1194): Fix this to handle
// the case where the set of computation participants is a subset of all Duchies.
combinedInputBlobs.size() - workerStubs.size * kBytesPerCipherText
)
combinedRegisterVector = combinedInputBlobs.substring(0, combinedRegisterVectorSizeBytes)
curveId = rollv2Details.parameters.ellipticCurveId.toLong()
if (noiseConfig.hasReachNoiseConfig()) {
noiseParameters = registerNoiseGenerationParameters {
compositeElGamalPublicKey = rollv2Details.combinedPublicKey
curveId = rollv2Details.parameters.ellipticCurveId.toLong()
// TODO(world-federation-of-advertisers/cross-media-measurement#1194): Fix this to handle
// the case where the set of computation participants is a subset of all Duchies.
contributorsCount = workerStubs.size + 1
contributorsCount = participantCount
totalSketchesCount = totalRequisitionsCount
dpParams = reachNoiseDifferentialPrivacyParams {
blindHistogram = noiseConfig.reachNoiseConfig.blindHistogramNoise
Expand All @@ -756,18 +749,13 @@ class ReachOnlyLiquidLegionsV2Mill(
}
compositeElGamalPublicKey = rollv2Details.combinedPublicKey
serializedExcessiveNoiseCiphertext =
combinedInputBlobs.substring(
// TODO(world-federation-of-advertisers/cross-media-measurement#1194): Fix this to handle
// the case where the set of computation participants is a subset of all Duchies.
combinedInputBlobs.size() - workerStubs.size * kBytesPerCipherText,
combinedInputBlobs.size()
)
combinedInputBlobs.substring(combinedRegisterVectorSizeBytes, combinedInputBlobs.size())
parallelism = this@ReachOnlyLiquidLegionsV2Mill.parallelism
}
}

/** Reads all input blobs and combines all the bytes together. */
protected suspend fun readAndCombineAllInputBlobsSetupPhaseAtAggregator(
private suspend fun readAndCombineAllInputBlobsSetupPhaseAtAggregator(
token: ComputationToken,
count: Int
): ByteString {
Expand All @@ -780,19 +768,30 @@ class ReachOnlyLiquidLegionsV2Mill(
var combinedRegisterVector = ByteString.EMPTY
var combinedNoiseCiphertext = ByteString.EMPTY
for (str in blobMap.values) {
require(str.size() >= kBytesPerCipherText) {
require(str.size() >= BYTES_PER_CIPHERTEXT) {
("Invalid input blob size. Input blob ${str.toStringUtf8()} has size " +
"${str.size()} which is less than ($kBytesPerCipherText).")
"${str.size()} which is less than ($BYTES_PER_CIPHERTEXT).")
}
combinedRegisterVector =
combinedRegisterVector.concat(str.substring(0, str.size() - kBytesPerCipherText))
combinedRegisterVector.concat(str.substring(0, str.size() - BYTES_PER_CIPHERTEXT))
combinedNoiseCiphertext =
combinedNoiseCiphertext.concat(str.substring(str.size() - kBytesPerCipherText, str.size()))
combinedNoiseCiphertext.concat(str.substring(str.size() - BYTES_PER_CIPHERTEXT, str.size()))
}
return combinedRegisterVector.concat(combinedNoiseCiphertext)
}

private val ComputationToken.participantCount: Int
get() =
if (computationDetails.kingdomComputation.participantCount != 0) {
computationDetails.kingdomComputation.participantCount
} else {
// For legacy Computations. See world-federation-of-advertisers/cross-media-measurement#1194
workerStubs.size + 1
}

companion object {
private const val BYTES_PER_CIPHERTEXT = 66

private val logger: Logger = Logger.getLogger(this::class.java.name)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ fun SystemComputation.toKingdomComputationDetails(): KingdomComputationDetails {
return kingdomComputationDetails {
publicApiVersion = source.publicApiVersion
measurementSpec = source.measurementSpec
participantCount = source.computationParticipantsCount
when (Version.fromString(source.publicApiVersion)) {
Version.V2_ALPHA -> {
val measurementSpec = MeasurementSpec.parseFrom(source.measurementSpec)
Expand Down
Loading

0 comments on commit 92cd0e8

Please sign in to comment.