Skip to content

Commit

Permalink
Update SetParticipantParams for HmSS
Browse files Browse the repository at this point in the history
  • Loading branch information
renjiezh committed Feb 13, 2024
1 parent f4fcb45 commit e365562
Show file tree
Hide file tree
Showing 15 changed files with 379 additions and 111 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,9 @@ class RequisitionReader : BaseSpannerReader<RequisitionReader.Result>() {
ComputationParticipant.Details.ProtocolCase.REACH_ONLY_LIQUID_LEGIONS_V2 -> {
reachOnlyLiquidLegionsV2 = participantDetails.reachOnlyLiquidLegionsV2
}
ComputationParticipant.Details.ProtocolCase.HONEST_MAJORITY_SHARE_SHUFFLE -> {
honestMajorityShareShuffle = participantDetails.honestMajorityShareShuffle
}
// Protocol may only be set after computation participant sets requisition params.
ComputationParticipant.Details.ProtocolCase.PROTOCOL_NOT_SET -> Unit
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,20 @@ class SetParticipantRequisitionParams(private val request: SetParticipantRequisi

val participantDetails =
computationParticipant.details.copy {
if (request.hasLiquidLegionsV2()) {
liquidLegionsV2 = request.liquidLegionsV2
} else if (request.hasReachOnlyLiquidLegionsV2()) {
reachOnlyLiquidLegionsV2 = request.reachOnlyLiquidLegionsV2
@Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Protobuf enum fields cannot be null.
when (request.protocolCase) {
SetParticipantRequisitionParamsRequest.ProtocolCase.LIQUID_LEGIONS_V2 -> {
liquidLegionsV2 = request.liquidLegionsV2
}
SetParticipantRequisitionParamsRequest.ProtocolCase.REACH_ONLY_LIQUID_LEGIONS_V2 -> {
reachOnlyLiquidLegionsV2 = request.reachOnlyLiquidLegionsV2
}
SetParticipantRequisitionParamsRequest.ProtocolCase.HONEST_MAJORITY_SHARE_SHUFFLE -> {
honestMajorityShareShuffle = request.honestMajorityShareShuffle
}
SetParticipantRequisitionParamsRequest.ProtocolCase.PROTOCOL_NOT_SET -> {
error("Unspecified protocol case in SetParticipantRequisitionParamsRequest.")
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.wfanet.measurement.api.v2alpha.DataProviderCertificateKey
import org.wfanet.measurement.api.v2alpha.DataProviderKey
import org.wfanet.measurement.api.v2alpha.DuchyCertificateKey
import org.wfanet.measurement.api.v2alpha.ElGamalPublicKey
import org.wfanet.measurement.api.v2alpha.EncryptedMessage
import org.wfanet.measurement.api.v2alpha.EncryptionPublicKey
import org.wfanet.measurement.api.v2alpha.FulfillDirectRequisitionRequest
import org.wfanet.measurement.api.v2alpha.FulfillDirectRequisitionResponse
Expand All @@ -44,7 +45,7 @@ import org.wfanet.measurement.api.v2alpha.Requisition
import org.wfanet.measurement.api.v2alpha.Requisition.DuchyEntry
import org.wfanet.measurement.api.v2alpha.Requisition.Refusal
import org.wfanet.measurement.api.v2alpha.Requisition.State
import org.wfanet.measurement.api.v2alpha.RequisitionKt.DuchyEntryKt.liquidLegionsV2
import org.wfanet.measurement.api.v2alpha.RequisitionKt.DuchyEntryKt
import org.wfanet.measurement.api.v2alpha.RequisitionKt.DuchyEntryKt.value
import org.wfanet.measurement.api.v2alpha.RequisitionKt.duchyEntry
import org.wfanet.measurement.api.v2alpha.RequisitionKt.refusal
Expand Down Expand Up @@ -435,6 +436,9 @@ private fun DuchyValue.toDuchyEntryValue(
DuchyValue.ProtocolCase.REACH_ONLY_LIQUID_LEGIONS_V2 -> {
reachOnlyLiquidLegionsV2 = source.reachOnlyLiquidLegionsV2.toDuchyEntryLlV2(apiVersion)
}
DuchyValue.ProtocolCase.HONEST_MAJORITY_SHARE_SHUFFLE -> {
honestMajorityShareShuffle = source.honestMajorityShareShuffle.toDuchyEntryHmss(apiVersion)
}
DuchyValue.ProtocolCase.PROTOCOL_NOT_SET -> error("protocol not set")
}
}
Expand All @@ -444,7 +448,7 @@ private fun ComputationParticipant.LiquidLegionsV2Details.toDuchyEntryLlV2(
apiVersion: Version
): DuchyEntry.LiquidLegionsV2 {
val source = this
return liquidLegionsV2 {
return DuchyEntryKt.liquidLegionsV2 {
elGamalPublicKey = signedMessage {
setMessage(
any {
Expand All @@ -461,6 +465,27 @@ private fun ComputationParticipant.LiquidLegionsV2Details.toDuchyEntryLlV2(
}
}

private fun ComputationParticipant.HonestMajorityShareShuffleDetails.toDuchyEntryHmss(
apiVersion: Version
): DuchyEntry.HonestMajorityShareShuffle {
val source = this
return DuchyEntryKt.honestMajorityShareShuffle {
publicKey = signedMessage {
setMessage(
any {
value = source.tinkPublicKey
typeUrl =
when (apiVersion) {
Version.V2_ALPHA -> ProtoReflection.getTypeUrl(EncryptedMessage.getDescriptor())
}
}
)
signature = source.tinkPublicKeySignature
signatureAlgorithmOid = source.tinkPublicKeySignatureAlgorithmOid
}
}
}

/** Converts an internal duchy map entry to a public [DuchyEntry]. */
private fun Map.Entry<String, DuchyValue>.toDuchyEntry(apiVersion: Version): DuchyEntry {
val mapEntry = this
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -543,14 +543,14 @@ abstract class CertificatesServiceTest<T : CertificatesCoroutineImplBase> {
val dataProvider = population.createDataProvider(dataProvidersService)

val measurementOne =
population.createComputedMeasurement(
population.createLlv2Measurement(
measurementsService,
measurementConsumer,
"measurement one",
dataProvider,
)
val measurementTwo =
population.createComputedMeasurement(
population.createLlv2Measurement(
measurementsService,
measurementConsumer,
"measurement two",
Expand Down Expand Up @@ -674,17 +674,9 @@ abstract class CertificatesServiceTest<T : CertificatesCoroutineImplBase> {
population.createMeasurementConsumer(measurementConsumersService, accountsService)

val measurementOne =
population.createComputedMeasurement(
measurementsService,
measurementConsumer,
"measurement one",
)
population.createLlv2Measurement(measurementsService, measurementConsumer, "measurement one")
val measurementTwo =
population.createComputedMeasurement(
measurementsService,
measurementConsumer,
"measurement two",
)
population.createLlv2Measurement(measurementsService, measurementConsumer, "measurement two")
measurementsService.cancelMeasurement(
cancelMeasurementRequest {
externalMeasurementConsumerId = measurementTwo.externalMeasurementConsumerId
Expand Down Expand Up @@ -796,17 +788,9 @@ abstract class CertificatesServiceTest<T : CertificatesCoroutineImplBase> {
val measurementConsumer =
population.createMeasurementConsumer(measurementConsumersService, accountsService)
val measurementOne =
population.createComputedMeasurement(
measurementsService,
measurementConsumer,
"measurement one",
)
population.createLlv2Measurement(measurementsService, measurementConsumer, "measurement one")
val measurementTwo =
population.createComputedMeasurement(
measurementsService,
measurementConsumer,
"measurement two",
)
population.createLlv2Measurement(measurementsService, measurementConsumer, "measurement two")
measurementsService.cancelMeasurement(
cancelMeasurementRequest {
externalMeasurementConsumerId = measurementTwo.externalMeasurementConsumerId
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import org.wfanet.measurement.internal.kingdom.Certificate
import org.wfanet.measurement.internal.kingdom.CertificatesGrpcKt.CertificatesCoroutineImplBase
import org.wfanet.measurement.internal.kingdom.ComputationParticipant
import org.wfanet.measurement.internal.kingdom.ComputationParticipantKt.details
import org.wfanet.measurement.internal.kingdom.ComputationParticipantKt.honestMajorityShareShuffleDetails
import org.wfanet.measurement.internal.kingdom.ComputationParticipantKt.liquidLegionsV2Details
import org.wfanet.measurement.internal.kingdom.ComputationParticipantsGrpcKt.ComputationParticipantsCoroutineImplBase
import org.wfanet.measurement.internal.kingdom.DataProvidersGrpcKt.DataProvidersCoroutineImplBase
Expand All @@ -62,6 +63,7 @@ import org.wfanet.measurement.internal.kingdom.revokeCertificateRequest
import org.wfanet.measurement.internal.kingdom.setMeasurementResultRequest
import org.wfanet.measurement.internal.kingdom.setParticipantRequisitionParamsRequest
import org.wfanet.measurement.internal.kingdom.streamRequisitionsRequest
import org.wfanet.measurement.kingdom.deploy.common.HmssProtocolConfig
import org.wfanet.measurement.kingdom.deploy.common.Llv2ProtocolConfig
import org.wfanet.measurement.kingdom.deploy.common.testing.DuchyIdSetter
import org.wfanet.measurement.kingdom.service.internal.testing.Population.Companion.DUCHIES
Expand All @@ -72,6 +74,9 @@ private const val PROVIDED_MEASUREMENT_ID = "measurement"
private val EL_GAMAL_PUBLIC_KEY = ByteString.copyFromUtf8("This is an ElGamal Public Key.")
private val EL_GAMAL_PUBLIC_KEY_SIGNATURE =
ByteString.copyFromUtf8("This is an ElGamal Public Key signature.")
private val TINK_PUBLIC_KEY = ByteString.copyFromUtf8("This is an Tink Public Key.")
private val TINK_PUBLIC_KEY_SIGNATURE =
ByteString.copyFromUtf8("This is an Tink Public Key signature.")

@RunWith(JUnit4::class)
abstract class ComputationParticipantsServiceTest<T : ComputationParticipantsCoroutineImplBase> {
Expand Down Expand Up @@ -147,7 +152,7 @@ abstract class ComputationParticipantsServiceTest<T : ComputationParticipantsCor
val dataProvider = population.createDataProvider(dataProvidersService)

val measurement =
population.createComputedMeasurement(
population.createLlv2Measurement(
measurementsService,
measurementConsumer,
PROVIDED_MEASUREMENT_ID,
Expand Down Expand Up @@ -181,7 +186,7 @@ abstract class ComputationParticipantsServiceTest<T : ComputationParticipantsCor
measurementConsumer.certificate.externalCertificateId
val dataProvider = population.createDataProvider(dataProvidersService)

population.createComputedMeasurement(
population.createLlv2Measurement(
measurementsService,
measurementConsumer,
PROVIDED_MEASUREMENT_ID,
Expand Down Expand Up @@ -216,7 +221,7 @@ abstract class ComputationParticipantsServiceTest<T : ComputationParticipantsCor
val dataProvider = population.createDataProvider(dataProvidersService)

val measurement =
population.createComputedMeasurement(
population.createLlv2Measurement(
measurementsService,
measurementConsumer,
PROVIDED_MEASUREMENT_ID,
Expand Down Expand Up @@ -249,7 +254,7 @@ abstract class ComputationParticipantsServiceTest<T : ComputationParticipantsCor
val dataProvider = population.createDataProvider(dataProvidersService)

val measurement =
population.createComputedMeasurement(
population.createLlv2Measurement(
measurementsService,
measurementConsumer,
PROVIDED_MEASUREMENT_ID,
Expand Down Expand Up @@ -290,7 +295,7 @@ abstract class ComputationParticipantsServiceTest<T : ComputationParticipantsCor
val dataProvider = population.createDataProvider(dataProvidersService)

val measurement =
population.createComputedMeasurement(
population.createLlv2Measurement(
measurementsService,
measurementConsumer,
PROVIDED_MEASUREMENT_ID,
Expand Down Expand Up @@ -330,7 +335,7 @@ abstract class ComputationParticipantsServiceTest<T : ComputationParticipantsCor
val dataProvider = population.createDataProvider(dataProvidersService)

val measurement =
population.createComputedMeasurement(
population.createLlv2Measurement(
measurementsService,
measurementConsumer,
PROVIDED_MEASUREMENT_ID,
Expand Down Expand Up @@ -371,7 +376,7 @@ abstract class ComputationParticipantsServiceTest<T : ComputationParticipantsCor
val dataProvider = population.createDataProvider(dataProvidersService)

val measurement =
population.createComputedMeasurement(
population.createLlv2Measurement(
measurementsService,
measurementConsumer,
PROVIDED_MEASUREMENT_ID,
Expand Down Expand Up @@ -415,7 +420,7 @@ abstract class ComputationParticipantsServiceTest<T : ComputationParticipantsCor
val dataProvider = population.createDataProvider(dataProvidersService)

val measurement =
population.createComputedMeasurement(
population.createLlv2Measurement(
measurementsService,
measurementConsumer,
PROVIDED_MEASUREMENT_ID,
Expand Down Expand Up @@ -459,6 +464,59 @@ abstract class ComputationParticipantsServiceTest<T : ComputationParticipantsCor
assertThat(nonUpdatedMeasurement.state).isEqualTo(Measurement.State.PENDING_REQUISITION_PARAMS)
}

@Test
fun `setParticipantRequisitionParams succeeds for non-final HMSS Duchy`() = runBlocking {
createDuchyCertificates()
val measurementConsumer =
population.createMeasurementConsumer(measurementConsumersService, accountsService)
val externalMeasurementConsumerId = measurementConsumer.externalMeasurementConsumerId
val dataProvider = population.createDataProvider(dataProvidersService)

val measurement =
population.createHmssMeasurement(
measurementsService,
measurementConsumer,
PROVIDED_MEASUREMENT_ID,
dataProvider,
)

val request = setParticipantRequisitionParamsRequest {
externalComputationId = measurement.externalComputationId
externalDuchyId = DUCHIES[0].externalDuchyId
externalDuchyCertificateId =
duchyCertificates[DUCHIES[0].externalDuchyId]!!.externalCertificateId
honestMajorityShareShuffle = honestMajorityShareShuffleDetails {
tinkPublicKey = TINK_PUBLIC_KEY
tinkPublicKeySignature = TINK_PUBLIC_KEY_SIGNATURE
}
}

val expectedComputationParticipant = computationParticipant {
state = ComputationParticipant.State.REQUISITION_PARAMS_SET
this.externalMeasurementConsumerId = externalMeasurementConsumerId
externalMeasurementId = measurement.externalMeasurementId
externalComputationId = measurement.externalComputationId
externalDuchyId = DUCHIES[0].externalDuchyId
details = details { honestMajorityShareShuffle = request.honestMajorityShareShuffle }
apiVersion = measurement.details.apiVersion
duchyCertificate = duchyCertificates[DUCHIES[0].externalDuchyId]!!
}

val computationParticipant =
computationParticipantsService.setParticipantRequisitionParams(request)
assertThat(computationParticipant)
.ignoringFields(ComputationParticipant.UPDATE_TIME_FIELD_NUMBER)
.isEqualTo(expectedComputationParticipant)

val nonUpdatedMeasurement =
measurementsService.getMeasurementByComputationId(
GetMeasurementByComputationIdRequest.newBuilder()
.apply { externalComputationId = measurement.externalComputationId }
.build()
)
assertThat(nonUpdatedMeasurement.state).isEqualTo(Measurement.State.PENDING_REQUISITION_PARAMS)
}

@Test
fun `setParticipantRequisitionParams for final Duchy updates Measurement and Requisition state`() {
runBlocking {
Expand All @@ -468,7 +526,7 @@ abstract class ComputationParticipantsServiceTest<T : ComputationParticipantsCor
val dataProvider = population.createDataProvider(dataProvidersService)
val externalComputationId =
population
.createComputedMeasurement(
.createLlv2Measurement(
measurementsService,
measurementConsumer,
PROVIDED_MEASUREMENT_ID,
Expand Down Expand Up @@ -543,7 +601,7 @@ abstract class ComputationParticipantsServiceTest<T : ComputationParticipantsCor
fun `confirmComputationParticipant succeeds for non-last duchy`(): Unit = runBlocking {
createDuchyCertificates()
val measurement =
population.createComputedMeasurement(
population.createLlv2Measurement(
measurementsService,
population.createMeasurementConsumer(measurementConsumersService, accountsService),
"measurement",
Expand Down Expand Up @@ -630,7 +688,7 @@ abstract class ComputationParticipantsServiceTest<T : ComputationParticipantsCor
fun `confirmComputationParticipant succeeds for last duchy`() = runBlocking {
createDuchyCertificates()
val measurement =
population.createComputedMeasurement(
population.createLlv2Measurement(
measurementsService,
population.createMeasurementConsumer(measurementConsumersService, accountsService),
"measurement",
Expand Down Expand Up @@ -706,7 +764,7 @@ abstract class ComputationParticipantsServiceTest<T : ComputationParticipantsCor
val duchies = DUCHIES.dropLast(1)
createDuchyCertificates(duchies.map { it.externalDuchyId })
val measurement =
population.createComputedMeasurement(
population.createLlv2Measurement(
measurementsService,
population.createMeasurementConsumer(measurementConsumersService, accountsService),
"measurement",
Expand Down Expand Up @@ -790,7 +848,7 @@ abstract class ComputationParticipantsServiceTest<T : ComputationParticipantsCor
val dataProvider = population.createDataProvider(dataProvidersService)

val measurement =
population.createComputedMeasurement(
population.createLlv2Measurement(
measurementsService,
measurementConsumer,
"measurement 1",
Expand Down Expand Up @@ -842,7 +900,7 @@ abstract class ComputationParticipantsServiceTest<T : ComputationParticipantsCor
val dataProvider = population.createDataProvider(dataProvidersService)

val measurement =
population.createComputedMeasurement(
population.createLlv2Measurement(
measurementsService,
measurementConsumer,
"measurement 1",
Expand Down Expand Up @@ -904,6 +962,14 @@ abstract class ComputationParticipantsServiceTest<T : ComputationParticipantsCor
setOf(Population.AGGREGATOR_DUCHY.externalDuchyId),
2,
)
HmssProtocolConfig.setForTest(
ProtocolConfig.HonestMajorityShareShuffle.getDefaultInstance(),
setOf(
Population.AGGREGATOR_DUCHY.externalDuchyId,
Population.WORKER1_DUCHY.externalDuchyId,
Population.WORKER2_DUCHY.externalDuchyId,
),
)
}
}
}
Loading

0 comments on commit e365562

Please sign in to comment.