Skip to content

Commit

Permalink
Update SetParticipantParams for HMSS protocol (#1470)
Browse files Browse the repository at this point in the history
  • Loading branch information
renjiezh authored Feb 16, 2024
1 parent bd00d25 commit 321b373
Show file tree
Hide file tree
Showing 17 changed files with 391 additions and 113 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@ object HmssProtocolConfig {
lateinit var protocolConfig: ProtocolConfig.HonestMajorityShareShuffle
private set

/** A set of external duchy ids that the first one must be corresponding to the aggregator. */
/**
* Set of external IDs of required Duchies, where the first entry must correspond to the Duchy in
* the aggregator role.
*/
lateinit var requiredExternalDuchyIds: Set<String>
private set

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,9 @@ class RequisitionReader : BaseSpannerReader<RequisitionReader.Result>() {
ComputationParticipant.Details.ProtocolCase.REACH_ONLY_LIQUID_LEGIONS_V2 -> {
reachOnlyLiquidLegionsV2 = participantDetails.reachOnlyLiquidLegionsV2
}
ComputationParticipant.Details.ProtocolCase.HONEST_MAJORITY_SHARE_SHUFFLE -> {
honestMajorityShareShuffle = participantDetails.honestMajorityShareShuffle
}
// Protocol may only be set after computation participant sets requisition params.
ComputationParticipant.Details.ProtocolCase.PROTOCOL_NOT_SET -> Unit
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ class CreateMeasurements(private val requests: List<CreateMeasurementRequest>) :
)
}
ProtocolConfig.ProtocolCase.HONEST_MAJORITY_SHARE_SHUFFLE -> {
// For each EDP, insert a Requisition for each non-aggregator Duchy.
insertRequisitions(
measurementConsumerId,
measurementId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,20 @@ class SetParticipantRequisitionParams(private val request: SetParticipantRequisi

val participantDetails =
computationParticipant.details.copy {
if (request.hasLiquidLegionsV2()) {
liquidLegionsV2 = request.liquidLegionsV2
} else if (request.hasReachOnlyLiquidLegionsV2()) {
reachOnlyLiquidLegionsV2 = request.reachOnlyLiquidLegionsV2
@Suppress("WHEN_ENUM_CAN_BE_NULL_IN_JAVA") // Protobuf enum fields cannot be null.
when (request.protocolCase) {
SetParticipantRequisitionParamsRequest.ProtocolCase.LIQUID_LEGIONS_V2 -> {
liquidLegionsV2 = request.liquidLegionsV2
}
SetParticipantRequisitionParamsRequest.ProtocolCase.REACH_ONLY_LIQUID_LEGIONS_V2 -> {
reachOnlyLiquidLegionsV2 = request.reachOnlyLiquidLegionsV2
}
SetParticipantRequisitionParamsRequest.ProtocolCase.HONEST_MAJORITY_SHARE_SHUFFLE -> {
honestMajorityShareShuffle = request.honestMajorityShareShuffle
}
SetParticipantRequisitionParamsRequest.ProtocolCase.PROTOCOL_NOT_SET -> {
error("Unspecified protocol case in SetParticipantRequisitionParamsRequest.")
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.wfanet.measurement.api.v2alpha.DataProviderCertificateKey
import org.wfanet.measurement.api.v2alpha.DataProviderKey
import org.wfanet.measurement.api.v2alpha.DuchyCertificateKey
import org.wfanet.measurement.api.v2alpha.ElGamalPublicKey
import org.wfanet.measurement.api.v2alpha.EncryptedMessage
import org.wfanet.measurement.api.v2alpha.EncryptionPublicKey
import org.wfanet.measurement.api.v2alpha.FulfillDirectRequisitionRequest
import org.wfanet.measurement.api.v2alpha.FulfillDirectRequisitionResponse
Expand All @@ -44,7 +45,7 @@ import org.wfanet.measurement.api.v2alpha.Requisition
import org.wfanet.measurement.api.v2alpha.Requisition.DuchyEntry
import org.wfanet.measurement.api.v2alpha.Requisition.Refusal
import org.wfanet.measurement.api.v2alpha.Requisition.State
import org.wfanet.measurement.api.v2alpha.RequisitionKt.DuchyEntryKt.liquidLegionsV2
import org.wfanet.measurement.api.v2alpha.RequisitionKt.DuchyEntryKt
import org.wfanet.measurement.api.v2alpha.RequisitionKt.DuchyEntryKt.value
import org.wfanet.measurement.api.v2alpha.RequisitionKt.duchyEntry
import org.wfanet.measurement.api.v2alpha.RequisitionKt.refusal
Expand Down Expand Up @@ -435,6 +436,9 @@ private fun DuchyValue.toDuchyEntryValue(
DuchyValue.ProtocolCase.REACH_ONLY_LIQUID_LEGIONS_V2 -> {
reachOnlyLiquidLegionsV2 = source.reachOnlyLiquidLegionsV2.toDuchyEntryLlV2(apiVersion)
}
DuchyValue.ProtocolCase.HONEST_MAJORITY_SHARE_SHUFFLE -> {
honestMajorityShareShuffle = source.honestMajorityShareShuffle.toDuchyEntryHmss(apiVersion)
}
DuchyValue.ProtocolCase.PROTOCOL_NOT_SET -> error("protocol not set")
}
}
Expand All @@ -444,7 +448,7 @@ private fun ComputationParticipant.LiquidLegionsV2Details.toDuchyEntryLlV2(
apiVersion: Version
): DuchyEntry.LiquidLegionsV2 {
val source = this
return liquidLegionsV2 {
return DuchyEntryKt.liquidLegionsV2 {
elGamalPublicKey = signedMessage {
setMessage(
any {
Expand All @@ -461,6 +465,27 @@ private fun ComputationParticipant.LiquidLegionsV2Details.toDuchyEntryLlV2(
}
}

private fun ComputationParticipant.HonestMajorityShareShuffleDetails.toDuchyEntryHmss(
apiVersion: Version
): DuchyEntry.HonestMajorityShareShuffle {
val source = this
return DuchyEntryKt.honestMajorityShareShuffle {
publicKey = signedMessage {
setMessage(
any {
value = source.tinkPublicKey
typeUrl =
when (apiVersion) {
Version.V2_ALPHA -> ProtoReflection.getTypeUrl(EncryptedMessage.getDescriptor())
}
}
)
signature = source.tinkPublicKeySignature
signatureAlgorithmOid = source.tinkPublicKeySignatureAlgorithmOid
}
}
}

/** Converts an internal duchy map entry to a public [DuchyEntry]. */
private fun Map.Entry<String, DuchyValue>.toDuchyEntry(apiVersion: Version): DuchyEntry {
val mapEntry = this
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import org.wfanet.measurement.internal.kingdom.Certificate
import org.wfanet.measurement.internal.kingdom.CertificateKt
import org.wfanet.measurement.internal.kingdom.CertificateKt.details
import org.wfanet.measurement.internal.kingdom.CertificatesGrpcKt.CertificatesCoroutineImplBase
import org.wfanet.measurement.internal.kingdom.ComputationParticipant
import org.wfanet.measurement.internal.kingdom.ComputationParticipantsGrpcKt.ComputationParticipantsCoroutineImplBase
import org.wfanet.measurement.internal.kingdom.DataProvidersGrpcKt.DataProvidersCoroutineImplBase
import org.wfanet.measurement.internal.kingdom.DuchyProtocolConfig
Expand Down Expand Up @@ -543,14 +544,14 @@ abstract class CertificatesServiceTest<T : CertificatesCoroutineImplBase> {
val dataProvider = population.createDataProvider(dataProvidersService)

val measurementOne =
population.createComputedMeasurement(
population.createLlv2Measurement(
measurementsService,
measurementConsumer,
"measurement one",
dataProvider,
)
val measurementTwo =
population.createComputedMeasurement(
population.createLlv2Measurement(
measurementsService,
measurementConsumer,
"measurement two",
Expand Down Expand Up @@ -674,17 +675,9 @@ abstract class CertificatesServiceTest<T : CertificatesCoroutineImplBase> {
population.createMeasurementConsumer(measurementConsumersService, accountsService)

val measurementOne =
population.createComputedMeasurement(
measurementsService,
measurementConsumer,
"measurement one",
)
population.createLlv2Measurement(measurementsService, measurementConsumer, "measurement one")
val measurementTwo =
population.createComputedMeasurement(
measurementsService,
measurementConsumer,
"measurement two",
)
population.createLlv2Measurement(measurementsService, measurementConsumer, "measurement two")
measurementsService.cancelMeasurement(
cancelMeasurementRequest {
externalMeasurementConsumerId = measurementTwo.externalMeasurementConsumerId
Expand Down Expand Up @@ -796,17 +789,9 @@ abstract class CertificatesServiceTest<T : CertificatesCoroutineImplBase> {
val measurementConsumer =
population.createMeasurementConsumer(measurementConsumersService, accountsService)
val measurementOne =
population.createComputedMeasurement(
measurementsService,
measurementConsumer,
"measurement one",
)
population.createLlv2Measurement(measurementsService, measurementConsumer, "measurement one")
val measurementTwo =
population.createComputedMeasurement(
measurementsService,
measurementConsumer,
"measurement two",
)
population.createLlv2Measurement(measurementsService, measurementConsumer, "measurement two")
measurementsService.cancelMeasurement(
cancelMeasurementRequest {
externalMeasurementConsumerId = measurementTwo.externalMeasurementConsumerId
Expand All @@ -828,6 +813,7 @@ abstract class CertificatesServiceTest<T : CertificatesCoroutineImplBase> {
this.externalDuchyId = externalDuchyId
externalDuchyCertificateId = certificate.externalCertificateId
externalComputationId = measurementOne.externalComputationId
liquidLegionsV2 = ComputationParticipant.LiquidLegionsV2Details.getDefaultInstance()
}
)
val request = revokeCertificateRequest {
Expand Down
Loading

0 comments on commit 321b373

Please sign in to comment.