diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/LiquidLegionsV2Mill.kt b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/LiquidLegionsV2Mill.kt index f2a8569cc2a..cfa5cb64165 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/LiquidLegionsV2Mill.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/LiquidLegionsV2Mill.kt @@ -473,16 +473,18 @@ class LiquidLegionsV2Mill( private suspend fun completeSetupPhaseAtAggregator(token: ComputationToken): ComputationToken { val llv2Details = token.computationDetails.liquidLegionsV2 require(AGGREGATOR == llv2Details.role) { "invalid role for this function." } - // TODO(world-federation-of-advertisers/cross-media-measurement#1194): Fix this to - // handle the case where the set of computation participants is a subset of all Duchies. - val inputBlobCount = workerStubs.size + val inputBlobCount = token.participantCount - 1 val (bytes, nextToken) = existingOutputOr(token) { val request = dataClients .readAllRequisitionBlobs(token, duchyId) .concat(readAndCombineAllInputBlobs(token, inputBlobCount)) - .toCompleteSetupPhaseRequest(llv2Details, token.requisitionsCount) + .toCompleteSetupPhaseRequest( + llv2Details, + token.requisitionsCount, + token.participantCount + ) val cryptoResult: CompleteSetupPhaseResponse = cryptoWorker.completeSetupPhase(request) logStageDurationMetric( token, @@ -518,7 +520,11 @@ class LiquidLegionsV2Mill( val request = dataClients .readAllRequisitionBlobs(token, duchyId) - .toCompleteSetupPhaseRequest(llv2Details, token.requisitionsCount) + .toCompleteSetupPhaseRequest( + llv2Details, + token.requisitionsCount, + token.participantCount + ) val cryptoResult: CompleteSetupPhaseResponse = cryptoWorker.completeSetupPhase(request) logStageDurationMetric( token, @@ -564,7 +570,7 @@ class LiquidLegionsV2Mill( totalSketchesCount = token.requisitionsCount noiseMechanism = llv2Details.parameters.noise.noiseMechanism if (llv2Parameters.noise.hasFrequencyNoiseConfig() && maximumRequestedFrequency > 1) { - noiseParameters = getFrequencyNoiseParams(llv2Parameters) + noiseParameters = getFrequencyNoiseParams(llv2Parameters, token.participantCount) } } val cryptoResult: CompleteExecutionPhaseOneAtAggregatorResponse = @@ -674,17 +680,13 @@ class LiquidLegionsV2Mill( vidSamplingIntervalWidth = measurementSpec.vidSamplingInterval.width if (llv2Parameters.noise.hasReachNoiseConfig()) { reachDpNoiseBaseline = globalReachDpNoiseBaseline { - // TODO(world-federation-of-advertisers/cross-media-measurement#1194): Fix this to - // handle the case where the computation participants is a subset of all Duchies. - contributorsCount = workerStubs.size + 1 + contributorsCount = token.participantCount globalReachDpNoise = llv2Parameters.noise.reachNoiseConfig.globalReachDpNoise } } if (llv2Parameters.noise.hasFrequencyNoiseConfig() && (maximumRequestedFrequency > 1)) { frequencyNoiseParameters = flagCountTupleNoiseGenerationParameters { - // TODO(world-federation-of-advertisers/cross-media-measurement#1194): Fix this to - // handle the case where the computation participants is a subset of all Duchies. - contributorsCount = workerStubs.size + 1 + contributorsCount = token.participantCount maximumFrequency = maximumRequestedFrequency dpParams = llv2Parameters.noise.frequencyNoiseConfig } @@ -768,7 +770,7 @@ class LiquidLegionsV2Mill( if (llv2Parameters.noise.hasFrequencyNoiseConfig()) { partialCompositeElGamalPublicKey = llv2Details.partiallyCombinedPublicKey if (maximumRequestedFrequency > 1) { - noiseParameters = getFrequencyNoiseParams(llv2Parameters) + noiseParameters = getFrequencyNoiseParams(llv2Parameters, token.participantCount) } } noiseMechanism = llv2Details.parameters.noise.noiseMechanism @@ -824,9 +826,7 @@ class LiquidLegionsV2Mill( maximumFrequency = llv2Parameters.maximumFrequency.coerceAtLeast(1) if (llv2Parameters.noise.hasFrequencyNoiseConfig()) { globalFrequencyDpNoisePerBucket = perBucketFrequencyDpNoiseBaseline { - // TODO(world-federation-of-advertisers/cross-media-measurement#1194): Fix this to - // handle the case where the computation participants is a subset of all Duchies. - contributorsCount = workerStubs.size + 1 + contributorsCount = token.participantCount dpParams = llv2Parameters.noise.frequencyNoiseConfig } } @@ -945,7 +945,8 @@ class LiquidLegionsV2Mill( private fun ByteString.toCompleteSetupPhaseRequest( llv2Details: LiquidLegionsSketchAggregationV2.ComputationDetails, - totalRequisitionsCount: Int + totalRequisitionsCount: Int, + participantCount: Int, ): CompleteSetupPhaseRequest { val noiseConfig = llv2Details.parameters.noise return completeSetupPhaseRequest { @@ -955,9 +956,7 @@ class LiquidLegionsV2Mill( noiseParameters = registerNoiseGenerationParameters { compositeElGamalPublicKey = llv2Details.combinedPublicKey curveId = llv2Details.parameters.ellipticCurveId.toLong() - // TODO(world-federation-of-advertisers/cross-media-measurement#1194): Fix this to handle - // the case where the computation participants is a subset of all Duchies. - contributorsCount = workerStubs.size + 1 + contributorsCount = participantCount totalSketchesCount = totalRequisitionsCount dpParams = reachNoiseDifferentialPrivacyParams { blindHistogram = noiseConfig.reachNoiseConfig.blindHistogramNoise @@ -972,17 +971,25 @@ class LiquidLegionsV2Mill( } private fun getFrequencyNoiseParams( - llv2Parameters: Parameters + llv2Parameters: Parameters, + participantCount: Int, ): FlagCountTupleNoiseGenerationParameters { return flagCountTupleNoiseGenerationParameters { maximumFrequency = llv2Parameters.maximumFrequency - // TODO(world-federation-of-advertisers/cross-media-measurement#1194): Fix this to handle the - // case where the computation participants is a subset of all Duchies. - contributorsCount = workerStubs.size + 1 + contributorsCount = participantCount dpParams = llv2Parameters.noise.frequencyNoiseConfig } } + private val ComputationToken.participantCount: Int + get() = + if (computationDetails.kingdomComputation.participantCount != 0) { + computationDetails.kingdomComputation.participantCount + } else { + // For legacy Computations. See world-federation-of-advertisers/cross-media-measurement#1194 + workerStubs.size + 1 + } + companion object { init { loadLibrary( diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2Mill.kt b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2Mill.kt index f64b4b9f5f6..6d4f02f372a 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2Mill.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2Mill.kt @@ -162,8 +162,7 @@ class ReachOnlyLiquidLegionsV2Mill( private val executionPhaseCryptoCpuTimeDurationHistogram: LongHistogram = meter.histogramBuilder("execution_phase_crypto_cpu_time_duration_millis").ofLongs().build() - override val endingStage = - ReachOnlyLiquidLegionsSketchAggregationV2.Stage.COMPLETE.toProtocolStage() + override val endingStage = Stage.COMPLETE.toProtocolStage() private val actions = mapOf( @@ -177,8 +176,6 @@ class ReachOnlyLiquidLegionsV2Mill( Pair(Stage.EXECUTION_PHASE, NON_AGGREGATOR) to ::completeExecutionPhaseAtNonAggregator, ) - private val kBytesPerCipherText = 66 - override suspend fun processComputationImpl(token: ComputationToken) { require(token.computationDetails.hasReachOnlyLiquidLegionsV2()) { "Only Reach Only Liquid Legions V2 computation is supported in this mill." @@ -454,16 +451,18 @@ class ReachOnlyLiquidLegionsV2Mill( private suspend fun completeSetupPhaseAtAggregator(token: ComputationToken): ComputationToken { val rollv2Details = token.computationDetails.reachOnlyLiquidLegionsV2 require(AGGREGATOR == rollv2Details.role) { "invalid role for this function." } - // TODO(world-federation-of-advertisers/cross-media-measurement#1194): Fix this to - // handle the case where the set of computation participants is a subset of all Duchies. - val inputBlobCount = workerStubs.size + val inputBlobCount = token.participantCount - 1 val (bytes, nextToken) = existingOutputOr(token) { val request = dataClients .readAllRequisitionBlobs(token, duchyId) .concat(readAndCombineAllInputBlobsSetupPhaseAtAggregator(token, inputBlobCount)) - .toCompleteSetupPhaseAtAggregatorRequest(rollv2Details, token.requisitionsCount) + .toCompleteSetupPhaseAtAggregatorRequest( + rollv2Details, + token.requisitionsCount, + token.participantCount + ) val cryptoResult: CompleteReachOnlySetupPhaseResponse = cryptoWorker.completeReachOnlySetupPhaseAtAggregator(request) logStageDurationMetric( @@ -501,7 +500,11 @@ class ReachOnlyLiquidLegionsV2Mill( val request = dataClients .readAllRequisitionBlobs(token, duchyId) - .toCompleteReachOnlySetupPhaseRequest(rollv2Details, token.requisitionsCount) + .toCompleteReachOnlySetupPhaseRequest( + rollv2Details, + token.requisitionsCount, + token.participantCount + ) val cryptoResult: CompleteReachOnlySetupPhaseResponse = cryptoWorker.completeReachOnlySetupPhase(request) logStageDurationMetric( @@ -541,24 +544,22 @@ class ReachOnlyLiquidLegionsV2Mill( val measurementSpec = MeasurementSpec.parseFrom(token.computationDetails.kingdomComputation.measurementSpec) val inputBlob = readAndCombineAllInputBlobs(token, 1) - require(inputBlob.size() >= kBytesPerCipherText) { + require(inputBlob.size() >= BYTES_PER_CIPHERTEXT) { ("Invalid input blob size. Input blob ${inputBlob.toStringUtf8()} has size " + - "${inputBlob.size()} which is less than ($kBytesPerCipherText).") + "${inputBlob.size()} which is less than ($BYTES_PER_CIPHERTEXT).") } var reach = 0L - val (bytes, nextToken) = + val (_, nextToken) = existingOutputOr(token) { val request = completeReachOnlyExecutionPhaseAtAggregatorRequest { - combinedRegisterVector = inputBlob.substring(0, inputBlob.size() - kBytesPerCipherText) + combinedRegisterVector = inputBlob.substring(0, inputBlob.size() - BYTES_PER_CIPHERTEXT) localElGamalKeyPair = rollv2Details.localElgamalKey curveId = rollv2Details.parameters.ellipticCurveId.toLong() serializedExcessiveNoiseCiphertext = - inputBlob.substring(inputBlob.size() - kBytesPerCipherText, inputBlob.size()) + inputBlob.substring(inputBlob.size() - BYTES_PER_CIPHERTEXT, inputBlob.size()) if (rollv2Parameters.noise.hasReachNoiseConfig()) { reachDpNoiseBaseline = globalReachDpNoiseBaseline { - // TODO(world-federation-of-advertisers/cross-media-measurement#1194): Fix this to - // handle the case where the computation participants is a subset of all Duchies. - contributorsCount = workerStubs.size + 1 + contributorsCount = token.participantCount globalReachDpNoise = rollv2Parameters.noise.reachNoiseConfig.globalReachDpNoise } } @@ -571,9 +572,7 @@ class ReachOnlyLiquidLegionsV2Mill( noiseParameters = registerNoiseGenerationParameters { compositeElGamalPublicKey = rollv2Details.combinedPublicKey curveId = rollv2Details.parameters.ellipticCurveId.toLong() - // TODO(world-federation-of-advertisers/cross-media-measurement#1194): Fix this to - // handle the case where the computation participants is a subset of all Duchies. - contributorsCount = workerStubs.size + 1 + contributorsCount = token.participantCount totalSketchesCount = token.requisitionsCount dpParams = reachNoiseDifferentialPrivacyParams { blindHistogram = noiseConfig.reachNoiseConfig.blindHistogramNoise @@ -607,9 +606,9 @@ class ReachOnlyLiquidLegionsV2Mill( val rollv2Details = token.computationDetails.reachOnlyLiquidLegionsV2 require(NON_AGGREGATOR == rollv2Details.role) { "invalid role for this function." } val inputBlob = readAndCombineAllInputBlobs(token, 1) - require(inputBlob.size() >= kBytesPerCipherText) { + require(inputBlob.size() >= BYTES_PER_CIPHERTEXT) { ("Invalid input blob size. Input blob ${inputBlob.toStringUtf8()} has size " + - "${inputBlob.size()} which is less than ($kBytesPerCipherText).") + "${inputBlob.size()} which is less than ($BYTES_PER_CIPHERTEXT).") } val (bytes, nextToken) = existingOutputOr(token) { @@ -617,11 +616,11 @@ class ReachOnlyLiquidLegionsV2Mill( cryptoWorker.completeReachOnlyExecutionPhase( completeReachOnlyExecutionPhaseRequest { combinedRegisterVector = - inputBlob.substring(0, inputBlob.size() - kBytesPerCipherText) + inputBlob.substring(0, inputBlob.size() - BYTES_PER_CIPHERTEXT) localElGamalKeyPair = rollv2Details.localElgamalKey curveId = rollv2Details.parameters.ellipticCurveId.toLong() serializedExcessiveNoiseCiphertext = - inputBlob.substring(inputBlob.size() - kBytesPerCipherText, inputBlob.size()) + inputBlob.substring(inputBlob.size() - BYTES_PER_CIPHERTEXT, inputBlob.size()) parallelism = this@ReachOnlyLiquidLegionsV2Mill.parallelism } ) @@ -695,7 +694,8 @@ class ReachOnlyLiquidLegionsV2Mill( private fun ByteString.toCompleteReachOnlySetupPhaseRequest( rollv2Details: ReachOnlyLiquidLegionsSketchAggregationV2.ComputationDetails, - totalRequisitionsCount: Int + totalRequisitionsCount: Int, + participantCount: Int, ): CompleteReachOnlySetupPhaseRequest { val noiseConfig = rollv2Details.parameters.noise return completeReachOnlySetupPhaseRequest { @@ -705,9 +705,7 @@ class ReachOnlyLiquidLegionsV2Mill( noiseParameters = registerNoiseGenerationParameters { compositeElGamalPublicKey = rollv2Details.combinedPublicKey curveId = rollv2Details.parameters.ellipticCurveId.toLong() - // TODO(world-federation-of-advertisers/cross-media-measurement#1194): Fix this to handle - // the case where the set of computation participants is a subset of all Duchies. - contributorsCount = workerStubs.size + 1 + contributorsCount = participantCount totalSketchesCount = totalRequisitionsCount dpParams = reachNoiseDifferentialPrivacyParams { blindHistogram = noiseConfig.reachNoiseConfig.blindHistogramNoise @@ -725,26 +723,21 @@ class ReachOnlyLiquidLegionsV2Mill( private fun ByteString.toCompleteSetupPhaseAtAggregatorRequest( rollv2Details: ReachOnlyLiquidLegionsSketchAggregationV2.ComputationDetails, - totalRequisitionsCount: Int + totalRequisitionsCount: Int, + participantCount: Int, ): CompleteReachOnlySetupPhaseRequest { val noiseConfig = rollv2Details.parameters.noise val combinedInputBlobs = this@toCompleteSetupPhaseAtAggregatorRequest + val combinedRegisterVectorSizeBytes = + combinedInputBlobs.size() - (participantCount - 1) * BYTES_PER_CIPHERTEXT return completeReachOnlySetupPhaseRequest { - combinedRegisterVector = - combinedInputBlobs.substring( - 0, - // TODO(world-federation-of-advertisers/cross-media-measurement#1194): Fix this to handle - // the case where the set of computation participants is a subset of all Duchies. - combinedInputBlobs.size() - workerStubs.size * kBytesPerCipherText - ) + combinedRegisterVector = combinedInputBlobs.substring(0, combinedRegisterVectorSizeBytes) curveId = rollv2Details.parameters.ellipticCurveId.toLong() if (noiseConfig.hasReachNoiseConfig()) { noiseParameters = registerNoiseGenerationParameters { compositeElGamalPublicKey = rollv2Details.combinedPublicKey curveId = rollv2Details.parameters.ellipticCurveId.toLong() - // TODO(world-federation-of-advertisers/cross-media-measurement#1194): Fix this to handle - // the case where the set of computation participants is a subset of all Duchies. - contributorsCount = workerStubs.size + 1 + contributorsCount = participantCount totalSketchesCount = totalRequisitionsCount dpParams = reachNoiseDifferentialPrivacyParams { blindHistogram = noiseConfig.reachNoiseConfig.blindHistogramNoise @@ -756,18 +749,13 @@ class ReachOnlyLiquidLegionsV2Mill( } compositeElGamalPublicKey = rollv2Details.combinedPublicKey serializedExcessiveNoiseCiphertext = - combinedInputBlobs.substring( - // TODO(world-federation-of-advertisers/cross-media-measurement#1194): Fix this to handle - // the case where the set of computation participants is a subset of all Duchies. - combinedInputBlobs.size() - workerStubs.size * kBytesPerCipherText, - combinedInputBlobs.size() - ) + combinedInputBlobs.substring(combinedRegisterVectorSizeBytes, combinedInputBlobs.size()) parallelism = this@ReachOnlyLiquidLegionsV2Mill.parallelism } } /** Reads all input blobs and combines all the bytes together. */ - protected suspend fun readAndCombineAllInputBlobsSetupPhaseAtAggregator( + private suspend fun readAndCombineAllInputBlobsSetupPhaseAtAggregator( token: ComputationToken, count: Int ): ByteString { @@ -780,19 +768,30 @@ class ReachOnlyLiquidLegionsV2Mill( var combinedRegisterVector = ByteString.EMPTY var combinedNoiseCiphertext = ByteString.EMPTY for (str in blobMap.values) { - require(str.size() >= kBytesPerCipherText) { + require(str.size() >= BYTES_PER_CIPHERTEXT) { ("Invalid input blob size. Input blob ${str.toStringUtf8()} has size " + - "${str.size()} which is less than ($kBytesPerCipherText).") + "${str.size()} which is less than ($BYTES_PER_CIPHERTEXT).") } combinedRegisterVector = - combinedRegisterVector.concat(str.substring(0, str.size() - kBytesPerCipherText)) + combinedRegisterVector.concat(str.substring(0, str.size() - BYTES_PER_CIPHERTEXT)) combinedNoiseCiphertext = - combinedNoiseCiphertext.concat(str.substring(str.size() - kBytesPerCipherText, str.size())) + combinedNoiseCiphertext.concat(str.substring(str.size() - BYTES_PER_CIPHERTEXT, str.size())) } return combinedRegisterVector.concat(combinedNoiseCiphertext) } + private val ComputationToken.participantCount: Int + get() = + if (computationDetails.kingdomComputation.participantCount != 0) { + computationDetails.kingdomComputation.participantCount + } else { + // For legacy Computations. See world-federation-of-advertisers/cross-media-measurement#1194 + workerStubs.size + 1 + } + companion object { + private const val BYTES_PER_CIPHERTEXT = 66 + private val logger: Logger = Logger.getLogger(this::class.java.name) } } diff --git a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/utils/ComputationConversions.kt b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/utils/ComputationConversions.kt index ac98ab25e9b..df0adb63916 100644 --- a/src/main/kotlin/org/wfanet/measurement/duchy/daemon/utils/ComputationConversions.kt +++ b/src/main/kotlin/org/wfanet/measurement/duchy/daemon/utils/ComputationConversions.kt @@ -87,6 +87,7 @@ fun SystemComputation.toKingdomComputationDetails(): KingdomComputationDetails { return kingdomComputationDetails { publicApiVersion = source.publicApiVersion measurementSpec = source.measurementSpec + participantCount = source.computationParticipantsCount when (Version.fromString(source.publicApiVersion)) { Version.V2_ALPHA -> { val measurementSpec = MeasurementSpec.parseFrom(source.measurementSpec) diff --git a/src/main/proto/wfa/measurement/internal/duchy/computation_details.proto b/src/main/proto/wfa/measurement/internal/duchy/computation_details.proto index 89397f13c36..4c0f88c378a 100644 --- a/src/main/proto/wfa/measurement/internal/duchy/computation_details.proto +++ b/src/main/proto/wfa/measurement/internal/duchy/computation_details.proto @@ -55,6 +55,11 @@ message ComputationDetails { // Public key for asymmetric encryption. Used when encrypting the final // result. EncryptionPublicKey measurement_public_key = 4; + + // Count of Duchy participants in this Computation. + // + // This may not be set for legacy Computations. + int32 participant_count = 5; } KingdomComputationDetails kingdom_computation = 3; diff --git a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/herald/HeraldTest.kt b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/herald/HeraldTest.kt index bce92309db9..a6608c0dfd9 100644 --- a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/herald/HeraldTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/herald/HeraldTest.kt @@ -413,7 +413,7 @@ class HeraldTest { buildComputationAtKingdom( "2", Computation.State.PENDING_REQUISITION_PARAMS, - listOf(systemApiRequisitions1, systemApiRequisitions2) + systemApiRequisitions = listOf(systemApiRequisitions1, systemApiRequisitions2) ) mockStreamActiveComputationsToReturn(confirmingKnown, confirmingUnknown) @@ -458,6 +458,7 @@ class HeraldTest { publicApiVersion = PUBLIC_API_VERSION measurementSpec = SERIALIZED_MEASUREMENT_SPEC measurementPublicKey = PUBLIC_API_ENCRYPTION_PUBLIC_KEY.toDuchyEncryptionPublicKey() + participantCount = 3 } liquidLegionsV2 = LiquidLegionsSketchAggregationV2Kt.computationDetails { @@ -560,6 +561,7 @@ class HeraldTest { publicApiVersion = PUBLIC_API_VERSION measurementSpec = SERIALIZED_REACH_ONLY_MEASUREMENT_SPEC measurementPublicKey = PUBLIC_API_ENCRYPTION_PUBLIC_KEY.toDuchyEncryptionPublicKey() + participantCount = 3 } liquidLegionsV2 = LiquidLegionsSketchAggregationV2Kt.computationDetails { @@ -660,6 +662,7 @@ class HeraldTest { publicApiVersion = PUBLIC_API_VERSION measurementSpec = SERIALIZED_REACH_ONLY_MEASUREMENT_SPEC measurementPublicKey = PUBLIC_API_ENCRYPTION_PUBLIC_KEY.toDuchyEncryptionPublicKey() + participantCount = 3 } reachOnlyLiquidLegionsV2 = ReachOnlyLiquidLegionsSketchAggregationV2Kt.computationDetails { @@ -1332,29 +1335,44 @@ class HeraldTest { ) } - /** - * Builds a kingdom system Api Computation using default values for fields not included in the - * parameters. - */ - private fun buildComputationAtKingdom( - globalId: String, - stateAtKingdom: Computation.State, - systemApiRequisitions: List = listOf(), - systemComputationParticipant: List = listOf(), - serializedMeasurementSpec: ByteString = SERIALIZED_MEASUREMENT_SPEC, - mpcProtocolConfig: MpcProtocolConfig = LLV2_MPC_PROTOCOL_CONFIG - ): Computation { - return computation { - name = ComputationKey(globalId).toName() - publicApiVersion = PUBLIC_API_VERSION - measurementSpec = serializedMeasurementSpec - state = stateAtKingdom - requisitions += systemApiRequisitions - computationParticipants += systemComputationParticipant - this.mpcProtocolConfig = mpcProtocolConfig + private fun Computation.continuationToken(): String = key.computationId + + companion object { + private val ALL_COMPUTATION_PARTICIPANTS = + listOf( + computationParticipant { + name = ComputationParticipantKey(COMPUTATION_GLOBAL_ID, DUCHY_ONE).toName() + }, + computationParticipant { + name = ComputationParticipantKey(COMPUTATION_GLOBAL_ID, DUCHY_TWO).toName() + }, + computationParticipant { + name = ComputationParticipantKey(COMPUTATION_GLOBAL_ID, DUCHY_THREE).toName() + }, + ) + + /** + * Builds a kingdom system Api Computation using default values for fields not included in the + * parameters. + */ + private fun buildComputationAtKingdom( + globalId: String, + stateAtKingdom: Computation.State, + systemApiRequisitions: List = listOf(), + systemComputationParticipant: List = + ALL_COMPUTATION_PARTICIPANTS, + serializedMeasurementSpec: ByteString = SERIALIZED_MEASUREMENT_SPEC, + mpcProtocolConfig: MpcProtocolConfig = LLV2_MPC_PROTOCOL_CONFIG + ): Computation { + return computation { + name = ComputationKey(globalId).toName() + publicApiVersion = PUBLIC_API_VERSION + measurementSpec = serializedMeasurementSpec + state = stateAtKingdom + requisitions += systemApiRequisitions + computationParticipants += systemComputationParticipant + this.mpcProtocolConfig = mpcProtocolConfig + } } } - - private var continuationTokenTimeSeq = 0L - private fun Computation.continuationToken(): String = key.computationId.toString() } diff --git a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/LiquidLegionsV2MillTest.kt b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/LiquidLegionsV2MillTest.kt index 15ccc126fbf..539812883e4 100644 --- a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/LiquidLegionsV2MillTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/LiquidLegionsV2MillTest.kt @@ -360,6 +360,7 @@ private val AGGREGATOR_COMPUTATION_DETAILS = computationDetails { publicApiVersion = PUBLIC_API_VERSION measurementPublicKey = ENCRYPTION_PUBLIC_KEY.toDuchyEncryptionPublicKey() measurementSpec = SERIALIZED_MEASUREMENT_SPEC + participantCount = 3 } liquidLegionsV2 = LiquidLegionsSketchAggregationV2Kt.computationDetails { @@ -1297,6 +1298,111 @@ class LiquidLegionsV2MillTest { ) } + @Test + fun `setup phase at aggregator using calculated result with fewer participants`() = runBlocking { + // Stage 0. preparing the storage and set up mock + val computationDetails = + AGGREGATOR_COMPUTATION_DETAILS.copy { + kingdomComputation = + kingdomComputation.copy { + // Indicates that the Kingdom selected only 2 of the 3 Duchies. + participantCount = 2 + } + } + val partialToken = + FakeComputationsDatabase.newPartialToken( + localId = LOCAL_ID, + stage = SETUP_PHASE.toProtocolStage() + ) + .build() + val requisitionBlobContext = + RequisitionBlobContext(GLOBAL_ID, REQUISITION_1.externalKey.externalRequisitionId) + requisitionStore.writeString(requisitionBlobContext, "local_requisition") + val inputBlob0Context = ComputationBlobContext(GLOBAL_ID, SETUP_PHASE.toProtocolStage(), 0L) + computationStore.writeString(inputBlob0Context, "-duchy_two_sketch") + fakeComputationDb.addComputation( + partialToken.localComputationId, + partialToken.computationStage, + computationDetails = computationDetails, + blobs = + listOf(newInputBlobMetadata(0L, inputBlob0Context.blobKey), newEmptyOutputBlobMetadata(3L)), + requisitions = listOf(REQUISITION_1, REQUISITION_2, REQUISITION_3) + ) + + var cryptoRequest = CompleteSetupPhaseRequest.getDefaultInstance() + whenever(mockCryptoWorker.completeSetupPhase(any())).thenAnswer { + cryptoRequest = it.getArgument(0) + val postFix = ByteString.copyFromUtf8("-completeSetupPhase_done") + CompleteSetupPhaseResponse.newBuilder() + .apply { combinedRegisterVector = cryptoRequest.combinedRegisterVector.concat(postFix) } + .build() + } + + // Stage 1. Process the above computation + aggregatorMill.pollAndProcessNextComputation() + + // Stage 2. Check the status of the computation + val blobKey = ComputationBlobContext(GLOBAL_ID, SETUP_PHASE.toProtocolStage(), 3L).blobKey + assertThat(fakeComputationDb[LOCAL_ID]) + .isEqualTo( + ComputationToken.newBuilder() + .apply { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 1 + computationStage = WAIT_EXECUTION_PHASE_ONE_INPUTS.toProtocolStage() + addBlobsBuilder().apply { + dependencyType = ComputationBlobDependency.INPUT + blobId = 0 + path = blobKey + } + addBlobsBuilder().apply { + dependencyType = ComputationBlobDependency.OUTPUT + blobId = 1 + } + version = 3 // claimTask + writeOutputBlob + transitionStage + this.computationDetails = computationDetails + addAllRequisitions(listOf(REQUISITION_1, REQUISITION_2, REQUISITION_3)) + } + .build() + ) + + assertThat(computationStore.get(blobKey)?.readToString()) + .isEqualTo("local_requisition-duchy_two_sketch-completeSetupPhase_done") + + assertThat(computationControlRequests) + .containsExactlyElementsIn( + buildAdvanceComputationRequests( + GLOBAL_ID, + EXECUTION_PHASE_ONE_INPUT, + "local_requisition-du", + "chy_two_sketch-compl", + "eteSetupPhase_done", + ) + ) + .inOrder() + + assertThat(cryptoRequest) + .isEqualTo( + completeSetupPhaseRequest { + combinedRegisterVector = ByteString.copyFromUtf8("local_requisition-duchy_two_sketch") + noiseParameters = registerNoiseGenerationParameters { + compositeElGamalPublicKey = COMBINED_PUBLIC_KEY + curveId = CURVE_ID + contributorsCount = 2 + totalSketchesCount = REQUISITIONS.size + dpParams = reachNoiseDifferentialPrivacyParams { + blindHistogram = TEST_NOISE_CONFIG.reachNoiseConfig.blindHistogramNoise + noiseForPublisherNoise = TEST_NOISE_CONFIG.reachNoiseConfig.noiseForPublisherNoise + globalReachDpNoise = TEST_NOISE_CONFIG.reachNoiseConfig.globalReachDpNoise + } + } + maximumFrequency = MAX_FREQUENCY + parallelism = PARALLELISM + } + ) + } + @Test fun `execution phase one at non-aggregator using cached result`() = runBlocking { // Stage 0. preparing the storage and set up mock diff --git a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2MillTest.kt b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2MillTest.kt index 77d197a943e..5e1bccb34d8 100644 --- a/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2MillTest.kt +++ b/src/test/kotlin/org/wfanet/measurement/duchy/daemon/mill/liquidlegionsv2/ReachOnlyLiquidLegionsV2MillTest.kt @@ -320,6 +320,7 @@ private val AGGREGATOR_COMPUTATION_DETAILS = computationDetails { publicApiVersion = PUBLIC_API_VERSION measurementPublicKey = ENCRYPTION_PUBLIC_KEY.toDuchyEncryptionPublicKey() measurementSpec = SERIALIZED_MEASUREMENT_SPEC + participantCount = 3 } reachOnlyLiquidLegionsV2 = ReachOnlyLiquidLegionsSketchAggregationV2Kt.computationDetails { @@ -339,6 +340,7 @@ private val NON_AGGREGATOR_COMPUTATION_DETAILS = computationDetails { publicApiVersion = PUBLIC_API_VERSION measurementPublicKey = ENCRYPTION_PUBLIC_KEY.toDuchyEncryptionPublicKey() measurementSpec = SERIALIZED_MEASUREMENT_SPEC + participantCount = 3 } reachOnlyLiquidLegionsV2 = ReachOnlyLiquidLegionsSketchAggregationV2Kt.computationDetails { @@ -1410,6 +1412,117 @@ class ReachOnlyLiquidLegionsV2MillTest { ) } + @Test + fun `setup phase at aggregator using calculated result with fewer participants`() = runBlocking { + // Stage 0. preparing the storage and set up mock + val computationDetails = + AGGREGATOR_COMPUTATION_DETAILS.copy { + kingdomComputation = + kingdomComputation.copy { + // Indicates that the Kingdom selected only 2 of the 3 Duchies. + participantCount = 2 + } + } + val partialToken = + FakeComputationsDatabase.newPartialToken( + localId = LOCAL_ID, + stage = SETUP_PHASE.toProtocolStage() + ) + .build() + val requisitionBlobContext = + RequisitionBlobContext(GLOBAL_ID, REQUISITION_1.externalKey.externalRequisitionId) + requisitionStore.writeString(requisitionBlobContext, "local_requisition") + val inputBlob0Context = ComputationBlobContext(GLOBAL_ID, SETUP_PHASE.toProtocolStage(), 0L) + computationStore.writeString(inputBlob0Context, "-duchy_2_sketch$NOISE_CIPHERTEXT") + fakeComputationDb.addComputation( + partialToken.localComputationId, + partialToken.computationStage, + computationDetails = computationDetails, + blobs = + listOf(newInputBlobMetadata(0L, inputBlob0Context.blobKey), newEmptyOutputBlobMetadata(3L)), + requisitions = listOf(REQUISITION_1, REQUISITION_2, REQUISITION_3) + ) + + var cryptoRequest = CompleteReachOnlySetupPhaseRequest.getDefaultInstance() + whenever(mockCryptoWorker.completeReachOnlySetupPhaseAtAggregator(any())).thenAnswer { + cryptoRequest = it.getArgument(0) + val postFix = ByteString.copyFromUtf8("-completeReachOnlySetupPhase") + completeReachOnlySetupPhaseResponse { + combinedRegisterVector = cryptoRequest.combinedRegisterVector.concat(postFix) + serializedExcessiveNoiseCiphertext = ByteString.copyFromUtf8("-encryptedNoise") + } + } + + // Stage 1. Process the above computation + aggregatorMill.pollAndProcessNextComputation() + + // Stage 2. Check the status of the computation + val blobKey = ComputationBlobContext(GLOBAL_ID, SETUP_PHASE.toProtocolStage(), 3L).blobKey + assertThat(fakeComputationDb[LOCAL_ID]) + .isEqualTo( + computationToken { + globalComputationId = GLOBAL_ID + localComputationId = LOCAL_ID + attempt = 1 + computationStage = WAIT_EXECUTION_PHASE_INPUTS.toProtocolStage() + blobs.addAll( + listOf( + computationStageBlobMetadata { + dependencyType = ComputationBlobDependency.INPUT + blobId = 0 + path = blobKey + }, + computationStageBlobMetadata { + dependencyType = ComputationBlobDependency.OUTPUT + blobId = 1 + } + ) + ) + version = 3 // claimTask + writeOutputBlob + transitionStage + this.computationDetails = computationDetails + requisitions.addAll(listOf(REQUISITION_1, REQUISITION_2, REQUISITION_3)) + } + ) + + assertThat(computationStore.get(blobKey)?.readToString()) + .isEqualTo("local_requisition-duchy_2_sketch-completeReachOnlySetupPhase-encryptedNoise") + + assertThat(computationControlRequests) + .containsExactlyElementsIn( + buildAdvanceComputationRequests( + GLOBAL_ID, + EXECUTION_PHASE_INPUT, + "local_requisition-du", + "chy_2_sketch-complet", + "eReachOnlySetupPhase", + "-encryptedNoise", + ) + ) + .inOrder() + + assertThat(cryptoRequest) + .isEqualTo( + completeReachOnlySetupPhaseRequest { + combinedRegisterVector = ByteString.copyFromUtf8("local_requisition-duchy_2_sketch") + curveId = CURVE_ID + noiseParameters = registerNoiseGenerationParameters { + compositeElGamalPublicKey = COMBINED_PUBLIC_KEY + curveId = CURVE_ID + contributorsCount = 2 + totalSketchesCount = REQUISITIONS.size + dpParams = reachNoiseDifferentialPrivacyParams { + blindHistogram = TEST_NOISE_CONFIG.reachNoiseConfig.blindHistogramNoise + noiseForPublisherNoise = TEST_NOISE_CONFIG.reachNoiseConfig.noiseForPublisherNoise + globalReachDpNoise = TEST_NOISE_CONFIG.reachNoiseConfig.globalReachDpNoise + } + } + compositeElGamalPublicKey = COMBINED_PUBLIC_KEY + serializedExcessiveNoiseCiphertext = SERIALIZED_NOISE_CIPHERTEXT + parallelism = PARALLELISM + } + ) + } + @Test fun `setup phase at aggregator, failed due to invalid input blob size`() = runBlocking { // Stage 0. preparing the storage and set up mock