Skip to content

Commit

Permalink
Update Kingdom measurement creation to support HMSS protocol. (#1404)
Browse files Browse the repository at this point in the history
  • Loading branch information
renjiezh authored Jan 11, 2024
1 parent 1b0f32a commit 58036da
Show file tree
Hide file tree
Showing 9 changed files with 207 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,16 @@ kt_jvm_library(
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc",
],
)

kt_jvm_library(
name = "hmss_protocol_config",
srcs = ["HmssProtocolConfig.kt"],
deps = [
"//src/main/proto/wfa/measurement/internal/kingdom:protocol_config_config_kt_jvm_proto",
"@wfa_common_jvm//imports/java/io/grpc:api",
"@wfa_common_jvm//imports/java/io/grpc/stub",
"@wfa_common_jvm//imports/java/picocli",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common",
"@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc",
],
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Copyright 2024 The Cross-Media Measurement Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package org.wfanet.measurement.kingdom.deploy.common

import java.io.File
import org.wfanet.measurement.common.parseTextProto
import org.wfanet.measurement.internal.kingdom.HmssProtocolConfigConfig
import org.wfanet.measurement.internal.kingdom.ProtocolConfig
import picocli.CommandLine

object HmssProtocolConfig {
const val name = "hmss"
lateinit var protocolConfig: ProtocolConfig.HonestMajorityShareShuffle
private set

lateinit var requiredExternalDuchyIds: Set<String>
private set

fun initializeFromFlags(flags: HmssProtocolConfigFlags) {
require(!HmssProtocolConfig::protocolConfig.isInitialized)
require(!HmssProtocolConfig::requiredExternalDuchyIds.isInitialized)
val configMessage =
flags.config.reader().use {
parseTextProto(it, HmssProtocolConfigConfig.getDefaultInstance())
}

protocolConfig = configMessage.protocolConfig
requiredExternalDuchyIds = configMessage.requiredExternalDuchyIdsList.toSet()
}

fun setForTest(
protocolConfig: ProtocolConfig.HonestMajorityShareShuffle,
requiredExternalDuchyIds: Set<String>,
) {
require(!HmssProtocolConfig::protocolConfig.isInitialized)

HmssProtocolConfig.protocolConfig = protocolConfig
HmssProtocolConfig.requiredExternalDuchyIds = requiredExternalDuchyIds
}
}

class HmssProtocolConfigFlags {
@CommandLine.Option(
names = ["--hmss-protocol-config-config"],
description = ["HmssProtocolConfigConfig proto message in text format."],
required = true
)
lateinit var config: File
private set
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ kt_jvm_library(
"//src/main/kotlin/org/wfanet/measurement/kingdom/deploy/common:duchy_ids",
"//src/main/kotlin/org/wfanet/measurement/kingdom/deploy/common:llv2_protocol_config",
"//src/main/kotlin/org/wfanet/measurement/kingdom/deploy/common:ro_llv2_protocol_config",
"//src/main/kotlin/org/wfanet/measurement/kingdom/deploy/common:hmss_protocol_config",
"//src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/common",
"//src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/queries",
"//src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/readers",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import org.wfanet.measurement.internal.kingdom.Requisition
import org.wfanet.measurement.internal.kingdom.RequisitionKt
import org.wfanet.measurement.internal.kingdom.copy
import org.wfanet.measurement.kingdom.deploy.common.DuchyIds
import org.wfanet.measurement.kingdom.deploy.common.HmssProtocolConfig
import org.wfanet.measurement.kingdom.deploy.common.Llv2ProtocolConfig
import org.wfanet.measurement.kingdom.deploy.common.RoLlv2ProtocolConfig
import org.wfanet.measurement.kingdom.deploy.gcloud.spanner.common.CertificateIsInvalidException
Expand Down Expand Up @@ -82,7 +83,8 @@ class CreateMeasurements(private val requests: List<CreateMeasurementRequest>) :
?: @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Protobuf enum fields are never null.
when (it.measurement.details.protocolConfig.protocolCase) {
ProtocolConfig.ProtocolCase.LIQUID_LEGIONS_V2,
ProtocolConfig.ProtocolCase.REACH_ONLY_LIQUID_LEGIONS_V2 -> {
ProtocolConfig.ProtocolCase.REACH_ONLY_LIQUID_LEGIONS_V2,
ProtocolConfig.ProtocolCase.HONEST_MAJORITY_SHARE_SHUFFLE -> {
createComputedMeasurement(it, measurementConsumerId)
}
ProtocolConfig.ProtocolCase.DIRECT -> createDirectMeasurement(it, measurementConsumerId)
Expand All @@ -98,12 +100,16 @@ class CreateMeasurements(private val requests: List<CreateMeasurementRequest>) :
val initialMeasurementState = Measurement.State.PENDING_REQUISITION_PARAMS

val requiredExternalDuchyIds =
if (
createMeasurementRequest.measurement.details.protocolConfig.protocolCase ==
ProtocolConfig.ProtocolCase.LIQUID_LEGIONS_V2
)
Llv2ProtocolConfig.requiredExternalDuchyIds
else RoLlv2ProtocolConfig.requiredExternalDuchyIds
@Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Protobuf enum fields are never null.
when (createMeasurementRequest.measurement.details.protocolConfig.protocolCase) {
ProtocolConfig.ProtocolCase.LIQUID_LEGIONS_V2 -> Llv2ProtocolConfig.requiredExternalDuchyIds
ProtocolConfig.ProtocolCase.REACH_ONLY_LIQUID_LEGIONS_V2 ->
RoLlv2ProtocolConfig.requiredExternalDuchyIds
ProtocolConfig.ProtocolCase.HONEST_MAJORITY_SHARE_SHUFFLE ->
HmssProtocolConfig.requiredExternalDuchyIds
ProtocolConfig.ProtocolCase.DIRECT,
ProtocolConfig.ProtocolCase.PROTOCOL_NOT_SET -> error("Invalid protocol.")
}
val requiredDuchyIds =
requiredExternalDuchyIds +
readDataProviderRequiredDuchies(
Expand All @@ -117,12 +123,17 @@ class CreateMeasurements(private val requests: List<CreateMeasurementRequest>) :
}
}
val minimumNumberOfRequiredDuchies =
if (
createMeasurementRequest.measurement.details.protocolConfig.protocolCase ==
ProtocolConfig.ProtocolCase.LIQUID_LEGIONS_V2
)
Llv2ProtocolConfig.minimumNumberOfRequiredDuchies
else RoLlv2ProtocolConfig.minimumNumberOfRequiredDuchies
@Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Protobuf enum fields are never null.
when (createMeasurementRequest.measurement.details.protocolConfig.protocolCase) {
ProtocolConfig.ProtocolCase.LIQUID_LEGIONS_V2 ->
Llv2ProtocolConfig.minimumNumberOfRequiredDuchies
ProtocolConfig.ProtocolCase.REACH_ONLY_LIQUID_LEGIONS_V2 ->
RoLlv2ProtocolConfig.minimumNumberOfRequiredDuchies
ProtocolConfig.ProtocolCase.HONEST_MAJORITY_SHARE_SHUFFLE ->
HmssProtocolConfig.requiredExternalDuchyIds.size
ProtocolConfig.ProtocolCase.DIRECT,
ProtocolConfig.ProtocolCase.PROTOCOL_NOT_SET -> error("Invalid protocol.")
}

val includedDuchyEntries =
if (requiredDuchyEntries.size < minimumNumberOfRequiredDuchies) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ import org.wfanet.measurement.api.v2alpha.PopulationKt.populationBlob
import org.wfanet.measurement.api.v2alpha.ProtocolConfig
import org.wfanet.measurement.api.v2alpha.ProtocolConfig.NoiseMechanism
import org.wfanet.measurement.api.v2alpha.ProtocolConfigKt.direct
import org.wfanet.measurement.api.v2alpha.ProtocolConfigKt.honestMajorityShareShuffle
import org.wfanet.measurement.api.v2alpha.ProtocolConfigKt.liquidLegionsV2
import org.wfanet.measurement.api.v2alpha.ProtocolConfigKt.protocol
import org.wfanet.measurement.api.v2alpha.ProtocolConfigKt.reachOnlyLiquidLegionsV2
Expand All @@ -89,6 +90,7 @@ import org.wfanet.measurement.api.v2alpha.population
import org.wfanet.measurement.api.v2alpha.protocolConfig
import org.wfanet.measurement.api.v2alpha.reachOnlyLiquidLegionsSketchParams
import org.wfanet.measurement.api.v2alpha.setMessage
import org.wfanet.measurement.api.v2alpha.shareShuffleSketchParams
import org.wfanet.measurement.api.v2alpha.signedMessage
import org.wfanet.measurement.api.v2alpha.unpack
import org.wfanet.measurement.common.ProtoReflection
Expand Down Expand Up @@ -452,6 +454,19 @@ private fun buildMpcProtocolConfig(
}
}
}
InternalProtocolConfig.ProtocolCase.HONEST_MAJORITY_SHARE_SHUFFLE -> {
protocol {
honestMajorityShareShuffle = honestMajorityShareShuffle {
sketchParams = shareShuffleSketchParams {
registerCount = protocolConfig.honestMajorityShareShuffle.sketchParams.registerCount
bytesPerRegister =
protocolConfig.honestMajorityShareShuffle.sketchParams.bytesPerRegister
}
noiseMechanism =
protocolConfig.honestMajorityShareShuffle.noiseMechanism.toNoiseMechanism()
}
}
}
InternalProtocolConfig.ProtocolCase.DIRECT -> {
error("Direct protocol cannot be used for MPC-based Measurements")
}
Expand Down Expand Up @@ -977,6 +992,8 @@ fun Map.Entry<Long, DataProviderValue>.toDataProviderEntry(apiVersion: Version):
* Converts a public [Measurement] to an internal [InternalMeasurement] for creation.
*
* @throws [IllegalStateException] if MeasurementType not specified
*
* TODO(@renjie): Enable HMSS protocol based on feature flag.
*/
fun Measurement.toInternal(
measurementConsumerCertificateKey: MeasurementConsumerCertificateKey,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ kt_jvm_library(
deps = KINGDOM_INTERNAL_PROTOS + [
"//src/main/kotlin/org/wfanet/measurement/kingdom/deploy/common:llv2_protocol_config",
"//src/main/kotlin/org/wfanet/measurement/kingdom/deploy/common:ro_llv2_protocol_config",
"//src/main/kotlin/org/wfanet/measurement/kingdom/deploy/common:hmss_protocol_config",
"//src/main/kotlin/org/wfanet/measurement/kingdom/deploy/common/testing",
"@wfa_common_jvm//imports/java/com/google/common/truth",
"@wfa_common_jvm//imports/java/com/google/common/truth/extensions/proto",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ import org.wfanet.measurement.internal.kingdom.setMeasurementResultRequest
import org.wfanet.measurement.internal.kingdom.streamMeasurementsRequest
import org.wfanet.measurement.internal.kingdom.streamRequisitionsRequest
import org.wfanet.measurement.kingdom.deploy.common.DuchyIds
import org.wfanet.measurement.kingdom.deploy.common.HmssProtocolConfig
import org.wfanet.measurement.kingdom.deploy.common.Llv2ProtocolConfig
import org.wfanet.measurement.kingdom.deploy.common.RoLlv2ProtocolConfig
import org.wfanet.measurement.kingdom.deploy.common.testing.DuchyIdSetter
Expand Down Expand Up @@ -121,6 +122,20 @@ private val REACH_ONLY_MEASUREMENT =
}
}

private val HMSS_MEASUREMENT = measurement {
providedMeasurementId = PROVIDED_MEASUREMENT_ID
details =
MeasurementKt.details {
apiVersion = API_VERSION
measurementSpec = ByteString.copyFromUtf8("MeasurementSpec")
measurementSpecSignature = ByteString.copyFromUtf8("MeasurementSpec signature")
measurementSpecSignatureAlgorithmOid = "2.9999"
protocolConfig = protocolConfig {
honestMajorityShareShuffle = ProtocolConfig.HonestMajorityShareShuffle.getDefaultInstance()
}
}
}

private val INVALID_WORKER_DUCHY =
DuchyIds.Entry(4, "worker3", Instant.now().minusSeconds(100L)..Instant.now().minusSeconds(50L))

Expand Down Expand Up @@ -489,6 +504,39 @@ abstract class MeasurementsServiceTest<T : MeasurementsCoroutineImplBase> {
.isEqualTo(measurement.copy { state = Measurement.State.PENDING_REQUISITION_PARAMS })
}

@Test
fun `createMeasurement for duchy HMSS measurement succeeds`() = runBlocking {
val measurementConsumer =
population.createMeasurementConsumer(measurementConsumersService, accountsService)
val dataProvider = population.createDataProvider(dataProvidersService)

val measurement =
HMSS_MEASUREMENT.copy {
externalMeasurementConsumerId = measurementConsumer.externalMeasurementConsumerId
externalMeasurementConsumerCertificateId =
measurementConsumer.certificate.externalCertificateId
dataProviders[dataProvider.externalDataProviderId] = dataProvider.toDataProviderValue()
}

val createdMeasurement =
measurementsService.createMeasurement(
createMeasurementRequest { this.measurement = measurement }
)
assertThat(createdMeasurement.externalMeasurementId).isNotEqualTo(0L)
assertThat(createdMeasurement.externalComputationId).isNotEqualTo(0L)
assertThat(createdMeasurement.createTime.seconds).isGreaterThan(0L)
assertThat(createdMeasurement.updateTime).isEqualTo(createdMeasurement.createTime)
assertThat(createdMeasurement)
.ignoringFields(
Measurement.EXTERNAL_MEASUREMENT_ID_FIELD_NUMBER,
Measurement.EXTERNAL_COMPUTATION_ID_FIELD_NUMBER,
Measurement.CREATE_TIME_FIELD_NUMBER,
Measurement.UPDATE_TIME_FIELD_NUMBER,
Measurement.ETAG_FIELD_NUMBER,
)
.isEqualTo(measurement.copy { state = Measurement.State.PENDING_REQUISITION_PARAMS })
}

@Test
fun `createMeasurement for duchy measurement contains required duchies and the aggregator`() =
runBlocking {
Expand Down Expand Up @@ -2527,6 +2575,14 @@ abstract class MeasurementsServiceTest<T : MeasurementsCoroutineImplBase> {
),
2
)
HmssProtocolConfig.setForTest(
ProtocolConfig.HonestMajorityShareShuffle.getDefaultInstance(),
setOf(
Population.AGGREGATOR_DUCHY.externalDuchyId,
Population.WORKER1_DUCHY.externalDuchyId,
Population.WORKER2_DUCHY.externalDuchyId,
)
)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,15 @@ message ProtocolConfig {
NoiseMechanism noise_mechanism = 5;
}

// Configuration for Honest Majority Share Shuffle protocols.
message HonestMajorityShareShuffle {
// Parameters for sketches.
ShareShuffleSketchParams sketch_params = 1;

// The mechanism to generate noise by MPC nodes during computation.
NoiseMechanism noise_mechanism = 2;
}

// Configuration for the specific protocol.
oneof protocol {
// Liquid Legions v2 config.
Expand All @@ -145,6 +154,9 @@ message ProtocolConfig {

// Direct protocol.
Direct direct = 5;

// Honest Majority Share Shuffle protocol.
HonestMajorityShareShuffle honest_majority_share_shuffle = 6;
}
}

Expand All @@ -160,3 +172,15 @@ message LiquidLegionsSketchParams {
// Reach-Only Liquid Legions protocol ignores this field.
int64 sampling_indicator_size = 3;
}

// Parameters for a honest majority share shuffle sketch.
message ShareShuffleSketchParams {
// The number of registers in the sketch.
int64 register_count = 1;

// Length of each register in bytes.
int32 bytes_per_register = 2;

// Secret share modulus.
uint32 ring_modulus = 3;
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,12 @@ message Llv2ProtocolConfigConfig {
repeated string required_external_duchy_ids = 3;
int32 minimum_duchy_participant_count = 4;
}

// The kingdom local config proto used to initialize the `ProtocolConfig` for
// the Honest Majority Share Shuffle protocol.
message HmssProtocolConfigConfig {
// The HonestMajorityShareShuffle protocol Config.
ProtocolConfig.HonestMajorityShareShuffle protocol_config = 1;
// List of required duchies for this protocol.
repeated string required_external_duchy_ids = 2;
}

0 comments on commit 58036da

Please sign in to comment.