From 58036da848d7e53b0719dda5c84b542894f35d05 Mon Sep 17 00:00:00 2001 From: renjiezh <94721804+renjiezh@users.noreply.github.com> Date: Thu, 11 Jan 2024 13:59:48 -0800 Subject: [PATCH] Update Kingdom measurement creation to support HMSS protocol. (#1404) --- .../kingdom/deploy/common/BUILD.bazel | 13 ++++ .../deploy/common/HmssProtocolConfig.kt | 62 +++++++++++++++++++ .../deploy/gcloud/spanner/writers/BUILD.bazel | 1 + .../spanner/writers/CreateMeasurements.kt | 37 +++++++---- .../service/api/v2alpha/ProtoConversions.kt | 17 +++++ .../service/internal/testing/BUILD.bazel | 1 + .../testing/MeasurementsServiceTest.kt | 56 +++++++++++++++++ .../internal/kingdom/protocol_config.proto | 24 +++++++ .../kingdom/protocol_config_config.proto | 9 +++ 9 files changed, 207 insertions(+), 13 deletions(-) create mode 100644 src/main/kotlin/org/wfanet/measurement/kingdom/deploy/common/HmssProtocolConfig.kt diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/common/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/common/BUILD.bazel index 00070f4ca29..ddf0f2a6208 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/common/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/common/BUILD.bazel @@ -51,3 +51,16 @@ kt_jvm_library( "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc", ], ) + +kt_jvm_library( + name = "hmss_protocol_config", + srcs = ["HmssProtocolConfig.kt"], + deps = [ + "//src/main/proto/wfa/measurement/internal/kingdom:protocol_config_config_kt_jvm_proto", + "@wfa_common_jvm//imports/java/io/grpc:api", + "@wfa_common_jvm//imports/java/io/grpc/stub", + "@wfa_common_jvm//imports/java/picocli", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common", + "@wfa_common_jvm//src/main/kotlin/org/wfanet/measurement/common/grpc", + ], +) diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/common/HmssProtocolConfig.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/common/HmssProtocolConfig.kt new file mode 100644 index 00000000000..5e1c46c342b --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/common/HmssProtocolConfig.kt @@ -0,0 +1,62 @@ +// Copyright 2024 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package org.wfanet.measurement.kingdom.deploy.common + +import java.io.File +import org.wfanet.measurement.common.parseTextProto +import org.wfanet.measurement.internal.kingdom.HmssProtocolConfigConfig +import org.wfanet.measurement.internal.kingdom.ProtocolConfig +import picocli.CommandLine + +object HmssProtocolConfig { + const val name = "hmss" + lateinit var protocolConfig: ProtocolConfig.HonestMajorityShareShuffle + private set + + lateinit var requiredExternalDuchyIds: Set + private set + + fun initializeFromFlags(flags: HmssProtocolConfigFlags) { + require(!HmssProtocolConfig::protocolConfig.isInitialized) + require(!HmssProtocolConfig::requiredExternalDuchyIds.isInitialized) + val configMessage = + flags.config.reader().use { + parseTextProto(it, HmssProtocolConfigConfig.getDefaultInstance()) + } + + protocolConfig = configMessage.protocolConfig + requiredExternalDuchyIds = configMessage.requiredExternalDuchyIdsList.toSet() + } + + fun setForTest( + protocolConfig: ProtocolConfig.HonestMajorityShareShuffle, + requiredExternalDuchyIds: Set, + ) { + require(!HmssProtocolConfig::protocolConfig.isInitialized) + + HmssProtocolConfig.protocolConfig = protocolConfig + HmssProtocolConfig.requiredExternalDuchyIds = requiredExternalDuchyIds + } +} + +class HmssProtocolConfigFlags { + @CommandLine.Option( + names = ["--hmss-protocol-config-config"], + description = ["HmssProtocolConfigConfig proto message in text format."], + required = true + ) + lateinit var config: File + private set +} diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/writers/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/writers/BUILD.bazel index de32e165843..7cc6caf1af0 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/writers/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/writers/BUILD.bazel @@ -14,6 +14,7 @@ kt_jvm_library( "//src/main/kotlin/org/wfanet/measurement/kingdom/deploy/common:duchy_ids", "//src/main/kotlin/org/wfanet/measurement/kingdom/deploy/common:llv2_protocol_config", "//src/main/kotlin/org/wfanet/measurement/kingdom/deploy/common:ro_llv2_protocol_config", + "//src/main/kotlin/org/wfanet/measurement/kingdom/deploy/common:hmss_protocol_config", "//src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/common", "//src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/queries", "//src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/readers", diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/writers/CreateMeasurements.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/writers/CreateMeasurements.kt index fce856f60db..4cf8d401389 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/writers/CreateMeasurements.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/deploy/gcloud/spanner/writers/CreateMeasurements.kt @@ -39,6 +39,7 @@ import org.wfanet.measurement.internal.kingdom.Requisition import org.wfanet.measurement.internal.kingdom.RequisitionKt import org.wfanet.measurement.internal.kingdom.copy import org.wfanet.measurement.kingdom.deploy.common.DuchyIds +import org.wfanet.measurement.kingdom.deploy.common.HmssProtocolConfig import org.wfanet.measurement.kingdom.deploy.common.Llv2ProtocolConfig import org.wfanet.measurement.kingdom.deploy.common.RoLlv2ProtocolConfig import org.wfanet.measurement.kingdom.deploy.gcloud.spanner.common.CertificateIsInvalidException @@ -82,7 +83,8 @@ class CreateMeasurements(private val requests: List) : ?: @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Protobuf enum fields are never null. when (it.measurement.details.protocolConfig.protocolCase) { ProtocolConfig.ProtocolCase.LIQUID_LEGIONS_V2, - ProtocolConfig.ProtocolCase.REACH_ONLY_LIQUID_LEGIONS_V2 -> { + ProtocolConfig.ProtocolCase.REACH_ONLY_LIQUID_LEGIONS_V2, + ProtocolConfig.ProtocolCase.HONEST_MAJORITY_SHARE_SHUFFLE -> { createComputedMeasurement(it, measurementConsumerId) } ProtocolConfig.ProtocolCase.DIRECT -> createDirectMeasurement(it, measurementConsumerId) @@ -98,12 +100,16 @@ class CreateMeasurements(private val requests: List) : val initialMeasurementState = Measurement.State.PENDING_REQUISITION_PARAMS val requiredExternalDuchyIds = - if ( - createMeasurementRequest.measurement.details.protocolConfig.protocolCase == - ProtocolConfig.ProtocolCase.LIQUID_LEGIONS_V2 - ) - Llv2ProtocolConfig.requiredExternalDuchyIds - else RoLlv2ProtocolConfig.requiredExternalDuchyIds + @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Protobuf enum fields are never null. + when (createMeasurementRequest.measurement.details.protocolConfig.protocolCase) { + ProtocolConfig.ProtocolCase.LIQUID_LEGIONS_V2 -> Llv2ProtocolConfig.requiredExternalDuchyIds + ProtocolConfig.ProtocolCase.REACH_ONLY_LIQUID_LEGIONS_V2 -> + RoLlv2ProtocolConfig.requiredExternalDuchyIds + ProtocolConfig.ProtocolCase.HONEST_MAJORITY_SHARE_SHUFFLE -> + HmssProtocolConfig.requiredExternalDuchyIds + ProtocolConfig.ProtocolCase.DIRECT, + ProtocolConfig.ProtocolCase.PROTOCOL_NOT_SET -> error("Invalid protocol.") + } val requiredDuchyIds = requiredExternalDuchyIds + readDataProviderRequiredDuchies( @@ -117,12 +123,17 @@ class CreateMeasurements(private val requests: List) : } } val minimumNumberOfRequiredDuchies = - if ( - createMeasurementRequest.measurement.details.protocolConfig.protocolCase == - ProtocolConfig.ProtocolCase.LIQUID_LEGIONS_V2 - ) - Llv2ProtocolConfig.minimumNumberOfRequiredDuchies - else RoLlv2ProtocolConfig.minimumNumberOfRequiredDuchies + @Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Protobuf enum fields are never null. + when (createMeasurementRequest.measurement.details.protocolConfig.protocolCase) { + ProtocolConfig.ProtocolCase.LIQUID_LEGIONS_V2 -> + Llv2ProtocolConfig.minimumNumberOfRequiredDuchies + ProtocolConfig.ProtocolCase.REACH_ONLY_LIQUID_LEGIONS_V2 -> + RoLlv2ProtocolConfig.minimumNumberOfRequiredDuchies + ProtocolConfig.ProtocolCase.HONEST_MAJORITY_SHARE_SHUFFLE -> + HmssProtocolConfig.requiredExternalDuchyIds.size + ProtocolConfig.ProtocolCase.DIRECT, + ProtocolConfig.ProtocolCase.PROTOCOL_NOT_SET -> error("Invalid protocol.") + } val includedDuchyEntries = if (requiredDuchyEntries.size < minimumNumberOfRequiredDuchies) { diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/ProtoConversions.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/ProtoConversions.kt index 2c9abbcf634..6ff4bcd7b86 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/ProtoConversions.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/service/api/v2alpha/ProtoConversions.kt @@ -66,6 +66,7 @@ import org.wfanet.measurement.api.v2alpha.PopulationKt.populationBlob import org.wfanet.measurement.api.v2alpha.ProtocolConfig import org.wfanet.measurement.api.v2alpha.ProtocolConfig.NoiseMechanism import org.wfanet.measurement.api.v2alpha.ProtocolConfigKt.direct +import org.wfanet.measurement.api.v2alpha.ProtocolConfigKt.honestMajorityShareShuffle import org.wfanet.measurement.api.v2alpha.ProtocolConfigKt.liquidLegionsV2 import org.wfanet.measurement.api.v2alpha.ProtocolConfigKt.protocol import org.wfanet.measurement.api.v2alpha.ProtocolConfigKt.reachOnlyLiquidLegionsV2 @@ -89,6 +90,7 @@ import org.wfanet.measurement.api.v2alpha.population import org.wfanet.measurement.api.v2alpha.protocolConfig import org.wfanet.measurement.api.v2alpha.reachOnlyLiquidLegionsSketchParams import org.wfanet.measurement.api.v2alpha.setMessage +import org.wfanet.measurement.api.v2alpha.shareShuffleSketchParams import org.wfanet.measurement.api.v2alpha.signedMessage import org.wfanet.measurement.api.v2alpha.unpack import org.wfanet.measurement.common.ProtoReflection @@ -452,6 +454,19 @@ private fun buildMpcProtocolConfig( } } } + InternalProtocolConfig.ProtocolCase.HONEST_MAJORITY_SHARE_SHUFFLE -> { + protocol { + honestMajorityShareShuffle = honestMajorityShareShuffle { + sketchParams = shareShuffleSketchParams { + registerCount = protocolConfig.honestMajorityShareShuffle.sketchParams.registerCount + bytesPerRegister = + protocolConfig.honestMajorityShareShuffle.sketchParams.bytesPerRegister + } + noiseMechanism = + protocolConfig.honestMajorityShareShuffle.noiseMechanism.toNoiseMechanism() + } + } + } InternalProtocolConfig.ProtocolCase.DIRECT -> { error("Direct protocol cannot be used for MPC-based Measurements") } @@ -977,6 +992,8 @@ fun Map.Entry.toDataProviderEntry(apiVersion: Version): * Converts a public [Measurement] to an internal [InternalMeasurement] for creation. * * @throws [IllegalStateException] if MeasurementType not specified + * + * TODO(@renjie): Enable HMSS protocol based on feature flag. */ fun Measurement.toInternal( measurementConsumerCertificateKey: MeasurementConsumerCertificateKey, diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/service/internal/testing/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/kingdom/service/internal/testing/BUILD.bazel index b07d9e09280..d0fb453d839 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/service/internal/testing/BUILD.bazel +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/service/internal/testing/BUILD.bazel @@ -10,6 +10,7 @@ kt_jvm_library( deps = KINGDOM_INTERNAL_PROTOS + [ "//src/main/kotlin/org/wfanet/measurement/kingdom/deploy/common:llv2_protocol_config", "//src/main/kotlin/org/wfanet/measurement/kingdom/deploy/common:ro_llv2_protocol_config", + "//src/main/kotlin/org/wfanet/measurement/kingdom/deploy/common:hmss_protocol_config", "//src/main/kotlin/org/wfanet/measurement/kingdom/deploy/common/testing", "@wfa_common_jvm//imports/java/com/google/common/truth", "@wfa_common_jvm//imports/java/com/google/common/truth/extensions/proto", diff --git a/src/main/kotlin/org/wfanet/measurement/kingdom/service/internal/testing/MeasurementsServiceTest.kt b/src/main/kotlin/org/wfanet/measurement/kingdom/service/internal/testing/MeasurementsServiceTest.kt index da75bd69ef9..15e7cc2fe43 100644 --- a/src/main/kotlin/org/wfanet/measurement/kingdom/service/internal/testing/MeasurementsServiceTest.kt +++ b/src/main/kotlin/org/wfanet/measurement/kingdom/service/internal/testing/MeasurementsServiceTest.kt @@ -80,6 +80,7 @@ import org.wfanet.measurement.internal.kingdom.setMeasurementResultRequest import org.wfanet.measurement.internal.kingdom.streamMeasurementsRequest import org.wfanet.measurement.internal.kingdom.streamRequisitionsRequest import org.wfanet.measurement.kingdom.deploy.common.DuchyIds +import org.wfanet.measurement.kingdom.deploy.common.HmssProtocolConfig import org.wfanet.measurement.kingdom.deploy.common.Llv2ProtocolConfig import org.wfanet.measurement.kingdom.deploy.common.RoLlv2ProtocolConfig import org.wfanet.measurement.kingdom.deploy.common.testing.DuchyIdSetter @@ -121,6 +122,20 @@ private val REACH_ONLY_MEASUREMENT = } } +private val HMSS_MEASUREMENT = measurement { + providedMeasurementId = PROVIDED_MEASUREMENT_ID + details = + MeasurementKt.details { + apiVersion = API_VERSION + measurementSpec = ByteString.copyFromUtf8("MeasurementSpec") + measurementSpecSignature = ByteString.copyFromUtf8("MeasurementSpec signature") + measurementSpecSignatureAlgorithmOid = "2.9999" + protocolConfig = protocolConfig { + honestMajorityShareShuffle = ProtocolConfig.HonestMajorityShareShuffle.getDefaultInstance() + } + } +} + private val INVALID_WORKER_DUCHY = DuchyIds.Entry(4, "worker3", Instant.now().minusSeconds(100L)..Instant.now().minusSeconds(50L)) @@ -489,6 +504,39 @@ abstract class MeasurementsServiceTest { .isEqualTo(measurement.copy { state = Measurement.State.PENDING_REQUISITION_PARAMS }) } + @Test + fun `createMeasurement for duchy HMSS measurement succeeds`() = runBlocking { + val measurementConsumer = + population.createMeasurementConsumer(measurementConsumersService, accountsService) + val dataProvider = population.createDataProvider(dataProvidersService) + + val measurement = + HMSS_MEASUREMENT.copy { + externalMeasurementConsumerId = measurementConsumer.externalMeasurementConsumerId + externalMeasurementConsumerCertificateId = + measurementConsumer.certificate.externalCertificateId + dataProviders[dataProvider.externalDataProviderId] = dataProvider.toDataProviderValue() + } + + val createdMeasurement = + measurementsService.createMeasurement( + createMeasurementRequest { this.measurement = measurement } + ) + assertThat(createdMeasurement.externalMeasurementId).isNotEqualTo(0L) + assertThat(createdMeasurement.externalComputationId).isNotEqualTo(0L) + assertThat(createdMeasurement.createTime.seconds).isGreaterThan(0L) + assertThat(createdMeasurement.updateTime).isEqualTo(createdMeasurement.createTime) + assertThat(createdMeasurement) + .ignoringFields( + Measurement.EXTERNAL_MEASUREMENT_ID_FIELD_NUMBER, + Measurement.EXTERNAL_COMPUTATION_ID_FIELD_NUMBER, + Measurement.CREATE_TIME_FIELD_NUMBER, + Measurement.UPDATE_TIME_FIELD_NUMBER, + Measurement.ETAG_FIELD_NUMBER, + ) + .isEqualTo(measurement.copy { state = Measurement.State.PENDING_REQUISITION_PARAMS }) + } + @Test fun `createMeasurement for duchy measurement contains required duchies and the aggregator`() = runBlocking { @@ -2527,6 +2575,14 @@ abstract class MeasurementsServiceTest { ), 2 ) + HmssProtocolConfig.setForTest( + ProtocolConfig.HonestMajorityShareShuffle.getDefaultInstance(), + setOf( + Population.AGGREGATOR_DUCHY.externalDuchyId, + Population.WORKER1_DUCHY.externalDuchyId, + Population.WORKER2_DUCHY.externalDuchyId, + ) + ) } } } diff --git a/src/main/proto/wfa/measurement/internal/kingdom/protocol_config.proto b/src/main/proto/wfa/measurement/internal/kingdom/protocol_config.proto index 8a6fbdda345..2721b345fd2 100644 --- a/src/main/proto/wfa/measurement/internal/kingdom/protocol_config.proto +++ b/src/main/proto/wfa/measurement/internal/kingdom/protocol_config.proto @@ -131,6 +131,15 @@ message ProtocolConfig { NoiseMechanism noise_mechanism = 5; } + // Configuration for Honest Majority Share Shuffle protocols. + message HonestMajorityShareShuffle { + // Parameters for sketches. + ShareShuffleSketchParams sketch_params = 1; + + // The mechanism to generate noise by MPC nodes during computation. + NoiseMechanism noise_mechanism = 2; + } + // Configuration for the specific protocol. oneof protocol { // Liquid Legions v2 config. @@ -145,6 +154,9 @@ message ProtocolConfig { // Direct protocol. Direct direct = 5; + + // Honest Majority Share Shuffle protocol. + HonestMajorityShareShuffle honest_majority_share_shuffle = 6; } } @@ -160,3 +172,15 @@ message LiquidLegionsSketchParams { // Reach-Only Liquid Legions protocol ignores this field. int64 sampling_indicator_size = 3; } + +// Parameters for a honest majority share shuffle sketch. +message ShareShuffleSketchParams { + // The number of registers in the sketch. + int64 register_count = 1; + + // Length of each register in bytes. + int32 bytes_per_register = 2; + + // Secret share modulus. + uint32 ring_modulus = 3; +} diff --git a/src/main/proto/wfa/measurement/internal/kingdom/protocol_config_config.proto b/src/main/proto/wfa/measurement/internal/kingdom/protocol_config_config.proto index 6e4c3470842..b911f0d0afa 100644 --- a/src/main/proto/wfa/measurement/internal/kingdom/protocol_config_config.proto +++ b/src/main/proto/wfa/measurement/internal/kingdom/protocol_config_config.proto @@ -31,3 +31,12 @@ message Llv2ProtocolConfigConfig { repeated string required_external_duchy_ids = 3; int32 minimum_duchy_participant_count = 4; } + +// The kingdom local config proto used to initialize the `ProtocolConfig` for +// the Honest Majority Share Shuffle protocol. +message HmssProtocolConfigConfig { + // The HonestMajorityShareShuffle protocol Config. + ProtocolConfig.HonestMajorityShareShuffle protocol_config = 1; + // List of required duchies for this protocol. + repeated string required_external_duchy_ids = 2; +}